package org.wikidata.query.rdf.blazegraph.throttling;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Stopwatch;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.format.DateTimeFormatter;
import java.util.Collection;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.LongAdder;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import java.util.stream.Stream;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.isomorphism.util.TokenBuckets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.wikidata.query.rdf.blazegraph.events.QueryEventGenerator;
import org.wikidata.query.rdf.blazegraph.filters.FilterConfiguration;
import org.wikidata.query.rdf.blazegraph.filters.MonitoredFilter;

/* loaded from: input_file:org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.class */
public class ThrottlingFilter extends MonitoredFilter implements Filter, ThrottlingMXBean {
    private static final Logger log = LoggerFactory.getLogger(ThrottlingFilter.class);
    private boolean enabled;
    private Bucketing userAgentIpBucketing;
    private Bucketing regexBucketing;
    private Bucketing agentBucketing;
    private TimeAndErrorsThrottler<ThrottlingState> timeAndErrorsThrottler;
    private BanThrottler<ThrottlingState> banThrottler;
    private final LongAdder nbThrottledRequests = new LongAdder();
    private final LongAdder nbBannedRequests = new LongAdder();
    private Cache<Object, ThrottlingState> stateStore;

    @Override // org.wikidata.query.rdf.blazegraph.filters.MonitoredFilter
    public void init(FilterConfig filterConfig) throws ServletException {
        super.init(filterConfig);
        ThrottlingFilterConfig throttlingFilterConfig = new ThrottlingFilterConfig(new FilterConfiguration(filterConfig, FilterConfiguration.WDQS_CONFIG_PREFIX));
        this.enabled = throttlingFilterConfig.isFilterEnabled();
        this.userAgentIpBucketing = new UserAgentIpAddressBucketing();
        this.regexBucketing = new RegexpBucketing(loadRegexPatterns(throttlingFilterConfig.getRegexPatternsFile()), httpServletRequest -> {
            return httpServletRequest.getParameter(QueryEventGenerator.QUERY_PARAM);
        });
        this.agentBucketing = new RegexpBucketing(loadRegexPatterns(throttlingFilterConfig.getAgentPatternsFile()), httpServletRequest2 -> {
            return httpServletRequest2.getHeader("User-Agent");
        });
        this.stateStore = CacheBuilder.newBuilder().maximumSize(throttlingFilterConfig.getMaxStateSize()).expireAfterAccess(throttlingFilterConfig.getStateExpiration().toMillis(), TimeUnit.MILLISECONDS).build();
        Callable<ThrottlingState> createThrottlingState = createThrottlingState(throttlingFilterConfig.getTimeBucketCapacity(), throttlingFilterConfig.getTimeBucketRefillAmount(), throttlingFilterConfig.getTimeBucketRefillPeriod(), throttlingFilterConfig.getErrorBucketCapacity(), throttlingFilterConfig.getErrorBucketRefillAmount(), throttlingFilterConfig.getErrorBucketRefillPeriod(), throttlingFilterConfig.getThrottleBucketCapacity(), throttlingFilterConfig.getThrottleBucketRefillAmount(), throttlingFilterConfig.getThrottleBucketRefillPeriod(), throttlingFilterConfig.getBanDuration());
        this.timeAndErrorsThrottler = new TimeAndErrorsThrottler<>(throttlingFilterConfig.getRequestDurationThreshold(), createThrottlingState, this.stateStore, throttlingFilterConfig.getEnableThrottlingIfHeader(), throttlingFilterConfig.getAlwaysThrottleParam(), Clock.systemUTC());
        this.banThrottler = new BanThrottler<>(createThrottlingState, this.stateStore, throttlingFilterConfig.getEnableBanIfHeader(), throttlingFilterConfig.getAlwaysBanParam(), Clock.systemUTC());
    }

