From 0d21b9c85d506bd5ad75332dff61c315cf5ce3d1 Mon Sep 17 00:00:00 2001 From: Vasily Tsybenko Date: Wed, 4 Sep 2024 10:54:40 +0300 Subject: [PATCH] Make metrics collection in throttle and middleware packages more flexible Now middlewares receive interfaces instead of the concrete Prometheus implementations. --- httpserver/http_server.go | 22 +-- httpserver/middleware/example_test.go | 2 +- httpserver/middleware/metrics.go | 128 +++++++++++------- httpserver/middleware/metrics_test.go | 14 +- .../middleware/throttle/example_test.go | 10 +- httpserver/middleware/throttle/metrics.go | 117 +++++++++++----- httpserver/middleware/throttle/middleware.go | 54 ++++---- .../middleware/throttle/middleware_test.go | 2 +- httpserver/router.go | 4 +- 9 files changed, 217 insertions(+), 136 deletions(-) diff --git a/httpserver/http_server.go b/httpserver/http_server.go index c984d12..27e8f9f 100644 --- a/httpserver/http_server.go +++ b/httpserver/http_server.go @@ -91,9 +91,9 @@ type HTTPServer struct { Logger log.FieldLogger ShutdownTimeout time.Duration - port int32 - httpServerDone chan struct{} - httpReqMetricsCollector *middleware.HTTPRequestMetricsCollector + port int32 + httpServerDone chan struct{} + httpReqPrometheusMetrics *middleware.HTTPRequestPrometheusMetrics } var _ service.Unit = (*HTTPServer)(nil) @@ -102,20 +102,20 @@ var _ service.MetricsRegisterer = (*HTTPServer)(nil) // New creates a new HTTPServer with predefined logging, metrics collecting, // recovering after panics and health-checking functionality. func New(cfg *Config, logger log.FieldLogger, opts Opts) (*HTTPServer, error) { //nolint // hugeParam: opts is heavy, it's ok in this case. - httpReqMetricsCollector := middleware.NewHTTPRequestMetricsCollectorWithOpts( - middleware.HTTPRequestMetricsCollectorOpts{ + httpReqPromMetrics := middleware.NewHTTPRequestPrometheusMetricsWithOpts( + middleware.HTTPRequestPrometheusMetricsOpts{ Namespace: opts.HTTPRequestMetrics.Namespace, DurationBuckets: opts.HTTPRequestMetrics.DurationBuckets, ConstLabels: opts.HTTPRequestMetrics.ConstLabels, }) router := chi.NewRouter() - if err := applyDefaultMiddlewaresToRouter(router, cfg, logger, opts, httpReqMetricsCollector); err != nil { + if err := applyDefaultMiddlewaresToRouter(router, cfg, logger, opts, httpReqPromMetrics); err != nil { return nil, err } configureRouter(router, logger, opts.routerOpts()) appSrv := NewWithHandler(cfg, logger, router) - appSrv.httpReqMetricsCollector = httpReqMetricsCollector + appSrv.httpReqPrometheusMetrics = httpReqPromMetrics return appSrv, nil } @@ -259,15 +259,15 @@ func (s *HTTPServer) Stop(gracefully bool) error { // MustRegisterMetrics registers metrics in Prometheus client and panics if any error occurs. func (s *HTTPServer) MustRegisterMetrics() { - if s.httpReqMetricsCollector != nil { - s.httpReqMetricsCollector.MustRegister() + if s.httpReqPrometheusMetrics != nil { + s.httpReqPrometheusMetrics.MustRegister() } } // UnregisterMetrics unregisters metrics in Prometheus client. func (s *HTTPServer) UnregisterMetrics() { - if s.httpReqMetricsCollector != nil { - s.httpReqMetricsCollector.Unregister() + if s.httpReqPrometheusMetrics != nil { + s.httpReqPrometheusMetrics.Unregister() } } diff --git a/httpserver/middleware/example_test.go b/httpserver/middleware/example_test.go index 31d453b..a52d072 100644 --- a/httpserver/middleware/example_test.go +++ b/httpserver/middleware/example_test.go @@ -30,7 +30,7 @@ func Example() { RequestBodyLimit(1024*1024, errDomain), ) - metricsCollector := NewHTTPRequestMetricsCollector() + metricsCollector := NewHTTPRequestPrometheusMetrics() router.Use(HTTPRequestMetricsWithOpts(metricsCollector, getChiRoutePattern, HTTPRequestMetricsOpts{ ExcludedEndpoints: []string{"/metrics", "/healthz"}, // Metrics will not be collected for "/metrics" and "/healthz" endpoints. })) diff --git a/httpserver/middleware/metrics.go b/httpserver/middleware/metrics.go index a0643d2..16f0500 100644 --- a/httpserver/middleware/metrics.go +++ b/httpserver/middleware/metrics.go @@ -27,11 +27,30 @@ const ( userAgentTypeHTTPClient = "http-client" ) +// HTTPRequestInfoMetrics represents a request info for collecting metrics. +type HTTPRequestInfoMetrics struct { + Method string + RoutePattern string + UserAgentType string +} + +// HTTPRequestMetricsCollector is an interface for collecting metrics for incoming HTTP requests. +type HTTPRequestMetricsCollector interface { + // IncInFlightRequests increments the counter of in-flight requests. + IncInFlightRequests(requestInfo HTTPRequestInfoMetrics) + + // DecInFlightRequests decrements the counter of in-flight requests. + DecInFlightRequests(requestInfo HTTPRequestInfoMetrics) + + // ObserveRequestFinish observes the duration of the request and the status code. + ObserveRequestFinish(requestInfo HTTPRequestInfoMetrics, status int, startTime time.Time) +} + // DefaultHTTPRequestDurationBuckets is default buckets into which observations of serving HTTP requests are counted. var DefaultHTTPRequestDurationBuckets = []float64{0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60, 150, 300, 600} -// HTTPRequestMetricsCollectorOpts represents an options for HTTPRequestMetricsCollector. -type HTTPRequestMetricsCollectorOpts struct { +// HTTPRequestPrometheusMetricsOpts represents an options for HTTPRequestPrometheusMetrics. +type HTTPRequestPrometheusMetricsOpts struct { // Namespace is a namespace for metrics. It will be prepended to all metric names. Namespace string @@ -42,26 +61,26 @@ type HTTPRequestMetricsCollectorOpts struct { ConstLabels prometheus.Labels // CurriedLabelNames is a list of label names that will be curried with the provided labels. - // See HTTPRequestMetricsCollector.MustCurryWith method for more details. + // See HTTPRequestPrometheusMetrics.MustCurryWith method for more details. // Keep in mind that if this list is not empty, - // HTTPRequestMetricsCollector.MustCurryWith method must be called further with the same labels. + // HTTPRequestPrometheusMetrics.MustCurryWith method must be called further with the same labels. // Otherwise, the collector will panic. CurriedLabelNames []string } -// HTTPRequestMetricsCollector represents collector of metrics for incoming HTTP requests. -type HTTPRequestMetricsCollector struct { +// HTTPRequestPrometheusMetrics represents collector of metrics for incoming HTTP requests. +type HTTPRequestPrometheusMetrics struct { Durations *prometheus.HistogramVec InFlight *prometheus.GaugeVec } -// NewHTTPRequestMetricsCollector creates a new metrics collector. -func NewHTTPRequestMetricsCollector() *HTTPRequestMetricsCollector { - return NewHTTPRequestMetricsCollectorWithOpts(HTTPRequestMetricsCollectorOpts{}) +// NewHTTPRequestPrometheusMetrics creates a new instance of HTTPRequestPrometheusMetrics with default options. +func NewHTTPRequestPrometheusMetrics() *HTTPRequestPrometheusMetrics { + return NewHTTPRequestPrometheusMetricsWithOpts(HTTPRequestPrometheusMetricsOpts{}) } -// NewHTTPRequestMetricsCollectorWithOpts is a more configurable version of creating HTTPRequestMetricsCollector. -func NewHTTPRequestMetricsCollectorWithOpts(opts HTTPRequestMetricsCollectorOpts) *HTTPRequestMetricsCollector { +// NewHTTPRequestPrometheusMetricsWithOpts creates a new instance of HTTPRequestPrometheusMetrics with the provided options. +func NewHTTPRequestPrometheusMetricsWithOpts(opts HTTPRequestPrometheusMetricsOpts) *HTTPRequestPrometheusMetrics { makeLabelNames := func(names ...string) []string { l := append(make([]string, 0, len(opts.CurriedLabelNames)+len(names)), opts.CurriedLabelNames...) return append(l, names...) @@ -101,52 +120,62 @@ func NewHTTPRequestMetricsCollectorWithOpts(opts HTTPRequestMetricsCollectorOpts ), ) - return &HTTPRequestMetricsCollector{ + return &HTTPRequestPrometheusMetrics{ Durations: durations, InFlight: inFlight, } } // MustCurryWith curries the metrics collector with the provided labels. -func (c *HTTPRequestMetricsCollector) MustCurryWith(labels prometheus.Labels) *HTTPRequestMetricsCollector { - return &HTTPRequestMetricsCollector{ - Durations: c.Durations.MustCurryWith(labels).(*prometheus.HistogramVec), - InFlight: c.InFlight.MustCurryWith(labels), +func (pm *HTTPRequestPrometheusMetrics) MustCurryWith(labels prometheus.Labels) *HTTPRequestPrometheusMetrics { + return &HTTPRequestPrometheusMetrics{ + Durations: pm.Durations.MustCurryWith(labels).(*prometheus.HistogramVec), + InFlight: pm.InFlight.MustCurryWith(labels), } } // MustRegister does registration of metrics collector in Prometheus and panics if any error occurs. -func (c *HTTPRequestMetricsCollector) MustRegister() { +func (pm *HTTPRequestPrometheusMetrics) MustRegister() { prometheus.MustRegister( - c.Durations, - c.InFlight, + pm.Durations, + pm.InFlight, ) } // Unregister cancels registration of metrics collector in Prometheus. -func (c *HTTPRequestMetricsCollector) Unregister() { - prometheus.Unregister(c.InFlight) - prometheus.Unregister(c.Durations) +func (pm *HTTPRequestPrometheusMetrics) Unregister() { + prometheus.Unregister(pm.InFlight) + prometheus.Unregister(pm.Durations) } -func (c *HTTPRequestMetricsCollector) trackRequestEnd(reqInfo *httpRequestInfo, status int, startTime time.Time) { - labels := reqInfo.makeLabels() - labels[httpRequestMetricsLabelStatusCode] = strconv.Itoa(status) - c.Durations.With(labels).Observe(time.Since(startTime).Seconds()) +// IncInFlightRequests increments the counter of in-flight requests. +func (pm *HTTPRequestPrometheusMetrics) IncInFlightRequests(requestInfo HTTPRequestInfoMetrics) { + pm.InFlight.With(prometheus.Labels{ + httpRequestMetricsLabelMethod: requestInfo.Method, + httpRequestMetricsLabelRoutePattern: requestInfo.RoutePattern, + httpRequestMetricsLabelUserAgentType: requestInfo.UserAgentType, + }).Inc() } -type httpRequestInfo struct { - method string - routePattern string - userAgentType string +// DecInFlightRequests decrements the counter of in-flight requests. +func (pm *HTTPRequestPrometheusMetrics) DecInFlightRequests(requestInfo HTTPRequestInfoMetrics) { + pm.InFlight.With(prometheus.Labels{ + httpRequestMetricsLabelMethod: requestInfo.Method, + httpRequestMetricsLabelRoutePattern: requestInfo.RoutePattern, + httpRequestMetricsLabelUserAgentType: requestInfo.UserAgentType, + }).Dec() } -func (hri *httpRequestInfo) makeLabels() prometheus.Labels { - return prometheus.Labels{ - httpRequestMetricsLabelMethod: hri.method, - httpRequestMetricsLabelRoutePattern: hri.routePattern, - httpRequestMetricsLabelUserAgentType: hri.userAgentType, - } +// ObserveRequestFinish observes the duration of the request and the status code. +func (pm *HTTPRequestPrometheusMetrics) ObserveRequestFinish( + requestInfo HTTPRequestInfoMetrics, status int, startTime time.Time, +) { + pm.Durations.With(prometheus.Labels{ + httpRequestMetricsLabelMethod: requestInfo.Method, + httpRequestMetricsLabelRoutePattern: requestInfo.RoutePattern, + httpRequestMetricsLabelUserAgentType: requestInfo.UserAgentType, + httpRequestMetricsLabelStatusCode: strconv.Itoa(status), + }).Observe(time.Since(startTime).Seconds()) } // UserAgentTypeGetterFunc is a function for getting user agent type from the request. @@ -161,21 +190,21 @@ type HTTPRequestMetricsOpts struct { type httpRequestMetricsHandler struct { next http.Handler - collector *HTTPRequestMetricsCollector + collector HTTPRequestMetricsCollector getRoutePattern RoutePatternGetterFunc opts HTTPRequestMetricsOpts } // HTTPRequestMetrics is a middleware that collects metrics for incoming HTTP requests using Prometheus data types. func HTTPRequestMetrics( - collector *HTTPRequestMetricsCollector, getRoutePattern RoutePatternGetterFunc, + collector HTTPRequestMetricsCollector, getRoutePattern RoutePatternGetterFunc, ) func(next http.Handler) http.Handler { return HTTPRequestMetricsWithOpts(collector, getRoutePattern, HTTPRequestMetricsOpts{}) } // HTTPRequestMetricsWithOpts is a more configurable version of HTTPRequestMetrics middleware. func HTTPRequestMetricsWithOpts( - collector *HTTPRequestMetricsCollector, + collector HTTPRequestMetricsCollector, getRoutePattern RoutePatternGetterFunc, opts HTTPRequestMetricsOpts, ) func(next http.Handler) http.Handler { @@ -204,15 +233,14 @@ func (h *httpRequestMetricsHandler) ServeHTTP(rw http.ResponseWriter, r *http.Re r = r.WithContext(NewContextWithRequestStartTime(r.Context(), startTime)) } - reqInfo := &httpRequestInfo{ - method: r.Method, - routePattern: h.getRoutePattern(r), - userAgentType: h.opts.GetUserAgentType(r), + reqInfo := HTTPRequestInfoMetrics{ + Method: r.Method, + RoutePattern: h.getRoutePattern(r), + UserAgentType: h.opts.GetUserAgentType(r), } - inFlightGauge := h.collector.InFlight.With(reqInfo.makeLabels()) - inFlightGauge.Inc() - defer inFlightGauge.Dec() + h.collector.IncInFlightRequests(reqInfo) + defer h.collector.DecInFlightRequests(reqInfo) r = r.WithContext(NewContextWithHTTPMetricsEnabled(r.Context())) @@ -222,16 +250,16 @@ func (h *httpRequestMetricsHandler) ServeHTTP(rw http.ResponseWriter, r *http.Re return } - if reqInfo.routePattern == "" { - reqInfo.routePattern = h.getRoutePattern(r) + if reqInfo.RoutePattern == "" { + reqInfo.RoutePattern = h.getRoutePattern(r) } if p := recover(); p != nil { if p != http.ErrAbortHandler { - h.collector.trackRequestEnd(reqInfo, http.StatusInternalServerError, startTime) + h.collector.ObserveRequestFinish(reqInfo, http.StatusInternalServerError, startTime) } panic(p) } - h.collector.trackRequestEnd(reqInfo, wrw.Status(), startTime) + h.collector.ObserveRequestFinish(reqInfo, wrw.Status(), startTime) }() h.next.ServeHTTP(wrw, r) diff --git a/httpserver/middleware/metrics_test.go b/httpserver/middleware/metrics_test.go index ed10e3c..c44344d 100644 --- a/httpserver/middleware/metrics_test.go +++ b/httpserver/middleware/metrics_test.go @@ -132,7 +132,7 @@ func TestHttpRequestMetricsHandler_ServeHTTP(t *testing.T) { for k := range tt.curriedLabels { curriedLabelNames = append(curriedLabelNames, k) } - collector := NewHTTPRequestMetricsCollectorWithOpts(HTTPRequestMetricsCollectorOpts{ + collector := NewHTTPRequestPrometheusMetricsWithOpts(HTTPRequestPrometheusMetricsOpts{ CurriedLabelNames: curriedLabelNames, }) collector = collector.MustCurryWith(tt.curriedLabels) @@ -168,30 +168,30 @@ func TestHttpRequestMetricsHandler_ServeHTTP(t *testing.T) { }) t.Run("collect 500 on panic", func(t *testing.T) { - collector := NewHTTPRequestMetricsCollector() + promMetrics := NewHTTPRequestPrometheusMetrics() next := &mockRecoveryNextHandler{} req := httptest.NewRequest(http.MethodGet, "/internal-error", nil) resp := httptest.NewRecorder() - h := HTTPRequestMetrics(collector, getRoutePattern)(next) + h := HTTPRequestMetrics(promMetrics, getRoutePattern)(next) if assert.Panics(t, func() { h.ServeHTTP(resp, req) }) { assert.Equal(t, 1, next.called) labels := makeLabels(http.MethodGet, "/internal-error", "http-client", "500") - hist := collector.Durations.With(labels).(prometheus.Histogram) + hist := promMetrics.Durations.With(labels).(prometheus.Histogram) testutil.AssertSamplesCountInHistogram(t, hist, 1) } }) t.Run("not collect if disabled", func(t *testing.T) { - collector := NewHTTPRequestMetricsCollector() + promMetrics := NewHTTPRequestPrometheusMetrics() next := &mockHTTPRequestMetricsDisabledHandler{} req := httptest.NewRequest(http.MethodGet, "/hello", nil) req.Header.Set("User-Agent", "http-client") resp := httptest.NewRecorder() - h := HTTPRequestMetrics(collector, getRoutePattern)(next) + h := HTTPRequestMetrics(promMetrics, getRoutePattern)(next) h.ServeHTTP(resp, req) assert.Equal(t, http.StatusOK, resp.Code) labels := makeLabels(http.MethodGet, "/hello", "http-client", "200") - hist := collector.Durations.With(labels).(prometheus.Histogram) + hist := promMetrics.Durations.With(labels).(prometheus.Histogram) testutil.AssertSamplesCountInHistogram(t, hist, 0) }) } diff --git a/httpserver/middleware/throttle/example_test.go b/httpserver/middleware/throttle/example_test.go index be86c80..08209de 100644 --- a/httpserver/middleware/throttle/example_test.go +++ b/httpserver/middleware/throttle/example_test.go @@ -157,16 +157,16 @@ rules: } func makeExampleTestServer(cfg *throttle.Config, longWorkDelay time.Duration) *httptest.Server { - throttleMetrics := throttle.NewMetricsCollector("") - throttleMetrics.MustRegister() - defer throttleMetrics.Unregister() + promMetrics := throttle.NewPrometheusMetrics() + promMetrics.MustRegister() + defer promMetrics.Unregister() // Configure middleware that should do global throttling ("all_reqs" tag says about that). - allReqsThrottleMiddleware := throttle.MiddlewareWithOpts(cfg, apiErrDomain, throttleMetrics, throttle.MiddlewareOpts{ + allReqsThrottleMiddleware := throttle.MiddlewareWithOpts(cfg, apiErrDomain, promMetrics, throttle.MiddlewareOpts{ Tags: []string{"all_reqs"}}) // Configure middleware that should do per-client throttling based on the username from basic auth ("authenticated_reqs" tag says about that). - authenticatedReqsThrottleMiddleware := throttle.MiddlewareWithOpts(cfg, apiErrDomain, throttleMetrics, throttle.MiddlewareOpts{ + authenticatedReqsThrottleMiddleware := throttle.MiddlewareWithOpts(cfg, apiErrDomain, promMetrics, throttle.MiddlewareOpts{ Tags: []string{"authenticated_reqs"}, GetKeyIdentity: func(r *http.Request) (key string, bypass bool, err error) { username, _, ok := r.BasicAuth() diff --git a/httpserver/middleware/throttle/metrics.go b/httpserver/middleware/throttle/metrics.go index 0cbec1c..d55252a 100644 --- a/httpserver/middleware/throttle/metrics.go +++ b/httpserver/middleware/throttle/metrics.go @@ -6,7 +6,9 @@ Released under MIT license. package throttle -import "github.com/prometheus/client_golang/prometheus" +import ( + "github.com/prometheus/client_golang/prometheus" +) const ( metricsLabelDryRun = "dry_run" @@ -19,70 +21,121 @@ const ( metricsValNo = "no" ) -// MetricsCollector represents collector of metrics for rate/in-flight limiting rejects. -type MetricsCollector struct { +// MetricsCollector represents a collector of metrics for rate/in-flight limiting rejects. +type MetricsCollector interface { + // IncInFlightLimitRejects increments the counter of rejected requests due to in-flight limit exceeded. + IncInFlightLimitRejects(ruleName string, dryRun bool, backlogged bool) + + // IncRateLimitRejects increments the counter of rejected requests due to rate limit exceeded. + IncRateLimitRejects(ruleName string, dryRun bool) +} + +// PrometheusMetricsOpts represents options for PrometheusMetrics. +type PrometheusMetricsOpts struct { + // Namespace is a namespace for metrics. It will be prepended to all metric names. + Namespace string + + // ConstLabels is a set of labels that will be applied to all metrics. + ConstLabels prometheus.Labels + + // CurriedLabelNames is a list of label names that will be curried with the provided labels. + // See PrometheusMetrics.MustCurryWith method for more details. + // Keep in mind that if this list is not empty, + // PrometheusMetrics.MustCurryWith method must be called further with the same labels. + // Otherwise, the collector will panic. + CurriedLabelNames []string +} + +// PrometheusMetrics represents a collector of Prometheus metrics for rate/in-flight limiting rejects. +type PrometheusMetrics struct { InFlightLimitRejects *prometheus.CounterVec RateLimitRejects *prometheus.CounterVec } -// NewMetricsCollector creates a new instance of MetricsCollector. -func NewMetricsCollector(namespace string) *MetricsCollector { +// NewPrometheusMetrics creates a new instance of PrometheusMetrics. +func NewPrometheusMetrics() *PrometheusMetrics { + return NewPrometheusMetricsWithOpts(PrometheusMetricsOpts{}) +} + +// NewPrometheusMetricsWithOpts creates a new instance of PrometheusMetrics with the provided options. +func NewPrometheusMetricsWithOpts(opts PrometheusMetricsOpts) *PrometheusMetrics { + makeLabelNames := func(names ...string) []string { + l := append(make([]string, 0, len(opts.CurriedLabelNames)+len(names)), opts.CurriedLabelNames...) + return append(l, names...) + } + inFlightLimitRejects := prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: namespace, - Name: "in_flight_limit_rejects_total", - Help: "Number of rejected requests due to in-flight limit exceeded.", - }, []string{metricsLabelDryRun, metricsLabelRule, metricsLabelBacklogged}) + Namespace: opts.Namespace, + Name: "in_flight_limit_rejects_total", + Help: "Number of rejected requests due to in-flight limit exceeded.", + ConstLabels: opts.ConstLabels, + }, makeLabelNames(metricsLabelDryRun, metricsLabelRule, metricsLabelBacklogged)) rateLimitRejects := prometheus.NewCounterVec(prometheus.CounterOpts{ - Namespace: namespace, - Name: "rate_limit_rejects_total", - Help: "Number of rejected requests due to rate limit exceeded.", - }, []string{metricsLabelDryRun, metricsLabelRule}) + Namespace: opts.Namespace, + Name: "rate_limit_rejects_total", + Help: "Number of rejected requests due to rate limit exceeded.", + ConstLabels: opts.ConstLabels, + }, makeLabelNames(metricsLabelDryRun, metricsLabelRule)) - return &MetricsCollector{ + return &PrometheusMetrics{ InFlightLimitRejects: inFlightLimitRejects, RateLimitRejects: rateLimitRejects, } } // MustCurryWith curries the metrics collector with the provided labels. -func (mc *MetricsCollector) MustCurryWith(labels prometheus.Labels) *MetricsCollector { - return &MetricsCollector{ - InFlightLimitRejects: mc.InFlightLimitRejects.MustCurryWith(labels), - RateLimitRejects: mc.RateLimitRejects.MustCurryWith(labels), +func (pm *PrometheusMetrics) MustCurryWith(labels prometheus.Labels) *PrometheusMetrics { + return &PrometheusMetrics{ + InFlightLimitRejects: pm.InFlightLimitRejects.MustCurryWith(labels), + RateLimitRejects: pm.RateLimitRejects.MustCurryWith(labels), } } // MustRegister does registration of metrics collector in Prometheus and panics if any error occurs. -func (mc *MetricsCollector) MustRegister() { +func (pm *PrometheusMetrics) MustRegister() { prometheus.MustRegister( - mc.InFlightLimitRejects, - mc.RateLimitRejects, + pm.InFlightLimitRejects, + pm.RateLimitRejects, ) } // Unregister cancels registration of metrics collector in Prometheus. -func (mc *MetricsCollector) Unregister() { - prometheus.Unregister(mc.InFlightLimitRejects) - prometheus.Unregister(mc.RateLimitRejects) +func (pm *PrometheusMetrics) Unregister() { + prometheus.Unregister(pm.InFlightLimitRejects) + prometheus.Unregister(pm.RateLimitRejects) } -func makeCommonPromLabels(dryRun bool, rule string) prometheus.Labels { +// IncInFlightLimitRejects increments the counter of rejected requests due to in-flight limit exceeded. +func (pm *PrometheusMetrics) IncInFlightLimitRejects(ruleName string, dryRun bool, backlogged bool) { dryRunVal := metricsValNo if dryRun { dryRunVal = metricsValYes } - return prometheus.Labels{metricsLabelDryRun: dryRunVal, metricsLabelRule: rule} -} - -func makePromLabelsForInFlightLimit(commonLabels prometheus.Labels, backlogged bool) prometheus.Labels { backloggedVal := metricsValNo if backlogged { backloggedVal = metricsValYes } - return prometheus.Labels{ - metricsLabelDryRun: commonLabels[metricsLabelDryRun], - metricsLabelRule: commonLabels[metricsLabelRule], + pm.InFlightLimitRejects.With(prometheus.Labels{ + metricsLabelDryRun: dryRunVal, + metricsLabelRule: ruleName, metricsLabelBacklogged: backloggedVal, + }).Inc() +} + +// IncRateLimitRejects increments the counter of rejected requests due to rate limit exceeded. +func (pm *PrometheusMetrics) IncRateLimitRejects(ruleName string, dryRun bool) { + dryRunVal := metricsValNo + if dryRun { + dryRunVal = metricsValYes } + pm.RateLimitRejects.With(prometheus.Labels{ + metricsLabelDryRun: dryRunVal, + metricsLabelRule: ruleName, + }).Inc() } + +type disabledMetrics struct{} + +func (disabledMetrics) IncInFlightLimitRejects(string, bool, bool) {} +func (disabledMetrics) IncRateLimitRejects(string, bool) {} diff --git a/httpserver/middleware/throttle/middleware.go b/httpserver/middleware/throttle/middleware.go index 5c09d63..929ed55 100644 --- a/httpserver/middleware/throttle/middleware.go +++ b/httpserver/middleware/throttle/middleware.go @@ -20,7 +20,7 @@ import ( "github.com/acronis/go-appkit/restapi" ) -// RuleLogFieldName is a logged field that contains name of the throttling rule. +// RuleLogFieldName is a logged field that contains the name of the throttling rule. const RuleLogFieldName = "throttle_rule" // MiddlewareOpts represents an options for Middleware. @@ -58,9 +58,9 @@ type MiddlewareOpts struct { BuildHandlerAtInit bool } -// RateLimitOpts returns rateLimitMiddlewareParams that may be used for constructing RateLimitMiddleware. -func (opts MiddlewareOpts) RateLimitOpts() rateLimitMiddlewareParams { - return rateLimitMiddlewareParams{ +// rateLimitOpts returns options for constructing rate limiting middleware. +func (opts MiddlewareOpts) rateLimitOpts() rateLimitMiddlewareOpts { + return rateLimitMiddlewareOpts{ GetKeyIdentity: opts.GetKeyIdentity, RateLimitOnReject: opts.RateLimitOnReject, RateLimitOnRejectInDryRun: opts.RateLimitOnRejectInDryRun, @@ -68,9 +68,9 @@ func (opts MiddlewareOpts) RateLimitOpts() rateLimitMiddlewareParams { } } -// InFlightLimitOpts returns inFlightLimitMiddlewareParams that may be used for constructing InFlightLimitMiddleware. -func (opts MiddlewareOpts) InFlightLimitOpts() inFlightLimitMiddlewareParams { - return inFlightLimitMiddlewareParams{ +// inFlightLimitOpts returns options for constructing in-flight limiting middleware. +func (opts MiddlewareOpts) inFlightLimitOpts() inFlightLimitMiddlewareOpts { + return inFlightLimitMiddlewareOpts{ GetKeyIdentity: opts.GetKeyIdentity, InFlightLimitOnReject: opts.InFlightLimitOnReject, InFlightLimitOnRejectInDryRun: opts.InFlightLimitOnRejectInDryRun, @@ -79,12 +79,16 @@ func (opts MiddlewareOpts) InFlightLimitOpts() inFlightLimitMiddlewareParams { } // Middleware is a middleware that throttles incoming HTTP requests based on the passed configuration. -func Middleware(cfg *Config, errDomain string, mc *MetricsCollector) func(next http.Handler) http.Handler { +func Middleware(cfg *Config, errDomain string, mc MetricsCollector) func(next http.Handler) http.Handler { return MiddlewareWithOpts(cfg, errDomain, mc, MiddlewareOpts{}) } // MiddlewareWithOpts is a more configurable version of Middleware. -func MiddlewareWithOpts(cfg *Config, errDomain string, mc *MetricsCollector, opts MiddlewareOpts) func(next http.Handler) http.Handler { +func MiddlewareWithOpts(cfg *Config, errDomain string, mc MetricsCollector, opts MiddlewareOpts) func(next http.Handler) http.Handler { + if mc == nil { + mc = disabledMetrics{} + } + mustMakeRoutesManager := func(next http.Handler) *restapi.RoutesManager { routes, err := makeRoutes(cfg, errDomain, mc, opts, next) if err != nil { @@ -134,7 +138,7 @@ func (h *handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // nolint: gocyclo // we would like to have high functional cohesion here. func makeRoutes( - cfg *Config, errDomain string, mc *MetricsCollector, opts MiddlewareOpts, handler http.Handler, + cfg *Config, errDomain string, mc MetricsCollector, opts MiddlewareOpts, handler http.Handler, ) (routes []restapi.Route, err error) { for _, rule := range cfg.Rules { if len(rule.RateLimits) == 0 && len(rule.InFlightLimits) == 0 { @@ -156,7 +160,7 @@ func makeRoutes( } var inFlightLimitMw func(next http.Handler) http.Handler inFlightLimitMw, err = makeInFlightLimitMiddleware( - &cfgZone, errDomain, rule.Name(), mc, opts.InFlightLimitOpts()) + &cfgZone, errDomain, rule.Name(), mc, opts.inFlightLimitOpts()) if err != nil { return nil, fmt.Errorf("new in-flight limit middleware for zone %q: %w", zoneName, err) } @@ -172,7 +176,7 @@ func makeRoutes( } var rateLimitMw func(next http.Handler) http.Handler rateLimitMw, err = makeRateLimitMiddleware( - &cfgZone, errDomain, rule.Name(), mc, opts.RateLimitOpts()) + &cfgZone, errDomain, rule.Name(), mc, opts.rateLimitOpts()) if err != nil { return nil, fmt.Errorf("new rate limit middleware for zone %q: %w", zoneName, err) } @@ -197,8 +201,8 @@ func makeRoutes( return routes, nil } -// rateLimitMiddlewareParams represents an options for RateLimitMiddleware. -type rateLimitMiddlewareParams struct { +// rateLimitMiddlewareOpts represents an options for RateLimitMiddleware. +type rateLimitMiddlewareOpts struct { // GetKeyIdentity is a function that returns identity string representation. // The returned string is used as a key for zone when key.type is "identity". GetKeyIdentity func(r *http.Request) (key string, bypass bool, err error) @@ -218,8 +222,8 @@ func makeRateLimitMiddleware( cfg *RateLimitZoneConfig, errDomain string, ruleName string, - mc *MetricsCollector, - opts rateLimitMiddlewareParams, + mc MetricsCollector, + opts rateLimitMiddlewareOpts, ) (func(next http.Handler) http.Handler, error) { var alg middleware.RateLimitAlg switch cfg.Alg { @@ -256,11 +260,10 @@ func makeRateLimitMiddleware( if onReject == nil { onReject = middleware.DefaultRateLimitOnReject } - rejectPromLabels := makeCommonPromLabels(false, ruleName) onRejectWithMetrics := func( rw http.ResponseWriter, r *http.Request, params middleware.RateLimitParams, next http.Handler, logger log.FieldLogger, ) { - mc.RateLimitRejects.With(rejectPromLabels).Inc() + mc.IncRateLimitRejects(ruleName, false) if logger != nil { logger = logger.With(log.String(RuleLogFieldName, ruleName)) } @@ -271,11 +274,10 @@ func makeRateLimitMiddleware( if onRejectInDryRun == nil { onRejectInDryRun = middleware.DefaultRateLimitOnRejectInDryRun } - rejectInDryRunPromLabels := makeCommonPromLabels(true, ruleName) onRejectInDryRunWithMetrics := func( rw http.ResponseWriter, r *http.Request, params middleware.RateLimitParams, next http.Handler, logger log.FieldLogger, ) { - mc.RateLimitRejects.With(rejectInDryRunPromLabels).Inc() + mc.IncRateLimitRejects(ruleName, true) if logger != nil { logger = logger.With(log.String(RuleLogFieldName, ruleName)) } @@ -299,7 +301,7 @@ func makeRateLimitMiddleware( }), nil } -type inFlightLimitMiddlewareParams struct { +type inFlightLimitMiddlewareOpts struct { // GetKeyIdentity is a function that returns identity string representation. // The returned string is used as a key for zone when key.type is "identity". GetKeyIdentity func(r *http.Request) (key string, bypass bool, err error) @@ -319,8 +321,8 @@ func makeInFlightLimitMiddleware( cfg *InFlightLimitZoneConfig, errDomain string, ruleName string, - mc *MetricsCollector, - opts inFlightLimitMiddlewareParams, + mc MetricsCollector, + opts inFlightLimitMiddlewareOpts, ) (func(next http.Handler) http.Handler, error) { if cfg.Key.Type == ZoneKeyTypeIdentity && opts.GetKeyIdentity == nil { return nil, fmt.Errorf("GetKeyIdentity is required for identity key type") @@ -342,11 +344,10 @@ func makeInFlightLimitMiddleware( if onReject == nil { onReject = middleware.DefaultInFlightLimitOnReject } - rejectPromLabels := makeCommonPromLabels(false, ruleName) onRejectWithMetrics := func( rw http.ResponseWriter, r *http.Request, params middleware.InFlightLimitParams, next http.Handler, logger log.FieldLogger, ) { - mc.InFlightLimitRejects.With(makePromLabelsForInFlightLimit(rejectPromLabels, params.RequestBacklogged)).Inc() + mc.IncInFlightLimitRejects(ruleName, false, params.RequestBacklogged) if logger != nil { logger = logger.With(log.String(RuleLogFieldName, ruleName)) } @@ -357,11 +358,10 @@ func makeInFlightLimitMiddleware( if onRejectInDryRun == nil { onRejectInDryRun = middleware.DefaultInFlightLimitOnRejectInDryRun } - rejectInDryRunPromLabels := makeCommonPromLabels(true, ruleName) onRejectInDryRunWithMetrics := func( rw http.ResponseWriter, r *http.Request, params middleware.InFlightLimitParams, next http.Handler, logger log.FieldLogger, ) { - mc.InFlightLimitRejects.With(makePromLabelsForInFlightLimit(rejectInDryRunPromLabels, params.RequestBacklogged)).Inc() + mc.IncInFlightLimitRejects(ruleName, true, params.RequestBacklogged) if logger != nil { logger = logger.With(log.String(RuleLogFieldName, ruleName)) } diff --git a/httpserver/middleware/throttle/middleware_test.go b/httpserver/middleware/throttle/middleware_test.go index 160fed4..818be31 100644 --- a/httpserver/middleware/throttle/middleware_test.go +++ b/httpserver/middleware/throttle/middleware_test.go @@ -593,7 +593,7 @@ func makeHandlerWrappedIntoMiddleware( cfg *Config, blockCh chan struct{}, tags []string, buildHandlerAtInit bool, ) (http.Handler, *testCounters) { c := &testCounters{} - mid := MiddlewareWithOpts(cfg, testErrDomain, NewMetricsCollector(""), MiddlewareOpts{ + mid := MiddlewareWithOpts(cfg, testErrDomain, NewPrometheusMetrics(), MiddlewareOpts{ GetKeyIdentity: func(r *http.Request) (key string, bypass bool, err error) { username, _, ok := r.BasicAuth() if !ok { diff --git a/httpserver/router.go b/httpserver/router.go index 5248692..3b6edd3 100644 --- a/httpserver/router.go +++ b/httpserver/router.go @@ -74,7 +74,7 @@ func configureRouter(router chi.Router, logger log.FieldLogger, opts RouterOpts) // nolint // hugeParam: opts is heavy, it's ok in this case. func applyDefaultMiddlewaresToRouter( - router chi.Router, cfg *Config, logger log.FieldLogger, opts Opts, metricsCollector *middleware.HTTPRequestMetricsCollector, + router chi.Router, cfg *Config, logger log.FieldLogger, opts Opts, promMetrics *middleware.HTTPRequestPrometheusMetrics, ) error { router.Use(func(handler http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -109,7 +109,7 @@ func applyDefaultMiddlewaresToRouter( // Custom route pattern parser getRoutePattern = opts.HTTPRequestMetrics.GetRoutePattern } - metricsMiddleware := middleware.HTTPRequestMetricsWithOpts(metricsCollector, getRoutePattern, + metricsMiddleware := middleware.HTTPRequestMetricsWithOpts(promMetrics, getRoutePattern, middleware.HTTPRequestMetricsOpts{ GetUserAgentType: opts.HTTPRequestMetrics.GetUserAgentType, ExcludedEndpoints: systemEndpoints,