diff --git a/pkg/solver/ratelimit_test.go b/pkg/solver/ratelimit_test.go index 3e4a8c3c..617255ea 100644 --- a/pkg/solver/ratelimit_test.go +++ b/pkg/solver/ratelimit_test.go @@ -49,12 +49,13 @@ func TestRateLimiter(t *testing.T) { // The running solver is configured to exempt localhost. // When no headers are set, test using the IP address from // the underlying connection (also localhost) - exemptHeaders := []map[string]string{ - {"True-Client-IP": "127.0.0.1"}, - {"X-Real-IP": "127.0.0.1"}, - {"X-Forwarded-For": "127.0.0.1"}, - {}, // No headers case - uses RemoteAddr - } + // TODO: re-enable exempt IP rate limiting + // exemptHeaders := []map[string]string{ + // {"True-Client-IP": "127.0.0.1"}, + // {"X-Real-IP": "127.0.0.1"}, + // {"X-Forwarded-For": "127.0.0.1"}, + // {}, // No headers case - uses RemoteAddr + // } t.Run("non-exempt IP is rate limited", func(t *testing.T) { // Select a random header on each test run. Over time we test them all. @@ -68,17 +69,18 @@ func TestRateLimiter(t *testing.T) { runRateLimitTest(t, paths, tc) }) - t.Run("exempt IP is not rate limited", func(t *testing.T) { - // Select a random header on each test run. Over time we test them all. - headers := exemptHeaders[rand.Intn(len(exemptHeaders))] - tc := rateTestCase{ - name: fmt.Sprintf("exempt with headers %v", headers), - headers: headers, - expectedOK: 100, - expectedLimit: 0, - } - runRateLimitTest(t, paths, tc) - }) + // TODO: re-enable exempt IP rate limiting + // t.Run("exempt IP is not rate limited", func(t *testing.T) { + // // Select a random header on each test run. Over time we test them all. + // headers := exemptHeaders[rand.Intn(len(exemptHeaders))] + // tc := rateTestCase{ + // name: fmt.Sprintf("exempt with headers %v", headers), + // headers: headers, + // expectedOK: 100, + // expectedLimit: 0, + // } + // runRateLimitTest(t, paths, tc) + // }) } func runRateLimitTest(t *testing.T, paths []string, tc rateTestCase) { diff --git a/pkg/solver/server.go b/pkg/solver/server.go index 7bd421fa..312f80ef 100644 --- a/pkg/solver/server.go +++ b/pkg/solver/server.go @@ -70,23 +70,29 @@ func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system subrouter.Use(otelmux.Middleware("solver", otelmux.WithTracerProvider(tracerProvider))) exemptIPs := solverServer.options.RateLimiter.ExemptedIPs + // TODO: re-enable exempt IP rate limiting + // subrouter.Use(httprate.Limit( + // solverServer.options.RateLimiter.RequestLimit, + // time.Duration(solverServer.options.RateLimiter.WindowLength)*time.Second, + // httprate.WithKeyFuncs( + // exemptIPKeyFunc(exemptIPs), + // httprate.KeyByEndpoint, + // ), + // httprate.WithLimitHandler(func(w corehttp.ResponseWriter, r *corehttp.Request) { + + // key, _ := exemptIPKeyFunc(exemptIPs)(r) + // if strings.HasPrefix(key, "exempt-") { + // return + // } + + // corehttp.Error(w, "Too Many Requests", corehttp.StatusTooManyRequests) + // }), + // )) subrouter.Use(httprate.Limit( solverServer.options.RateLimiter.RequestLimit, time.Duration(solverServer.options.RateLimiter.WindowLength)*time.Second, - httprate.WithKeyFuncs( - exemptIPKeyFunc(exemptIPs), - httprate.KeyByEndpoint, - ), - httprate.WithLimitHandler(func(w corehttp.ResponseWriter, r *corehttp.Request) { - - key, _ := exemptIPKeyFunc(exemptIPs)(r) - if strings.HasPrefix(key, "exempt-") { - return - } - - corehttp.Error(w, "Too Many Requests", corehttp.StatusTooManyRequests) - }), + httprate.WithKeyFuncs(httprate.KeyByRealIP, httprate.KeyByEndpoint), )) log.Info().Strs("exemptIPs", exemptIPs).Msg("Loaded rate limit exemptions")