From 17fe0df40868e19daba190b931b4ee54e2a9f4d8 Mon Sep 17 00:00:00 2001 From: Leo Antunes Date: Wed, 10 Jul 2024 09:52:36 +0200 Subject: [PATCH] feat: add response header for trace sample status --- config.go | 28 ++++++++++++++++------------ middleware.go | 50 ++++++++++++++++++++++++++++---------------------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/config.go b/config.go index 1763a28..62636ae 100644 --- a/config.go +++ b/config.go @@ -8,17 +8,21 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" ) -const defaultTraceResponseHeaderKey = "X-Trace-Id" +const ( + defaultTraceIDResponseHeaderKey = "X-Trace-Id" + defaultTraceSampledResponseHeaderKey = "X-Trace-Sampled" +) // config is used to configure the mux middleware. type config struct { - TracerProvider oteltrace.TracerProvider - Propagators propagation.TextMapPropagator - ChiRoutes chi.Routes - RequestMethodInSpanName bool - Filters []Filter - TraceResponseHeaderKey string - PublicEndpointFn func(r *http.Request) bool + TracerProvider oteltrace.TracerProvider + Propagators propagation.TextMapPropagator + ChiRoutes chi.Routes + RequestMethodInSpanName bool + Filters []Filter + TraceIDResponseHeaderKey string + TraceSampledResponseKey string + PublicEndpointFn func(r *http.Request) bool } // Option specifies instrumentation configuration options. @@ -32,7 +36,7 @@ func (o optionFunc) apply(c *config) { o(c) } -// Filter is a predicate used to determine whether a given http.request should +// Filter is a predicate used to determine whether a given [http.Request] should // be traced. A Filter must return true if the request should be traced. type Filter func(*http.Request) bool @@ -98,9 +102,9 @@ func WithFilter(filter Filter) Option { func WithTraceIDResponseHeader(headerKeyFunc func() string) Option { return optionFunc(func(cfg *config) { if headerKeyFunc == nil { - cfg.TraceResponseHeaderKey = defaultTraceResponseHeaderKey // use default trace header + cfg.TraceIDResponseHeaderKey = defaultTraceIDResponseHeaderKey // use default trace header } else { - cfg.TraceResponseHeaderKey = headerKeyFunc() + cfg.TraceIDResponseHeaderKey = headerKeyFunc() } }) } @@ -138,7 +142,7 @@ func WithPublicEndpoint() Option { // incoming span context. Otherwise, the generated span will be set as the // child span of the incoming span context. // -// Essentially it has the same functionality as WithPublicEndpoint but with +// Essentially it has the same functionality as [WithPublicEndpoint] but with // more flexibility. func WithPublicEndpointFn(fn func(r *http.Request) bool) Option { return optionFunc(func(cfg *config) { diff --git a/middleware.go b/middleware.go index 4d9cdaf..ee517de 100644 --- a/middleware.go +++ b/middleware.go @@ -2,6 +2,7 @@ package otelchi import ( "net/http" + "strconv" "sync" "github.com/felixge/httpsnoop" @@ -23,7 +24,9 @@ const ( // requests. The serverName parameter should describe the name of the // (virtual) server handling the request. func Middleware(serverName string, opts ...Option) func(next http.Handler) http.Handler { - cfg := config{} + cfg := config{ + TraceSampledResponseKey: defaultTraceSampledResponseHeaderKey, + } for _, opt := range opts { opt.apply(&cfg) } @@ -40,29 +43,31 @@ func Middleware(serverName string, opts ...Option) func(next http.Handler) http. return func(handler http.Handler) http.Handler { return traceware{ - serverName: serverName, - tracer: tracer, - propagators: cfg.Propagators, - handler: handler, - chiRoutes: cfg.ChiRoutes, - reqMethodInSpanName: cfg.RequestMethodInSpanName, - filters: cfg.Filters, - traceResponseHeaderKey: cfg.TraceResponseHeaderKey, - publicEndpointFn: cfg.PublicEndpointFn, + serverName: serverName, + tracer: tracer, + propagators: cfg.Propagators, + handler: handler, + chiRoutes: cfg.ChiRoutes, + reqMethodInSpanName: cfg.RequestMethodInSpanName, + filters: cfg.Filters, + traceIDResponseHeaderKey: cfg.TraceIDResponseHeaderKey, + traceSampledResponseKey: cfg.TraceSampledResponseKey, + publicEndpointFn: cfg.PublicEndpointFn, } } } type traceware struct { - serverName string - tracer oteltrace.Tracer - propagators propagation.TextMapPropagator - handler http.Handler - chiRoutes chi.Routes - reqMethodInSpanName bool - filters []Filter - traceResponseHeaderKey string - publicEndpointFn func(r *http.Request) bool + serverName string + tracer oteltrace.Tracer + propagators propagation.TextMapPropagator + handler http.Handler + chiRoutes chi.Routes + reqMethodInSpanName bool + filters []Filter + traceIDResponseHeaderKey string + traceSampledResponseKey string + publicEndpointFn func(r *http.Request) bool } type recordingResponseWriter struct { @@ -175,9 +180,10 @@ func (tw traceware) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx, span := tw.tracer.Start(ctx, spanName, spanOpts...) defer span.End() - // put trace_id to response header only when WithTraceResponseHeaderKey is used - if len(tw.traceResponseHeaderKey) > 0 && span.SpanContext().HasTraceID() { - w.Header().Add(tw.traceResponseHeaderKey, span.SpanContext().TraceID().String()) + // put trace_id to response header only when [WithTraceIDResponseHeader] is used + if len(tw.traceIDResponseHeaderKey) > 0 && span.SpanContext().HasTraceID() { + w.Header().Add(tw.traceIDResponseHeaderKey, span.SpanContext().TraceID().String()) + w.Header().Add(tw.traceSampledResponseKey, strconv.FormatBool(span.SpanContext().IsSampled())) } // get recording response writer