diff --git a/examples/gorilla/main.go b/examples/gorilla/main.go index a2d43bf..0dcc2d9 100644 --- a/examples/gorilla/main.go +++ b/examples/gorilla/main.go @@ -12,7 +12,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" metrics "github.com/slok/go-http-metrics/metrics/prometheus" "github.com/slok/go-http-metrics/middleware" - "github.com/slok/go-http-metrics/middleware/std" + muxmiddleware "github.com/slok/go-http-metrics/middleware/std" ) const ( @@ -28,7 +28,7 @@ func main() { // Create our router with the metrics middleware. r := mux.NewRouter() - r.Use(std.HandlerProvider("", mdlw)) + r.Use(muxmiddleware.HandlerProvider("", mdlw)) // Add paths. r.Methods("GET").Path("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/middleware/mux/example_test.go b/middleware/mux/example_test.go new file mode 100644 index 0000000..927e07d --- /dev/null +++ b/middleware/mux/example_test.go @@ -0,0 +1,42 @@ +package mux_test + +import ( + "log" + "net/http" + + "github.com/prometheus/client_golang/prometheus/promhttp" + + metrics "github.com/slok/go-http-metrics/metrics/prometheus" + "github.com/slok/go-http-metrics/middleware" + muxmiddleware "github.com/slok/go-http-metrics/middleware/mux" +) + +// MuxMiddleware shows how you would create a default middleware factory and use it +// to create a Gorilla Mux `http.Handler` compatible middleware. +func Example_muxMiddleware() { + // Create our middleware factory with the default settings. + mdlw := middleware.New(middleware.Config{ + Recorder: metrics.NewRecorder(metrics.Config{}), + }) + + // Create our handler. + myHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello world!")) + }) + + // Wrap our handler with the middleware. + h := muxmiddleware.Handler("", mdlw, myHandler) + + // Serve metrics from the default prometheus registry. + log.Printf("serving metrics at: %s", ":8081") + go func() { + _ = http.ListenAndServe(":8081", promhttp.Handler()) + }() + + // Serve our handler. + log.Printf("listening at: %s", ":8080") + if err := http.ListenAndServe(":8080", h); err != nil { + log.Panicf("error while serving: %s", err) + } +} diff --git a/middleware/mux/mux.go b/middleware/mux/mux.go new file mode 100644 index 0000000..5895f03 --- /dev/null +++ b/middleware/mux/mux.go @@ -0,0 +1,102 @@ +// Package mux is a helper package to get a Gorilla Mux compatible middleware. +package mux + +import ( + "bufio" + "context" + "errors" + "net" + "net/http" + + "github.com/gorilla/mux" + "github.com/slok/go-http-metrics/middleware" +) + +// Handler returns an measuring standard http.Handler. +func Handler(handlerID string, m middleware.Middleware, h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wi := &responseWriterInterceptor{ + statusCode: http.StatusOK, + ResponseWriter: w, + } + reporter := &muxReporter{ + w: wi, + r: r, + } + + m.Measure(handlerID, reporter, func() { + h.ServeHTTP(wi, r) + }) + }) +} + +// HandlerProvider is a helper method that returns a handler provider. This kind of +// provider is a defacto standard in some frameworks (e.g: Gorilla, Chi...). +func HandlerProvider(handlerID string, m middleware.Middleware) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return Handler(handlerID, m, next) + } +} + +type muxReporter struct { + w *responseWriterInterceptor + r *http.Request +} + +func (m *muxReporter) Method() string { return m.r.Method } + +func (m *muxReporter) Context() context.Context { return m.r.Context() } + +func (m *muxReporter) URLPath() string { + path, err := mux.CurrentRoute(m.r).GetPathTemplate() + if err != nil { + return m.r.URL.Path + } + return path +} + +func (m *muxReporter) StatusCode() int { return m.w.statusCode } + +func (m *muxReporter) BytesWritten() int64 { return int64(m.w.bytesWritten) } + +// responseWriterInterceptor is a simple wrapper to intercept set data on a +// ResponseWriter. +type responseWriterInterceptor struct { + http.ResponseWriter + statusCode int + bytesWritten int +} + +func (w *responseWriterInterceptor) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *responseWriterInterceptor) Write(p []byte) (int, error) { + w.bytesWritten += len(p) + return w.ResponseWriter.Write(p) +} + +func (w *responseWriterInterceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("type assertion failed http.ResponseWriter not a http.Hijacker") + } + return h.Hijack() +} + +func (w *responseWriterInterceptor) Flush() { + f, ok := w.ResponseWriter.(http.Flusher) + if !ok { + return + } + + f.Flush() +} + +// Check interface implementations. +var ( + _ http.ResponseWriter = &responseWriterInterceptor{} + _ http.Hijacker = &responseWriterInterceptor{} + _ http.Flusher = &responseWriterInterceptor{} +) diff --git a/middleware/mux/mux_test.go b/middleware/mux/mux_test.go new file mode 100644 index 0000000..ab936b3 --- /dev/null +++ b/middleware/mux/mux_test.go @@ -0,0 +1,163 @@ +package mux_test + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/gorilla/mux" + mmetrics "github.com/slok/go-http-metrics/internal/mocks/metrics" + "github.com/slok/go-http-metrics/metrics" + "github.com/slok/go-http-metrics/middleware" + muxmiddleware "github.com/slok/go-http-metrics/middleware/mux" +) + +func TestMiddleware(t *testing.T) { + tests := map[string]struct { + handlerID string + config middleware.Config + req func() *http.Request + mock func(m *mmetrics.Recorder) + handler func() http.Handler + expRespCode int + expRespBody string + }{ + "A default HTTP middleware should call the recorder to measure.": { + req: func() *http.Request { + return httptest.NewRequest(http.MethodPost, "/test", nil) + }, + mock: func(m *mmetrics.Recorder) { + expHTTPReqProps := metrics.HTTPReqProperties{ + ID: "/test", + Service: "", + Method: "POST", + Code: "202", + } + m.On("ObserveHTTPRequestDuration", mock.Anything, expHTTPReqProps, mock.Anything).Once() + m.On("ObserveHTTPResponseSize", mock.Anything, expHTTPReqProps, int64(15)).Once() + + expHTTPProps := metrics.HTTPProperties{ + ID: "/test", + Service: "", + } + m.On("AddInflightRequests", mock.Anything, expHTTPProps, 1).Once() + m.On("AddInflightRequests", mock.Anything, expHTTPProps, -1).Once() + }, + handler: func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(202) + w.Write([]byte("Я бэтмен")) // nolint: errcheck + }) + }, + expRespCode: 202, + expRespBody: "Я бэтмен", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + // Mocks. + mr := &mmetrics.Recorder{} + test.mock(mr) + + // Create our mux instance with the middleware. + test.config.Recorder = mr + m := middleware.New(test.config) + h := muxmiddleware.Handler(test.handlerID, m, test.handler()) + r := mux.NewRouter() + r.Handle("/test", h) + + // Make the request. + resp := httptest.NewRecorder() + r.ServeHTTP(resp, test.req()) + + // Check. + mr.AssertExpectations(t) + assert.Equal(test.expRespCode, resp.Result().StatusCode) + gotBody, err := io.ReadAll(resp.Result().Body) + require.NoError(err) + assert.Equal(test.expRespBody, string(gotBody)) + }) + } +} + +func TestProvider(t *testing.T) { + tests := map[string]struct { + handlerID string + config middleware.Config + req func() *http.Request + mock func(m *mmetrics.Recorder) + handler func() http.Handler + expRespCode int + expRespBody string + }{ + "A default HTTP middleware should call the recorder to measure.": { + req: func() *http.Request { + return httptest.NewRequest(http.MethodPost, "/test", nil) + }, + mock: func(m *mmetrics.Recorder) { + expHTTPReqProps := metrics.HTTPReqProperties{ + ID: "/test", + Service: "", + Method: "POST", + Code: "202", + } + m.On("ObserveHTTPRequestDuration", mock.Anything, expHTTPReqProps, mock.Anything).Once() + m.On("ObserveHTTPResponseSize", mock.Anything, expHTTPReqProps, int64(15)).Once() + + expHTTPProps := metrics.HTTPProperties{ + ID: "/test", + Service: "", + } + m.On("AddInflightRequests", mock.Anything, expHTTPProps, 1).Once() + m.On("AddInflightRequests", mock.Anything, expHTTPProps, -1).Once() + }, + handler: func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(202) + w.Write([]byte("Я бэтмен")) // nolint: errcheck + }) + }, + expRespCode: 202, + expRespBody: "Я бэтмен", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + // Mocks. + mr := &mmetrics.Recorder{} + test.mock(mr) + + // Create our mux instance with the middleware. + test.config.Recorder = mr + m := middleware.New(test.config) + provider := muxmiddleware.HandlerProvider(test.handlerID, m) + h := provider(test.handler()) + r := mux.NewRouter() + r.Handle("/test", h) + + // Make the request. + resp := httptest.NewRecorder() + r.ServeHTTP(resp, test.req()) + + // Check. + mr.AssertExpectations(t) + assert.Equal(test.expRespCode, resp.Result().StatusCode) + gotBody, err := io.ReadAll(resp.Result().Body) + require.NoError(err) + assert.Equal(test.expRespBody, string(gotBody)) + }) + } +}