Skip to content

Commit

Permalink
feat: add response header for trace sample status
Browse files Browse the repository at this point in the history
  • Loading branch information
costela committed Jul 22, 2024
1 parent 5f6a4bf commit 17fe0df
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
28 changes: 16 additions & 12 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,21 @@ import (
oteltrace "go.opentelemetry.io/otel/trace"
)

const defaultTraceResponseHeaderKey = "X-Trace-Id"
const (
defaultTraceIDResponseHeaderKey = "X-Trace-Id"
defaultTraceSampledResponseHeaderKey = "X-Trace-Sampled"
)

// config is used to configure the mux middleware.
type config struct {
TracerProvider oteltrace.TracerProvider
Propagators propagation.TextMapPropagator
ChiRoutes chi.Routes
RequestMethodInSpanName bool
Filters []Filter
TraceResponseHeaderKey string
PublicEndpointFn func(r *http.Request) bool
TracerProvider oteltrace.TracerProvider
Propagators propagation.TextMapPropagator
ChiRoutes chi.Routes
RequestMethodInSpanName bool
Filters []Filter
TraceIDResponseHeaderKey string
TraceSampledResponseKey string
PublicEndpointFn func(r *http.Request) bool
}

// Option specifies instrumentation configuration options.
Expand All @@ -32,7 +36,7 @@ func (o optionFunc) apply(c *config) {
o(c)
}

// Filter is a predicate used to determine whether a given http.request should
// Filter is a predicate used to determine whether a given [http.Request] should
// be traced. A Filter must return true if the request should be traced.
type Filter func(*http.Request) bool

Expand Down Expand Up @@ -98,9 +102,9 @@ func WithFilter(filter Filter) Option {
func WithTraceIDResponseHeader(headerKeyFunc func() string) Option {
return optionFunc(func(cfg *config) {
if headerKeyFunc == nil {
cfg.TraceResponseHeaderKey = defaultTraceResponseHeaderKey // use default trace header
cfg.TraceIDResponseHeaderKey = defaultTraceIDResponseHeaderKey // use default trace header
} else {
cfg.TraceResponseHeaderKey = headerKeyFunc()
cfg.TraceIDResponseHeaderKey = headerKeyFunc()
}
})
}
Expand Down Expand Up @@ -138,7 +142,7 @@ func WithPublicEndpoint() Option {
// incoming span context. Otherwise, the generated span will be set as the
// child span of the incoming span context.
//
// Essentially it has the same functionality as WithPublicEndpoint but with
// Essentially it has the same functionality as [WithPublicEndpoint] but with
// more flexibility.
func WithPublicEndpointFn(fn func(r *http.Request) bool) Option {
return optionFunc(func(cfg *config) {
Expand Down
50 changes: 28 additions & 22 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package otelchi

import (
"net/http"
"strconv"
"sync"

"github.com/felixge/httpsnoop"
Expand All @@ -23,7 +24,9 @@ const (
// requests. The serverName parameter should describe the name of the
// (virtual) server handling the request.
func Middleware(serverName string, opts ...Option) func(next http.Handler) http.Handler {
cfg := config{}
cfg := config{
TraceSampledResponseKey: defaultTraceSampledResponseHeaderKey,
}
for _, opt := range opts {
opt.apply(&cfg)
}
Expand All @@ -40,29 +43,31 @@ func Middleware(serverName string, opts ...Option) func(next http.Handler) http.

return func(handler http.Handler) http.Handler {
return traceware{
serverName: serverName,
tracer: tracer,
propagators: cfg.Propagators,
handler: handler,
chiRoutes: cfg.ChiRoutes,
reqMethodInSpanName: cfg.RequestMethodInSpanName,
filters: cfg.Filters,
traceResponseHeaderKey: cfg.TraceResponseHeaderKey,
publicEndpointFn: cfg.PublicEndpointFn,
serverName: serverName,
tracer: tracer,
propagators: cfg.Propagators,
handler: handler,
chiRoutes: cfg.ChiRoutes,
reqMethodInSpanName: cfg.RequestMethodInSpanName,
filters: cfg.Filters,
traceIDResponseHeaderKey: cfg.TraceIDResponseHeaderKey,
traceSampledResponseKey: cfg.TraceSampledResponseKey,
publicEndpointFn: cfg.PublicEndpointFn,
}
}
}

type traceware struct {
serverName string
tracer oteltrace.Tracer
propagators propagation.TextMapPropagator
handler http.Handler
chiRoutes chi.Routes
reqMethodInSpanName bool
filters []Filter
traceResponseHeaderKey string
publicEndpointFn func(r *http.Request) bool
serverName string
tracer oteltrace.Tracer
propagators propagation.TextMapPropagator
handler http.Handler
chiRoutes chi.Routes
reqMethodInSpanName bool
filters []Filter
traceIDResponseHeaderKey string
traceSampledResponseKey string
publicEndpointFn func(r *http.Request) bool
}

type recordingResponseWriter struct {
Expand Down Expand Up @@ -175,9 +180,10 @@ func (tw traceware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := tw.tracer.Start(ctx, spanName, spanOpts...)
defer span.End()

// put trace_id to response header only when WithTraceResponseHeaderKey is used
if len(tw.traceResponseHeaderKey) > 0 && span.SpanContext().HasTraceID() {
w.Header().Add(tw.traceResponseHeaderKey, span.SpanContext().TraceID().String())
// put trace_id to response header only when [WithTraceIDResponseHeader] is used
if len(tw.traceIDResponseHeaderKey) > 0 && span.SpanContext().HasTraceID() {
w.Header().Add(tw.traceIDResponseHeaderKey, span.SpanContext().TraceID().String())
w.Header().Add(tw.traceSampledResponseKey, strconv.FormatBool(span.SpanContext().IsSampled()))
}

// get recording response writer
Expand Down

0 comments on commit 17fe0df

Please sign in to comment.