From d31d722ab986b8f4f428d6a48c512b9f9aa34f5f Mon Sep 17 00:00:00 2001 From: Vasily Tsybenko Date: Mon, 9 Sep 2024 12:52:33 +0300 Subject: [PATCH] Avoid implicit panics in middleware and middleware/throttle packages --- httpserver/config.go | 2 +- httpserver/http_server.go | 3 -- httpserver/middleware/example_test.go | 4 +- httpserver/middleware/in_flight_limit.go | 28 +++++++++-- httpserver/middleware/in_flight_limit_test.go | 22 ++++---- httpserver/middleware/rate_limit.go | 28 +++++++++-- httpserver/middleware/rate_limit_test.go | 22 ++++---- .../middleware/throttle/example_test.go | 24 ++++++--- httpserver/middleware/throttle/middleware.go | 50 +++++++++---------- .../middleware/throttle/middleware_test.go | 23 ++++++--- httpserver/router.go | 11 ++-- 11 files changed, 132 insertions(+), 85 deletions(-) diff --git a/httpserver/config.go b/httpserver/config.go index c29f24a..2018d77 100644 --- a/httpserver/config.go +++ b/httpserver/config.go @@ -179,7 +179,7 @@ func NewConfig() *Config { } // NewConfigWithKeyPrefix creates a new instance of the Config. -// Allows to specify key prefix which will be used for parsing configuration parameters. +// Allows specifying key prefix which will be used for parsing configuration parameters. func NewConfigWithKeyPrefix(keyPrefix string) *Config { return &Config{keyPrefix: keyPrefix} } diff --git a/httpserver/http_server.go b/httpserver/http_server.go index 27e8f9f..8886a33 100644 --- a/httpserver/http_server.go +++ b/httpserver/http_server.go @@ -25,9 +25,6 @@ import ( "github.com/acronis/go-appkit/service" ) -// ErrInvalidMaxServingRequests error is returned when maximum number of currently serving requests is negative. -var ErrInvalidMaxServingRequests = errors.New("maximum number of currently serving requests must not be negative") - const ( networkTCP = "tcp" networkUnix = "unix" diff --git a/httpserver/middleware/example_test.go b/httpserver/middleware/example_test.go index a52d072..d897de4 100644 --- a/httpserver/middleware/example_test.go +++ b/httpserver/middleware/example_test.go @@ -35,7 +35,7 @@ func Example() { ExcludedEndpoints: []string{"/metrics", "/healthz"}, // Metrics will not be collected for "/metrics" and "/healthz" endpoints. })) - userCreateInFlightLimitMiddleware := InFlightLimitWithOpts(32, errDomain, InFlightLimitOpts{ + userCreateInFlightLimitMiddleware := MustInFlightLimitWithOpts(32, errDomain, InFlightLimitOpts{ GetKey: func(r *http.Request) (string, bool, error) { key := r.Header.Get("X-Client-ID") return key, key == "", nil @@ -49,7 +49,7 @@ func Example() { BacklogTimeout: time.Second * 10, }) - usersListRateLimitMiddleware := RateLimit(Rate{Count: 100, Duration: time.Second}, errDomain) + usersListRateLimitMiddleware := MustRateLimit(Rate{Count: 100, Duration: time.Second}, errDomain) router.Route("/users", func(r chi.Router) { r.With(usersListRateLimitMiddleware).Get("/", func(rw http.ResponseWriter, req *http.Request) { diff --git a/httpserver/middleware/in_flight_limit.go b/httpserver/middleware/in_flight_limit.go index 3f12fbe..8782a19 100644 --- a/httpserver/middleware/in_flight_limit.go +++ b/httpserver/middleware/in_flight_limit.go @@ -90,19 +90,28 @@ type InFlightLimitOpts struct { // InFlightLimit is a middleware that limits the total number of currently served (in-flight) HTTP requests. // It checks how many requests are in-flight and rejects with 503 if exceeded. -func InFlightLimit(limit int, errDomain string) func(next http.Handler) http.Handler { +func InFlightLimit(limit int, errDomain string) (func(next http.Handler) http.Handler, error) { return InFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{}) } +// MustInFlightLimit is a version of InFlightLimit that panics on error. +func MustInFlightLimit(limit int, errDomain string) func(next http.Handler) http.Handler { + mw, err := InFlightLimit(limit, errDomain) + if err != nil { + panic(err) + } + return mw +} + // InFlightLimitWithOpts is a configurable version of a middleware to limit in-flight HTTP requests. -func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) func(next http.Handler) http.Handler { +func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) (func(next http.Handler) http.Handler, error) { if limit <= 0 { - panic(fmt.Errorf("limit should be positive, got %d", limit)) + return nil, fmt.Errorf("limit should be positive, got %d", limit) } backlogLimit := opts.BacklogLimit if backlogLimit < 0 { - panic(fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit)) + return nil, fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit) } backlogTimeout := opts.BacklogTimeout @@ -120,7 +129,7 @@ func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) getSlots, err := makeInFlightLimitSlotsProvider(limit, backlogLimit, maxKeys) if err != nil { - panic(fmt.Errorf("make in-flight limit slots provider: %w", err)) + return nil, fmt.Errorf("make in-flight limit slots provider: %w", err) } respStatusCode := opts.ResponseStatusCode @@ -141,7 +150,16 @@ func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) onReject: makeInFlightLimitOnRejectFunc(opts), onError: makeInFlightLimitOnErrorFunc(opts), } + }, nil +} + +// MustInFlightLimitWithOpts is a version of InFlightLimitWithOpts that panics on error. +func MustInFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) func(next http.Handler) http.Handler { + mw, err := InFlightLimitWithOpts(limit, errDomain, opts) + if err != nil { + panic(err) } + return mw } func (h *inFlightLimitHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { diff --git a/httpserver/middleware/in_flight_limit_test.go b/httpserver/middleware/in_flight_limit_test.go index 7b7b493..9810a3c 100644 --- a/httpserver/middleware/in_flight_limit_test.go +++ b/httpserver/middleware/in_flight_limit_test.go @@ -45,7 +45,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { } rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimit(1, errDomain)(next) + handler := MustInFlightLimit(1, errDomain)(next) respCode := make(chan int) go func() { @@ -84,7 +84,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { } rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{BacklogLimit: 1, BacklogTimeout: backlogTimeout})(next) + handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{BacklogLimit: 1, BacklogTimeout: backlogTimeout})(next) resp1Code := make(chan int) go func() { @@ -139,7 +139,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimit(limit, errDomain)(next) + handler := MustInFlightLimit(limit, errDomain)(next) var wg sync.WaitGroup for i := 0; i < reqsNum; i++ { wg.Add(1) @@ -184,7 +184,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { } rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ + handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ GetKey: makeInFlightLimitGetKeyByHeader(headerClientID), ResponseStatusCode: http.StatusTooManyRequests, })(next) @@ -235,7 +235,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { } rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ + handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ GetKey: func(r *http.Request) (string, bool, error) { key := r.Header.Get(headerClientID) return key, key == "", nil @@ -308,7 +308,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { } rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ + handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ GetKey: func(r *http.Request) (string, bool, error) { return r.Header.Get(headerClientID), false, nil }, @@ -358,7 +358,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(limitPerClient, errDomain, InFlightLimitOpts{ + handler := MustInFlightLimitWithOpts(limitPerClient, errDomain, InFlightLimitOpts{ GetKey: makeInFlightLimitGetKeyByHeader(headerClientID), ResponseStatusCode: http.StatusTooManyRequests, })(next) @@ -417,7 +417,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { } rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ + handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{ GetKey: makeInFlightLimitGetKeyByHeader(headerClientID), MaxKeys: 1, ResponseStatusCode: http.StatusTooManyRequests, @@ -460,7 +460,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{ + handler := MustInFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{ GetKey: makeInFlightLimitGetKeyByHeader(headerClientID), MaxKeys: 1, ResponseStatusCode: http.StatusTooManyRequests, @@ -499,7 +499,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { <-reqContinued rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{GetRetryAfter: func(r *http.Request) time.Duration { + handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{GetRetryAfter: func(r *http.Request) time.Duration { return retryAfter }})(next) @@ -530,7 +530,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) { rw.WriteHeader(http.StatusOK) }) - handler := InFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{ + handler := MustInFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{ DryRun: true, BacklogLimit: backlogLimit, BacklogTimeout: time.Millisecond * 10, diff --git a/httpserver/middleware/rate_limit.go b/httpserver/middleware/rate_limit.go index 5ac3f26..1fea203 100644 --- a/httpserver/middleware/rate_limit.go +++ b/httpserver/middleware/rate_limit.go @@ -109,15 +109,24 @@ type Rate struct { } // RateLimit is a middleware that limits the rate of HTTP requests. -func RateLimit(maxRate Rate, errDomain string) func(next http.Handler) http.Handler { +func RateLimit(maxRate Rate, errDomain string) (func(next http.Handler) http.Handler, error) { return RateLimitWithOpts(maxRate, errDomain, RateLimitOpts{GetRetryAfter: GetRetryAfterEstimatedTime}) } +// MustRateLimit is a version of RateLimit that panics if an error occurs. +func MustRateLimit(maxRate Rate, errDomain string) func(next http.Handler) http.Handler { + mw, err := RateLimit(maxRate, errDomain) + if err != nil { + panic(err) + } + return mw +} + // RateLimitWithOpts is a configurable version of a middleware to limit the rate of HTTP requests. -func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func(next http.Handler) http.Handler { +func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) (func(next http.Handler) http.Handler, error) { backlogLimit := opts.BacklogLimit if backlogLimit < 0 { - panic(fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit)) + return nil, fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit) } if opts.DryRun { backlogLimit = 0 @@ -148,12 +157,12 @@ func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func( } limiter, err := makeLimiter() if err != nil { - panic(err) + return nil, err } getBacklogSlots, err := makeRateLimitBacklogSlotsProvider(backlogLimit, maxKeys) if err != nil { - panic(err) + return nil, fmt.Errorf("make rate limit backlog slots provider: %w", err) } backlogTimeout := opts.BacklogTimeout @@ -174,7 +183,16 @@ func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func( onReject: makeRateLimitOnRejectFunc(opts), onError: makeRateLimitOnErrorFunc(opts), } + }, nil +} + +// MustRateLimitWithOpts is a version of RateLimitWithOpts that panics if an error occurs. +func MustRateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func(next http.Handler) http.Handler { + mw, err := RateLimitWithOpts(maxRate, errDomain, opts) + if err != nil { + panic(err) } + return mw } //nolint:funlen,gocyclo diff --git a/httpserver/middleware/rate_limit_test.go b/httpserver/middleware/rate_limit_test.go index 4d679ce..9284891 100644 --- a/httpserver/middleware/rate_limit_test.go +++ b/httpserver/middleware/rate_limit_test.go @@ -64,7 +64,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, no key", func(t *testing.T) { next, nextServedCount := makeNext() - handler := RateLimit(Rate{1, time.Second}, errDomain)(next) + handler := MustRateLimit(Rate{1, time.Second}, errDomain)(next) _ = sendReqAndCheckCode(t, handler, http.StatusOK, nil) retryAfter := sendReqAndCheckCode(t, handler, http.StatusServiceUnavailable, nil) time.Sleep(retryAfter) @@ -84,7 +84,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { wantRetryAfter := time.Second * time.Duration(math.Ceil(emissionInterval.Seconds())) next, nextServedCount := makeNext() - handler := RateLimitWithOpts(rate, errDomain, RateLimitOpts{MaxBurst: maxBurst, GetRetryAfter: GetRetryAfterEstimatedTime})(next) + handler := MustRateLimitWithOpts(rate, errDomain, RateLimitOpts{MaxBurst: maxBurst, GetRetryAfter: GetRetryAfterEstimatedTime})(next) sendNReqsConcurrentlyAndCheck := func(n int) { var okCount, tooManyReqsCount, unexpectedCodeReqsCount, wrongRetryAfterReqsCount, getRetryAfterErrsCount atomic.Int32 @@ -140,7 +140,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, by key", func(t *testing.T) { const headerClientID = "X-Client-ID" next, nextServedCount := makeNext() - handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ + handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ GetKey: makeRateLimitGetKeyByHeader(headerClientID), GetRetryAfter: GetRetryAfterEstimatedTime, ResponseStatusCode: http.StatusTooManyRequests, @@ -167,7 +167,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, by key, no bypass empty key", func(t *testing.T) { const headerClientID = "X-Client-ID" next, nextServedCount := makeNext() - handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ + handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ GetKey: func(r *http.Request) (key string, bypass bool, err error) { return r.Header.Get(headerClientID), false, nil }, @@ -204,7 +204,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { wantRetryAfter := time.Second * time.Duration(math.Ceil(emissionInterval.Seconds())) next, nextServedCount := makeNext() - handler := RateLimitWithOpts(rate, errDomain, RateLimitOpts{ + handler := MustRateLimitWithOpts(rate, errDomain, RateLimitOpts{ MaxBurst: maxBurst, GetKey: makeRateLimitGetKeyByHeader(headerClientID), GetRetryAfter: GetRetryAfterEstimatedTime, @@ -311,7 +311,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("sliding window, maxRate=1r/s, no key", func(t *testing.T) { next, nextServedCount := makeNext() - handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ + handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ Alg: RateLimitAlgSlidingWindow, GetRetryAfter: GetRetryAfterEstimatedTime, })(next) @@ -330,7 +330,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { rate := Rate{2, time.Second} next, nextServedCount := makeNext() - handler := RateLimitWithOpts(rate, errDomain, RateLimitOpts{ + handler := MustRateLimitWithOpts(rate, errDomain, RateLimitOpts{ Alg: RateLimitAlgSlidingWindow, GetKey: makeRateLimitGetKeyByHeader(headerClientID), GetRetryAfter: GetRetryAfterEstimatedTime, @@ -387,7 +387,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("RetryAfter custom", func(t *testing.T) { next, _ := makeNext() - handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ + handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ GetRetryAfter: func(r *http.Request, estimatedTime time.Duration) time.Duration { return estimatedTime * 3 }, @@ -399,7 +399,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, backlogLimit=1, no key", func(t *testing.T) { next, nextServedCount := makeNext() - handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{BacklogLimit: 1})(next) + handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{BacklogLimit: 1})(next) sendReqAndCheckCode(t, handler, http.StatusOK, nil) startTime := time.Now() sendReqAndCheckCode(t, handler, http.StatusOK, nil) @@ -425,7 +425,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("leaky bucket, maxRate=1r/m, maxBurst=0, backlogLimit=1, backlogTimeout=1s, no key", func(t *testing.T) { next, nextServedCount := makeNext() rateLimitOpts := RateLimitOpts{BacklogLimit: 1, BacklogTimeout: time.Second, GetRetryAfter: GetRetryAfterEstimatedTime} - handler := RateLimitWithOpts(Rate{1, time.Minute}, errDomain, rateLimitOpts)(next) + handler := MustRateLimitWithOpts(Rate{1, time.Minute}, errDomain, rateLimitOpts)(next) sendReqAndCheckCode(t, handler, http.StatusOK, nil) startTime := time.Now() sendReqAndCheckCode(t, handler, http.StatusServiceUnavailable, nil) @@ -436,7 +436,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) { t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, backlogLimit=1, by key", func(t *testing.T) { const headerClientID = "X-Client-ID" next, nextServedCount := makeNext() - handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ + handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{ GetKey: makeRateLimitGetKeyByHeader(headerClientID), BacklogLimit: 1, GetRetryAfter: GetRetryAfterEstimatedTime, diff --git a/httpserver/middleware/throttle/example_test.go b/httpserver/middleware/throttle/example_test.go index 08209de..bf0f31b 100644 --- a/httpserver/middleware/throttle/example_test.go +++ b/httpserver/middleware/throttle/example_test.go @@ -82,7 +82,11 @@ rules: const longWorkDelay = time.Second - srv := makeExampleTestServer(cfg, longWorkDelay) + srv, err := makeExampleTestServer(cfg, longWorkDelay) + if err != nil { + stdlog.Fatal(err) + return + } defer srv.Close() // Rate limiting. @@ -156,17 +160,20 @@ rules: // [9] PUT /api/2/tenants/446507ba-2f9b-4347-adbc-63581383ba25 204 } -func makeExampleTestServer(cfg *throttle.Config, longWorkDelay time.Duration) *httptest.Server { +func makeExampleTestServer(cfg *throttle.Config, longWorkDelay time.Duration) (*httptest.Server, error) { promMetrics := throttle.NewPrometheusMetrics() promMetrics.MustRegister() defer promMetrics.Unregister() // Configure middleware that should do global throttling ("all_reqs" tag says about that). - allReqsThrottleMiddleware := throttle.MiddlewareWithOpts(cfg, apiErrDomain, promMetrics, throttle.MiddlewareOpts{ + globalThrottleMiddleware, err := throttle.MiddlewareWithOpts(cfg, apiErrDomain, promMetrics, throttle.MiddlewareOpts{ Tags: []string{"all_reqs"}}) + if err != nil { + return nil, fmt.Errorf("create global throttling middleware: %w", err) + } // Configure middleware that should do per-client throttling based on the username from basic auth ("authenticated_reqs" tag says about that). - authenticatedReqsThrottleMiddleware := throttle.MiddlewareWithOpts(cfg, apiErrDomain, promMetrics, throttle.MiddlewareOpts{ + clientThrottleMiddleware, err := throttle.MiddlewareWithOpts(cfg, apiErrDomain, promMetrics, throttle.MiddlewareOpts{ Tags: []string{"authenticated_reqs"}, GetKeyIdentity: func(r *http.Request) (key string, bypass bool, err error) { username, _, ok := r.BasicAuth() @@ -176,9 +183,12 @@ func makeExampleTestServer(cfg *throttle.Config, longWorkDelay time.Duration) *h return username, false, nil }, }) + if err != nil { + return nil, fmt.Errorf("create client throttling middleware: %w", err) + } restoreTenantPathRegExp := regexp.MustCompile(`^/api/2/tenants/([\w-]{36})/?$`) - return httptest.NewServer(allReqsThrottleMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + return httptest.NewServer(globalThrottleMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/long-work": if r.Method != http.MethodPost { @@ -213,12 +223,12 @@ func makeExampleTestServer(cfg *throttle.Config, longWorkDelay time.Duration) *h rw.WriteHeader(http.StatusMethodNotAllowed) return } - authenticatedReqsThrottleMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + clientThrottleMiddleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusNoContent) })).ServeHTTP(rw, r) return } rw.WriteHeader(http.StatusNotFound) - }))) + }))), nil } diff --git a/httpserver/middleware/throttle/middleware.go b/httpserver/middleware/throttle/middleware.go index 929ed55..3be908a 100644 --- a/httpserver/middleware/throttle/middleware.go +++ b/httpserver/middleware/throttle/middleware.go @@ -79,34 +79,39 @@ func (opts MiddlewareOpts) inFlightLimitOpts() inFlightLimitMiddlewareOpts { } // Middleware is a middleware that throttles incoming HTTP requests based on the passed configuration. -func Middleware(cfg *Config, errDomain string, mc MetricsCollector) func(next http.Handler) http.Handler { +func Middleware(cfg *Config, errDomain string, mc MetricsCollector) (func(next http.Handler) http.Handler, error) { return MiddlewareWithOpts(cfg, errDomain, mc, MiddlewareOpts{}) } // MiddlewareWithOpts is a more configurable version of Middleware. -func MiddlewareWithOpts(cfg *Config, errDomain string, mc MetricsCollector, opts MiddlewareOpts) func(next http.Handler) http.Handler { +func MiddlewareWithOpts( + cfg *Config, errDomain string, mc MetricsCollector, opts MiddlewareOpts, +) (func(next http.Handler) http.Handler, error) { if mc == nil { mc = disabledMetrics{} } - mustMakeRoutesManager := func(next http.Handler) *restapi.RoutesManager { - routes, err := makeRoutes(cfg, errDomain, mc, opts, next) - if err != nil { - panic(err) // Should be validated above. - } - return restapi.NewRoutesManager(routes) + routes, err := makeRoutes(cfg, errDomain, mc, opts) + if err != nil { + return nil, err } if opts.BuildHandlerAtInit { return func(next http.Handler) http.Handler { - return &handler{next: next, routesManager: mustMakeRoutesManager(next)} - } + for i := range routes { + route := &routes[i] + route.Handler = next + for j := len(route.Middlewares) - 1; j >= 0; j-- { + route.Handler = route.Middlewares[j](route.Handler) + } + } + return &handler{next: next, routesManager: restapi.NewRoutesManager(routes)} + }, nil } - routesManager := mustMakeRoutesManager(nil) return func(next http.Handler) http.Handler { - return &handler{next: next, routesManager: routesManager} - } + return &handler{next: next, routesManager: restapi.NewRoutesManager(routes)} + }, nil } type handler struct { @@ -138,7 +143,7 @@ func (h *handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // nolint: gocyclo // we would like to have high functional cohesion here. func makeRoutes( - cfg *Config, errDomain string, mc MetricsCollector, opts MiddlewareOpts, handler http.Handler, + cfg *Config, errDomain string, mc MetricsCollector, opts MiddlewareOpts, ) (routes []restapi.Route, err error) { for _, rule := range cfg.Rules { if len(rule.RateLimits) == 0 && len(rule.InFlightLimits) == 0 { @@ -162,7 +167,7 @@ func makeRoutes( inFlightLimitMw, err = makeInFlightLimitMiddleware( &cfgZone, errDomain, rule.Name(), mc, opts.inFlightLimitOpts()) if err != nil { - return nil, fmt.Errorf("new in-flight limit middleware for zone %q: %w", zoneName, err) + return nil, fmt.Errorf("make in-flight limit middleware for zone %q: %w", zoneName, err) } middlewares = append(middlewares, inFlightLimitMw) } @@ -178,20 +183,13 @@ func makeRoutes( rateLimitMw, err = makeRateLimitMiddleware( &cfgZone, errDomain, rule.Name(), mc, opts.rateLimitOpts()) if err != nil { - return nil, fmt.Errorf("new rate limit middleware for zone %q: %w", zoneName, err) + return nil, fmt.Errorf("make rate limit middleware for zone %q: %w", zoneName, err) } middlewares = append(middlewares, rateLimitMw) } - routeHandler := handler - if routeHandler != nil { - for i := len(middlewares) - 1; i >= 0; i-- { - routeHandler = middlewares[i](routeHandler) - } - } - for _, cfgRoute := range rule.Routes { - routes = append(routes, restapi.NewRoute(cfgRoute, routeHandler, middlewares)) + routes = append(routes, restapi.NewRoute(cfgRoute, nil, middlewares)) } for _, exclCfgRoute := range rule.ExcludedRoutes { routes = append(routes, restapi.NewExcludedRoute(exclCfgRoute)) @@ -298,7 +296,7 @@ func makeRateLimitMiddleware( OnReject: onRejectWithMetrics, OnRejectInDryRun: onRejectInDryRunWithMetrics, OnError: opts.RateLimitOnError, - }), nil + }) } type inFlightLimitMiddlewareOpts struct { @@ -379,7 +377,7 @@ func makeInFlightLimitMiddleware( OnReject: onRejectWithMetrics, OnRejectInDryRun: onRejectInDryRunWithMetrics, OnError: opts.InFlightLimitOnError, - }), nil + }) } // nolint: gocyclo // we would like to have high functional cohesion here. diff --git a/httpserver/middleware/throttle/middleware_test.go b/httpserver/middleware/throttle/middleware_test.go index 818be31..fd5bc91 100644 --- a/httpserver/middleware/throttle/middleware_test.go +++ b/httpserver/middleware/throttle/middleware_test.go @@ -591,9 +591,9 @@ func (c *testCounters) checkInFlightLimit(t *testing.T, wantRejects, wantDryRunR func makeHandlerWrappedIntoMiddleware( cfg *Config, blockCh chan struct{}, tags []string, buildHandlerAtInit bool, -) (http.Handler, *testCounters) { +) (http.Handler, *testCounters, error) { c := &testCounters{} - mid := MiddlewareWithOpts(cfg, testErrDomain, NewPrometheusMetrics(), MiddlewareOpts{ + mw, err := MiddlewareWithOpts(cfg, testErrDomain, NewPrometheusMetrics(), MiddlewareOpts{ GetKeyIdentity: func(r *http.Request) (key string, bypass bool, err error) { username, _, ok := r.BasicAuth() if !ok { @@ -642,7 +642,10 @@ func makeHandlerWrappedIntoMiddleware( Tags: tags, BuildHandlerAtInit: buildHandlerAtInit, }) - return mid(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if err != nil { + return nil, nil, fmt.Errorf("create throttling middleware: %w", err) + } + return mw(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { c.nextCalls.Inc() if blockCh != nil { if err := waitSend(blockCh, time.Second*5); err != nil { @@ -651,7 +654,7 @@ func makeHandlerWrappedIntoMiddleware( } } rw.WriteHeader(http.StatusOK) - })), c + })), c, nil } // nolint @@ -669,7 +672,8 @@ func checkRateLimiting( panic("totalReqsNum should be > burst+1") } - throttleHandler, counters := makeHandlerWrappedIntoMiddleware(cfg, nil, tags, false) + throttleHandler, counters, err := makeHandlerWrappedIntoMiddleware(cfg, nil, tags, false) + require.NoError(t, err) // First N requests SHOULD NOT BE throttled. for i := 0; i < wantNotThrottledReqsNum; i++ { @@ -708,7 +712,8 @@ func checkNoRateLimiting(t *testing.T, cfg *Config, reqsGen func() *http.Request func checkNoRateLimitingOrDryRun( t *testing.T, cfg *Config, reqsGen func() *http.Request, reqsNum, wantDryRunRejects int, tags ...string, ) { - throttleHandler, counters := makeHandlerWrappedIntoMiddleware(cfg, nil, tags, false) + throttleHandler, counters, err := makeHandlerWrappedIntoMiddleware(cfg, nil, tags, false) + require.NoError(t, err) for i := 0; i < reqsNum; i++ { respRec := httptest.NewRecorder() throttleHandler.ServeHTTP(respRec, reqsGen()) @@ -734,7 +739,8 @@ func checkInFlightLimiting(t *testing.T, cfg *Config, params checkInFlightLimiti panic("reqsNum should be > totalLimit") } blockCh := make(chan struct{}) - throttleHandler, counters := makeHandlerWrappedIntoMiddleware(cfg, blockCh, params.tags, params.buildHandlerAtInit) + throttleHandler, counters, err := makeHandlerWrappedIntoMiddleware(cfg, blockCh, params.tags, params.buildHandlerAtInit) + require.NoError(t, err) var okCodes, throttledCodes, unexpectedCodes, wrongRetryAfterNums atomic.Int32 var wg sync.WaitGroup for i := 0; i < params.reqsNum; i++ { @@ -803,7 +809,8 @@ func checkNoInFlightLimitingOrDryRun( tags ...string, ) { blockCh := make(chan struct{}) - throttleHandler, counters := makeHandlerWrappedIntoMiddleware(cfg, blockCh, tags, false) + throttleHandler, counters, err := makeHandlerWrappedIntoMiddleware(cfg, blockCh, tags, false) + require.NoError(t, err) var okCodes, unexpectedCodes atomic.Int32 var wg sync.WaitGroup for i := 0; i < reqsNum; i++ { diff --git a/httpserver/router.go b/httpserver/router.go index 3b6edd3..019e309 100644 --- a/httpserver/router.go +++ b/httpserver/router.go @@ -116,12 +116,11 @@ func applyDefaultMiddlewaresToRouter( }) router.Use(metricsMiddleware) - // To limit the number of currently serving requests, we use Throttle middleware from chi. - if cfg.Limits.MaxRequests < 0 { - return ErrInvalidMaxServingRequests - } - if cfg.Limits.MaxRequests > 0 { - inFlightLimitMw := middleware.InFlightLimit(cfg.Limits.MaxRequests, opts.ErrorDomain) + if cfg.Limits.MaxRequests != 0 { + inFlightLimitMw, err := middleware.InFlightLimit(cfg.Limits.MaxRequests, opts.ErrorDomain) + if err != nil { + return fmt.Errorf("create in-flight limit middleware: %w", err) + } router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { for i := 0; i < len(systemEndpoints); i++ {