    private static Callable<ThrottlingState> createThrottlingState(Duration duration, Duration duration2, Duration duration3, int i, int i2, Duration duration4, int i3, int i4, Duration duration5, Duration duration6) {
        return () -> {
            return new ThrottlingState(TokenBuckets.builder().withCapacity(duration.toMillis()).withFixedIntervalRefillStrategy(duration2.toMillis(), duration3.toMillis(), TimeUnit.MILLISECONDS).build(), TokenBuckets.builder().withCapacity(i).withFixedIntervalRefillStrategy(i2, duration4.toMillis(), TimeUnit.MILLISECONDS).build(), TokenBuckets.builder().withCapacity(i3).withFixedIntervalRefillStrategy(i4, duration5.toMillis(), TimeUnit.MILLISECONDS).build(), duration6);
        };
    }

    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
        Object bucket = this.regexBucketing.bucket(httpServletRequest);
        if (bucket == null) {
            bucket = this.agentBucketing.bucket(httpServletRequest);
        }
        if (bucket == null) {
            bucket = this.userAgentIpBucketing.bucket(httpServletRequest);
        }
        Instant throttledUntil = this.banThrottler.throttledUntil(bucket, httpServletRequest);
        if (throttledUntil.isAfter(Instant.now())) {
            log.info("A request is being banned.");
            if (this.enabled) {
                this.nbBannedRequests.increment();
                notifyUserBanned(httpServletResponse, throttledUntil);
                return;
            }
        }
        Duration throttledDuration = this.timeAndErrorsThrottler.throttledDuration(bucket, httpServletRequest);
        if (!throttledDuration.isNegative()) {
            log.info("A request is being throttled.");
            if (this.enabled) {
                this.nbThrottledRequests.increment();
                notifyUserThrottled(httpServletResponse, throttledDuration);
                this.banThrottler.throttled(bucket, httpServletRequest);
                return;
            }
        }
        Stopwatch createStarted = Stopwatch.createStarted();
        try {
            filterChain.doFilter(servletRequest, servletResponse);
            if (httpServletResponse.getStatus() < 400) {
                this.timeAndErrorsThrottler.success(bucket, httpServletRequest, createStarted.elapsed());
            } else {
                this.timeAndErrorsThrottler.failure(bucket, httpServletRequest, createStarted.elapsed());
            }
        } catch (IOException | ServletException e) {
            this.timeAndErrorsThrottler.failure(bucket, httpServletRequest, createStarted.elapsed());
            throw e;
        }
    }

    private void notifyUserBanned(HttpServletResponse httpServletResponse, Instant instant) throws IOException {
        httpServletResponse.sendError(403, formattedBanMessage(instant));
    }

    @VisibleForTesting
    static String formattedBanMessage(Instant instant) {
        return String.format(Locale.ENGLISH, "You have been banned until %s, please respect throttling and retry-after headers.", DateTimeFormatter.ISO_INSTANT.format(instant));
    }

    private void notifyUserThrottled(HttpServletResponse httpServletResponse, Duration duration) throws IOException {
        String l = Long.toString(duration.getSeconds());
        httpServletResponse.setHeader("Retry-After", l);
        httpServletResponse.sendError(429, String.format(Locale.ENGLISH, "Too Many Requests - Please retry in %s seconds.", l));
    }

    @Override // org.wikidata.query.rdf.blazegraph.throttling.ThrottlingMXBean
    public long getStateSize() {
        return this.stateStore.size();
    }

    @Override // org.wikidata.query.rdf.blazegraph.throttling.ThrottlingMXBean
    public long getNumberOfThrottledRequests() {
        return this.nbThrottledRequests.longValue();
    }

    @Override // org.wikidata.query.rdf.blazegraph.throttling.ThrottlingMXBean
    public long getNumberOfBannedRequests() {
        return this.nbBannedRequests.longValue();
    }

    private Pattern safeCompile(String str) {
        try {
            return Pattern.compile(str, 32);
        } catch (PatternSyntaxException e) {
            log.warn("Invalid pattern: {}", str);
            return null;
        }
    }

    private Collection<Pattern> loadRegexPatterns(String str) {
        try {
            Path path = Paths.get(str, new String[0]);
            if (!path.toFile().exists()) {
                log.info("Patterns file {} not found, ignoring.", str);
                return ImmutableList.of();
            }
            Stream<String> lines = Files.lines(path, StandardCharsets.UTF_8);
            Throwable th = null;
            try {
                try {
                    ImmutableList immutableList = (ImmutableList) lines.map(this::safeCompile).filter((v0) -> {
                        return Objects.nonNull(v0);
                    }).collect(ImmutableList.toImmutableList());
                    log.info("Loaded {} patterns from {}", Integer.valueOf(immutableList.size()), str);
                    if (lines != null) {
                        if (0 != 0) {
                            try {
                                lines.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            lines.close();
                        }
                    }
                    return immutableList;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            log.warn("Failed reading from patterns file {}.", str);
            return ImmutableList.of();
        }
    }
}
