diff --git a/contrib/go-chi/chi/chi.go b/contrib/go-chi/chi/chi.go index 3f2f29c90d..e4aa097ac2 100644 --- a/contrib/go-chi/chi/chi.go +++ b/contrib/go-chi/chi/chi.go @@ -67,7 +67,7 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler { } span.SetTag(ext.HTTPCode, strconv.Itoa(status)) - if status >= 500 && status < 600 { + if cfg.isStatusError(status) { // mark 5xx server error span.SetTag(ext.Error, fmt.Errorf("%d: %s", status, http.StatusText(status))) } diff --git a/contrib/go-chi/chi/chi_test.go b/contrib/go-chi/chi/chi_test.go index 27a1328585..82f72cbfbc 100644 --- a/contrib/go-chi/chi/chi_test.go +++ b/contrib/go-chi/chi/chi_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strconv" "testing" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" @@ -99,37 +100,74 @@ func TestTrace200(t *testing.T) { } func TestError(t *testing.T) { - assert := assert.New(t) - mt := mocktracer.Start() - defer mt.Stop() + assertSpan := func(assert *assert.Assertions, spans []mocktracer.Span, code int) { + assert.Len(spans, 1) + if len(spans) < 1 { + t.Fatalf("no spans") + } + span := spans[0] + assert.Equal("http.request", span.OperationName()) + assert.Equal("foobar", span.Tag(ext.ServiceName)) - // setup - router := chi.NewRouter() - router.Use(Middleware(WithServiceName("foobar"))) - code := 500 - wantErr := fmt.Sprintf("%d: %s", code, http.StatusText(code)) + assert.Equal(strconv.Itoa(code), span.Tag(ext.HTTPCode)) + + wantErr := fmt.Sprintf("%d: %s", code, http.StatusText(code)) + assert.Equal(wantErr, span.Tag(ext.Error).(error).Error()) + } - // a handler with an error and make the requests - router.Get("/err", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, fmt.Sprintf("%d!", code), code) + t.Run("default", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + // setup + router := chi.NewRouter() + router.Use(Middleware(WithServiceName("foobar"))) + code := 500 + + // a handler with an error and make the requests + router.Get("/err", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, fmt.Sprintf("%d!", code), code) + }) + r := httptest.NewRequest("GET", "/err", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + response := w.Result() + assert.Equal(response.StatusCode, code) + + // verify the errors and status are correct + spans := mt.FinishedSpans() + assertSpan(assert, spans, code) }) - r := httptest.NewRequest("GET", "/err", nil) - w := httptest.NewRecorder() - router.ServeHTTP(w, r) - response := w.Result() - assert.Equal(response.StatusCode, 500) - // verify the errors and status are correct - spans := mt.FinishedSpans() - assert.Len(spans, 1) - if len(spans) < 1 { - t.Fatalf("no spans") - } - span := spans[0] - assert.Equal("http.request", span.OperationName()) - assert.Equal("foobar", span.Tag(ext.ServiceName)) - assert.Equal("500", span.Tag(ext.HTTPCode)) - assert.Equal(wantErr, span.Tag(ext.Error).(error).Error()) + t.Run("custom", func(t *testing.T) { + assert := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + + // setup + router := chi.NewRouter() + router.Use(Middleware( + WithServiceName("foobar"), + WithStatusCheck(func(statusCode int) bool { + return statusCode >= 400 + }), + )) + code := 404 + // a handler with an error and make the requests + router.Get("/err", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, fmt.Sprintf("%d!", code), code) + }) + r := httptest.NewRequest("GET", "/err", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, r) + response := w.Result() + assert.Equal(response.StatusCode, code) + + // verify the errors and status are correct + spans := mt.FinishedSpans() + assertSpan(assert, spans, code) + }) } func TestGetSpanNotInstrumented(t *testing.T) { diff --git a/contrib/go-chi/chi/option.go b/contrib/go-chi/chi/option.go index 1db67384bb..a894d7834d 100644 --- a/contrib/go-chi/chi/option.go +++ b/contrib/go-chi/chi/option.go @@ -17,6 +17,7 @@ type config struct { serviceName string spanOpts []ddtrace.StartSpanOption // additional span options to be applied analyticsRate float64 + isStatusError func(statusCode int) bool } // Option represents an option that can be passed to NewRouter. @@ -32,6 +33,7 @@ func defaults(cfg *config) { } else { cfg.analyticsRate = globalconfig.AnalyticsRate() } + cfg.isStatusError = isServerError } // WithServiceName sets the given service name for the router. @@ -71,3 +73,15 @@ func WithAnalyticsRate(rate float64) Option { } } } + +// WithStatusCheck specifies a function fn which reports whether the passed +// statusCode should be considered an error. +func WithStatusCheck(fn func(statusCode int) bool) Option { + return func(cfg *config) { + cfg.isStatusError = fn + } +} + +func isServerError(statusCode int) bool { + return statusCode >= 500 && statusCode < 600 +}