diff --git a/transport/tchannel/channel_outbound.go b/transport/tchannel/channel_outbound.go index 21e75526f..2d3adc982 100644 --- a/transport/tchannel/channel_outbound.go +++ b/transport/tchannel/channel_outbound.go @@ -195,7 +195,9 @@ func (o *ChannelOutbound) Call(ctx context.Context, req *transport.Request) (*tr } err = getResponseError(headers) - deleteReservedHeaders(headers) + // no check: err will be returned as is + + deleteReservedHeaders(headers, o.transport.reservedHeaderMetric.With(req.Caller, req.Service)) resp := &transport.Response{ Headers: headers, diff --git a/transport/tchannel/channel_transport.go b/transport/tchannel/channel_transport.go index 2b1331cfd..49d04cc1d 100644 --- a/transport/tchannel/channel_transport.go +++ b/transport/tchannel/channel_transport.go @@ -26,6 +26,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/uber/tchannel-go" "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/pkg/lifecycle" "go.uber.org/zap" ) @@ -82,13 +83,14 @@ func (options transportOptions) newChannelTransport() *ChannelTransport { logger = zap.NewNop() } return &ChannelTransport{ - once: lifecycle.NewOnce(), - ch: options.ch, - addr: options.addr, - tracer: options.tracer, - logger: logger.Named("tchannel"), - originalHeaders: options.originalHeaders, - newResponseWriter: newHandlerWriter, + once: lifecycle.NewOnce(), + ch: options.ch, + addr: options.addr, + tracer: options.tracer, + logger: logger.Named("tchannel"), + originalHeaders: options.originalHeaders, + newResponseWriter: newHandlerWriter, + reservedHeaderMetric: observability.NewReserveHeaderMetrics(options.meter, TransportName+"_channel"), } } @@ -97,14 +99,15 @@ func (options transportOptions) newChannelTransport() *ChannelTransport { // If you have a YARPC peer.Chooser, use the unqualified tchannel.Transport // instead. type ChannelTransport struct { - once *lifecycle.Once - ch Channel - addr string - tracer opentracing.Tracer - logger *zap.Logger - router transport.Router - originalHeaders bool - newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter + once *lifecycle.Once + ch Channel + addr string + tracer opentracing.Tracer + logger *zap.Logger + router transport.Router + originalHeaders bool + newResponseWriter responseWriterConstructor + reservedHeaderMetric *observability.ReservedHeaderMetrics } // Channel returns the underlying TChannel "Channel" instance. @@ -140,11 +143,12 @@ func (t *ChannelTransport) start() error { sc := t.ch.GetSubChannel(s) existing := sc.GetHandlers() sc.SetHandler(handler{ - existing: existing, - router: t.router, - tracer: t.tracer, - logger: t.logger, - newResponseWriter: t.newResponseWriter, + existing: existing, + router: t.router, + tracer: t.tracer, + logger: t.logger, + reservedHeaderMetrics: t.reservedHeaderMetric, + newResponseWriter: t.newResponseWriter, }) } } diff --git a/transport/tchannel/handler.go b/transport/tchannel/handler.go index 672e88a7d..b7df406c6 100644 --- a/transport/tchannel/handler.go +++ b/transport/tchannel/handler.go @@ -30,6 +30,7 @@ import ( "go.uber.org/multierr" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/internal/bufferpool" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/pkg/errors" "go.uber.org/yarpc/yarpcerrors" "go.uber.org/zap" @@ -99,6 +100,7 @@ type handler struct { tracer opentracing.Tracer headerCase headerCase logger *zap.Logger + reservedHeaderMetrics *observability.ReservedHeaderMetrics newResponseWriter responseWriterConstructor excludeServiceHeaderInResponse bool } @@ -109,7 +111,7 @@ func (h handler) Handle(ctx ncontext.Context, call *tchannel.InboundCall) { func (h handler) handle(ctx context.Context, call inboundCall) { // you MUST close the responseWriter no matter what unless you have a tchannel.SystemError - responseWriter := h.newResponseWriter(call.Response(), call.Format(), h.headerCase) + responseWriter := h.newResponseWriter(call.Response(), call.Format(), h.headerCase, h.reservedHeaderMetrics.With(call.CallerName(), call.ServiceName())) defer responseWriter.ReleaseBuffer() if !h.excludeServiceHeaderInResponse { @@ -183,6 +185,7 @@ func (h handler) callHandler(ctx context.Context, call inboundCall, responseWrit } transportHeadersToRequest(treq, headers) + deleteReservedPrefixHeaders(headers, h.reservedHeaderMetrics.With(call.CallerName(), call.ServiceName())) treq.Headers = headers if tcall, ok := call.(tchannelCall); ok { diff --git a/transport/tchannel/handler_test.go b/transport/tchannel/handler_test.go index 1fe316276..9e40c0021 100644 --- a/transport/tchannel/handler_test.go +++ b/transport/tchannel/handler_test.go @@ -38,6 +38,7 @@ import ( "go.uber.org/yarpc/api/transport/transporttest" "go.uber.org/yarpc/encoding/json" "go.uber.org/yarpc/encoding/raw" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/internal/routertest" "go.uber.org/yarpc/internal/testtime" pkgerrors "go.uber.org/yarpc/pkg/errors" @@ -580,7 +581,7 @@ func TestResponseWriter(t *testing.T) { resp := newResponseRecorder() call.resp = resp - w := newHandlerWriter(call.Response(), call.Format(), tt.headerCase) + w := newHandlerWriter(call.Response(), call.Format(), tt.headerCase, observability.ReservedHeaderEdgeMetrics{}) tt.apply(w) assert.NoError(t, w.Close()) @@ -623,7 +624,7 @@ func TestResponseWriterFailure(t *testing.T) { resp := newResponseRecorder() tt.setupResp(resp) - w := newHandlerWriter(resp, tchannel.Raw, canonicalizedHeaderCase) + w := newHandlerWriter(resp, tchannel.Raw, canonicalizedHeaderCase, observability.ReservedHeaderEdgeMetrics{}) _, err := w.Write([]byte("foo")) assert.NoError(t, err) _, err = w.Write([]byte("bar")) @@ -638,7 +639,7 @@ func TestResponseWriterFailure(t *testing.T) { func TestResponseWriterEmptyBodyHeaders(t *testing.T) { res := newResponseRecorder() - w := newHandlerWriter(res, tchannel.Raw, canonicalizedHeaderCase) + w := newHandlerWriter(res, tchannel.Raw, canonicalizedHeaderCase, observability.ReservedHeaderEdgeMetrics{}) w.AddHeaders(transport.NewHeaders().With("foo", "bar")) require.NoError(t, w.Close()) @@ -809,7 +810,7 @@ func TestRpcServiceHeader(t *testing.T) { hw := &responseWriterImpl{} h := handler{ headerCase: canonicalizedHeaderCase, - newResponseWriter: func(inboundCallResponse, tchannel.Format, headerCase) responseWriter { + newResponseWriter: func(inboundCallResponse, tchannel.Format, headerCase, observability.ReservedHeaderEdgeMetrics) responseWriter { return hw }, } diff --git a/transport/tchannel/header.go b/transport/tchannel/header.go index cd2920b99..e2457fcd0 100644 --- a/transport/tchannel/header.go +++ b/transport/tchannel/header.go @@ -28,6 +28,7 @@ import ( "github.com/uber/tchannel-go" "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/transport/tchannel/internal" "go.uber.org/yarpc/yarpcerrors" ) @@ -68,11 +69,25 @@ var _reservedHeaderKeys = map[string]struct{}{ CallerProcedureHeader: {}, } +var ( + // enforceHeaderRules is a feature flag for a more strict header handling rules. + // If true and isReservedHeaderPrefix also true, an error will be returned for + // attempt to set such header; header will be stripped for incoming requests and receiving responses. + // See https://github.com/yarpc/yarpc-go/issues/2265 for more details. + enforceHeaderRules = false +) + +// isReservedHeaderKey checks header name by exact match. func isReservedHeaderKey(key string) bool { _, ok := _reservedHeaderKeys[strings.ToLower(key)] return ok } +// isReservedHeaderPrefix checks header name by prefix match. +func isReservedHeaderPrefix(header string) bool { + return strings.HasPrefix(strings.ToLower(header), "rpc-") || strings.HasPrefix(strings.ToLower(header), "$rpc$-") +} + // readRequestHeaders reads headers and baggage from an incoming request. func readRequestHeaders( ctx context.Context, @@ -236,10 +251,46 @@ func getHeaderMap(hs transport.Headers, headerCase headerCase) map[string]string } } -func deleteReservedHeaders(headers transport.Headers) { +func findReservedHeaderPrefix(headers map[string]string) (string, bool) { + for key := range headers { + if isReservedHeaderPrefix(key) { + return key, true + } + } + return "", false +} + +func validateApplicationHeaders(headers map[string]string, edgeMetrics observability.ReservedHeaderEdgeMetrics) error { + key, found := findReservedHeaderPrefix(headers) + if !found { + return nil + } + + edgeMetrics.IncError() + + if enforceHeaderRules { + return yarpcerrors.InternalErrorf("header with rpc prefix is not allowed in request application headers (%s was passed)", key) + } + return nil +} + +func deleteReservedHeaders(headers transport.Headers, edgeMetrics observability.ReservedHeaderEdgeMetrics) { for headerKey := range _reservedHeaderKeys { headers.Del(headerKey) } + + deleteReservedPrefixHeaders(headers, edgeMetrics) +} + +func deleteReservedPrefixHeaders(headers transport.Headers, edgeMetrics observability.ReservedHeaderEdgeMetrics) { + for key := range headers.Items() { + if isReservedHeaderPrefix(key) { + edgeMetrics.IncStripped() + if enforceHeaderRules { + headers.Del(key) + } + } + } } // this check ensures that the service we're issuing a request to is the one diff --git a/transport/tchannel/header_test.go b/transport/tchannel/header_test.go index 7aa98084a..e914917be 100644 --- a/transport/tchannel/header_test.go +++ b/transport/tchannel/header_test.go @@ -29,7 +29,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/tchannel-go" + "go.uber.org/net/metrics" "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/yarpcerrors" ) @@ -360,3 +362,243 @@ func TestValidateServiceHeaders(t *testing.T) { }) } } + +func TestFindReservedHeaderPrefix(t *testing.T) { + tests := map[string]struct { + headers map[string]string + expKeys []string + expFound bool + }{ + "nil-headers": {}, + "no-reserved-headers": { + headers: map[string]string{ + "any-header-1": "any-value-1", + "any-header-2": "any-value-2", + }, + }, + "reserved-known-headers": { + headers: map[string]string{ + ServiceHeaderKey: "any-value", + }, + expKeys: []string{ServiceHeaderKey}, + expFound: true, + }, + "reserved-prefix": { + headers: map[string]string{ + "rpc-any": "any-value", + "any-header": "any-value", + }, + expKeys: []string{"rpc-any"}, + expFound: true, + }, + "reserved-dollar-prefix": { + headers: map[string]string{ + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + expKeys: []string{"$rpc$-any"}, + expFound: true, + }, + "multiple-reserved-prefix": { + headers: map[string]string{ + "rpc-any": "any-value", + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + expKeys: []string{"rpc-any", "$rpc$-any"}, + expFound: true, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + key, found := findReservedHeaderPrefix(tt.headers) + if len(tt.expKeys) > 0 { + assert.Contains(t, tt.expKeys, key) + } else { + assert.Empty(t, key) + } + assert.Equal(t, tt.expFound, found) + }) + } +} + +func TestValidateApplicationHeaders(t *testing.T) { + tests := map[string]struct { + headers map[string]string + enforceHeaderRule bool + expErr error + expReportHeader bool + }{ + "no-headers-no-error": {}, + "valid-headers-no-error": { + headers: map[string]string{ + "valid-key": "valid-value", + }, + }, + "reserved-rpc-header-error": { + headers: map[string]string{ + "rpc-any": "any-value", + }, + expReportHeader: true, + }, + "reserved-rpc-header-error-enforced-rule": { + headers: map[string]string{ + "rpc-any": "any-value", + }, + enforceHeaderRule: true, + expReportHeader: true, + expErr: yarpcerrors.InternalErrorf("header with rpc prefix is not allowed in request application headers (rpc-any was passed)"), + }, + "reserved-dollad-rpc-header-error": { + headers: map[string]string{ + "$rpc$-any": "any-value", + }, + expReportHeader: true, + }, + "reserved-dollad-rpc-header-error-enforced-rule": { + headers: map[string]string{ + "$rpc$-any": "any-value", + }, + enforceHeaderRule: true, + expReportHeader: true, + expErr: yarpcerrors.InternalErrorf("header with rpc prefix is not allowed in request application headers ($rpc$-any was passed)"), + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRule) + + root := metrics.New() + m := observability.NewReserveHeaderMetrics(root.Scope(), "tchannel") + + err := validateApplicationHeaders(tt.headers, m.With("any-source", "any-dest")) + assert.Equal(t, tt.expErr, err) + + if tt.expReportHeader { + assertTuple(t, root.Snapshot().Counters, tuple{"tchannel_reserved_headers_error", "any-source", "any-dest", 1}) + } else { + assertEmptyMetrics(t, root.Snapshot()) + } + }) + } +} + +func TestDeleteReservedHeaders(t *testing.T) { + tests := map[string]struct { + headers map[string]string + enforceHeaderRule bool + expHeaders map[string]string + expReservedHeadersMetric int64 + }{ + "nil-headers": {}, + "no-reserved-headers": { + headers: map[string]string{ + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "any-header": "any-value", + }, + }, + "reserved-known-headers": { + headers: map[string]string{ + ServiceHeaderKey: "any-value", + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "any-header": "any-value", + }, + }, + "reserved-rpc-headers": { + headers: map[string]string{ + "rpc-any": "any-value", + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "rpc-any": "any-value", + "any-header": "any-value", + }, + expReservedHeadersMetric: 1, + }, + "reserved-dollar-rpc-headers": { + headers: map[string]string{ + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + expHeaders: map[string]string{ + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + expReservedHeadersMetric: 1, + }, + "enforce-header-rules": { + headers: map[string]string{ + "rpc-any": "any-value", + "$rpc$-any": "any-value", + "any-header": "any-value", + }, + enforceHeaderRule: true, + expHeaders: map[string]string{ + "any-header": "any-value", + }, + expReservedHeadersMetric: 2, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRule) + + root := metrics.New() + m := observability.NewReserveHeaderMetrics(root.Scope(), "tchannel") + + headers := transport.HeadersFromMap(tt.headers) + deleteReservedHeaders(headers, m.With("any-source", "any-dest")) + assert.Equal(t, transport.HeadersFromMap(tt.expHeaders), headers) + + if tt.expReservedHeadersMetric > 0 { + assertTuple(t, root.Snapshot().Counters, tuple{"tchannel_reserved_headers_stripped", "any-source", "any-dest", tt.expReservedHeadersMetric}) + } else { + assertEmptyMetrics(t, root.Snapshot()) + } + }) + + } +} + +func switchEnforceHeaderRules(t *testing.T, cond bool) { + if !cond { + return + } + + enforceHeaderRules = true + t.Cleanup(func() { + enforceHeaderRules = false + }) +} + +type tuple struct { + name, tag1, tag2 string + value int64 +} + +func assertTuple(t *testing.T, snapshots []metrics.Snapshot, expected tuple) { + assertTuples(t, snapshots, []tuple{expected}) +} + +func assertTuples(t *testing.T, snapshots []metrics.Snapshot, expected []tuple) { + actual := make([]tuple, 0, len(snapshots)) + + for _, c := range snapshots { + actual = append(actual, tuple{c.Name, c.Tags["source"], c.Tags["dest"], c.Value}) + } + + assert.ElementsMatch(t, expected, actual) +} + +func assertEmptyMetrics(t *testing.T, snapshot *metrics.RootSnapshot) { + assert.Empty(t, snapshot.Counters) + assert.Empty(t, snapshot.Gauges) + assert.Empty(t, snapshot.Histograms) +} diff --git a/transport/tchannel/outbound.go b/transport/tchannel/outbound.go index 0da49c5d8..bdbea6952 100644 --- a/transport/tchannel/outbound.go +++ b/transport/tchannel/outbound.go @@ -32,6 +32,7 @@ import ( "go.uber.org/yarpc/api/x/introspection" "go.uber.org/yarpc/internal/bufferpool" "go.uber.org/yarpc/internal/iopool" + "go.uber.org/yarpc/internal/observability" intyarpcerrors "go.uber.org/yarpc/internal/yarpcerrors" peerchooser "go.uber.org/yarpc/peer" "go.uber.org/yarpc/peer/hostport" @@ -121,11 +122,11 @@ func (o *Outbound) Call(ctx context.Context, req *transport.Request) (*transport // Call sends an RPC to this specific peer. func (p *tchannelPeer) Call(ctx context.Context, req *transport.Request, reuseBuffer bool) (*transport.Response, error) { - return callWithPeer(ctx, req, p.getPeer(), p.transport.headerCase, reuseBuffer) + return callWithPeer(ctx, req, p.getPeer(), p.transport.headerCase, reuseBuffer, p.transport.reservedHeaderMetric) } // callWithPeer sends a request with the chosen peer. -func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Peer, headerCase headerCase, reuseBuffer bool) (*transport.Response, error) { +func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Peer, headerCase headerCase, reuseBuffer bool, headerMetrics *observability.ReservedHeaderMetrics) (*transport.Response, error) { // NB(abg): Under the current API, the local service's name is required // twice: once when constructing the TChannel and then again when // constructing the RPC. @@ -158,6 +159,11 @@ func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Pe } reqHeaders := getHeaderMap(req.Headers, headerCase) + edgeMetrics := headerMetrics.With(req.Caller, req.Service) + + if err := validateApplicationHeaders(reqHeaders, edgeMetrics); err != nil { + return nil, err + } reqHeaders = requestToTransportHeaders(req, reqHeaders) @@ -211,7 +217,7 @@ func callWithPeer(ctx context.Context, req *transport.Request, peer *tchannel.Pe applicationErrorDetails, _ := headers.Get(ApplicationErrorDetailsHeaderKey) err = getResponseError(headers) - deleteReservedHeaders(headers) + deleteReservedHeaders(headers, edgeMetrics) resp := &transport.Response{ Headers: headers, diff --git a/transport/tchannel/outbound_test.go b/transport/tchannel/outbound_test.go index 66d277cca..3103aa82d 100644 --- a/transport/tchannel/outbound_test.go +++ b/transport/tchannel/outbound_test.go @@ -33,6 +33,7 @@ import ( "github.com/uber/tchannel-go/testutils" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/encoding/raw" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/internal/testtime" "go.uber.org/yarpc/yarpcerrors" "golang.org/x/net/context" @@ -95,7 +96,7 @@ func TestOutboundHeaders(t *testing.T) { return } - deleteReservedHeaders(headers) + deleteReservedHeaders(headers, observability.ReservedHeaderEdgeMetrics{}) assert.Equal(t, tt.wantHeaders, headers.OriginalItems(), "headers did not match") // write a response diff --git a/transport/tchannel/response_writer.go b/transport/tchannel/response_writer.go index 76ee10062..a0ecaec24 100644 --- a/transport/tchannel/response_writer.go +++ b/transport/tchannel/response_writer.go @@ -27,9 +27,10 @@ import ( "github.com/uber/tchannel-go" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/internal/bufferpool" + "go.uber.org/yarpc/internal/observability" ) -type responseWriterConstructor func(inboundCallResponse, tchannel.Format, headerCase) responseWriter +type responseWriterConstructor func(inboundCallResponse, tchannel.Format, headerCase, observability.ReservedHeaderEdgeMetrics) responseWriter type responseWriterImpl struct { failedWith error @@ -39,22 +40,42 @@ type responseWriterImpl struct { response inboundCallResponse applicationError bool headerCase headerCase + edgeMetrics observability.ReservedHeaderEdgeMetrics } -func newHandlerWriter(response inboundCallResponse, format tchannel.Format, headerCase headerCase) responseWriter { +func newHandlerWriter(response inboundCallResponse, format tchannel.Format, headerCase headerCase, edgeMetrics observability.ReservedHeaderEdgeMetrics) responseWriter { return &responseWriterImpl{ - response: response, - format: format, - headerCase: headerCase, + response: response, + format: format, + headerCase: headerCase, + edgeMetrics: edgeMetrics, } } func (w *responseWriterImpl) AddHeaders(h transport.Headers) { for k, v := range h.OriginalItems() { + if !isReservedHeaderPrefix(k) { + w.addHeader(k, v) + continue + } + + w.edgeMetrics.IncError() + + // Return error for prefix if new rules are enforced already + if enforceHeaderRules { + w.failedWith = appendError(w.failedWith, fmt.Errorf("header with rpc prefix is not allowed in response application headers (%s was passed)", k)) + return + } + + // If the header is a reserved header, return error regardless of the new rules feature flag, + // because it's an enforced by default rule. if isReservedHeaderKey(k) { w.failedWith = appendError(w.failedWith, fmt.Errorf("cannot use reserved header key: %s", k)) return } + + // Header with reserved prefix is used, but it's not a reserved key (existing rule), + // and new rule is not enforced yet, so we just report it. w.addHeader(k, v) } } diff --git a/transport/tchannel/response_writer_test.go b/transport/tchannel/response_writer_test.go new file mode 100644 index 000000000..1bdf3b2df --- /dev/null +++ b/transport/tchannel/response_writer_test.go @@ -0,0 +1,91 @@ +// Copyright (c) 2024 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package tchannel + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/multierr" + "go.uber.org/net/metrics" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/internal/observability" +) + +func TestResponseWriterAddHeaders(t *testing.T) { + tests := map[string]struct { + h transport.Headers + enforceHeaderRules bool + expErr error + expReportHeader bool + expHeaders transport.Headers + }{ + "success": { + h: transport.NewHeaders().With("foo", "bar"), + expHeaders: transport.NewHeaders().With("foo", "bar"), + }, + "known-reserved-header-used-which-lead-to-error": { + h: transport.NewHeaders().With(ServiceHeaderKey, "any-value"), + expErr: fmt.Errorf("cannot use reserved header key: %s", ServiceHeaderKey), + expReportHeader: true, + }, + "unknown-reserved-header-used-which-lead-reporting-metric": { + h: transport.NewHeaders().With("rpc-any", "any-value"), + expHeaders: transport.NewHeaders().With("rpc-any", "any-value"), + expReportHeader: true, + }, + "enforce-header-rules": { + h: transport.NewHeaders().With("rpc-any", "any-value"), + enforceHeaderRules: true, + expErr: fmt.Errorf("header with rpc prefix is not allowed in response application headers (rpc-any was passed)"), + expReportHeader: true, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + switchEnforceHeaderRules(t, tt.enforceHeaderRules) + + root := metrics.New() + m := observability.NewReserveHeaderMetrics(root.Scope(), "tchannel") + + rw := responseWriterImpl{edgeMetrics: m.With("any-source", "any-dest")} + + rw.AddHeaders(tt.h) + if tt.expErr != nil { + errs := multierr.Errors(rw.failedWith) + require.Len(t, errs, 1) + assert.Equal(t, tt.expErr, errs[0]) + } else { + assert.NoError(t, rw.failedWith) + } + assert.Equal(t, tt.expHeaders, rw.headers) + + if tt.expReportHeader { + assertTuple(t, root.Snapshot().Counters, tuple{"tchannel_reserved_headers_error", "any-source", "any-dest", 1}) + } else { + assertEmptyMetrics(t, root.Snapshot()) + } + }) + } +} diff --git a/transport/tchannel/tchannel_utils_test.go b/transport/tchannel/tchannel_utils_test.go index 87c47921a..70262552d 100644 --- a/transport/tchannel/tchannel_utils_test.go +++ b/transport/tchannel/tchannel_utils_test.go @@ -26,6 +26,7 @@ import ( "io" "github.com/uber/tchannel-go" + "go.uber.org/yarpc/internal/observability" ) func readArgs(r tchannel.ArgReadable) (arg2, arg3 []byte, err error) { @@ -173,7 +174,7 @@ func (fr *faultyResponseRecorder) SendSystemError(err error) error { // inside tchannel.Handle. type faultyHandlerWriter struct{ responseWriterImpl } -func newFaultyHandlerWriter(inboundCallResponse, tchannel.Format, headerCase) responseWriter { +func newFaultyHandlerWriter(inboundCallResponse, tchannel.Format, headerCase, observability.ReservedHeaderEdgeMetrics) responseWriter { return &faultyHandlerWriter{} } diff --git a/transport/tchannel/transport.go b/transport/tchannel/transport.go index bba84daf6..e5cb60115 100644 --- a/transport/tchannel/transport.go +++ b/transport/tchannel/transport.go @@ -36,6 +36,7 @@ import ( "go.uber.org/yarpc/api/peer" "go.uber.org/yarpc/api/transport" yarpctls "go.uber.org/yarpc/api/transport/tls" + "go.uber.org/yarpc/internal/observability" "go.uber.org/yarpc/pkg/lifecycle" "go.uber.org/yarpc/transport/internal/tls/dialer" "go.uber.org/yarpc/transport/internal/tls/muxlistener" @@ -85,6 +86,8 @@ type Transport struct { outboundTLSConfigProvider yarpctls.OutboundTLSConfigProvider outboundChannels []*outboundChannel + + reservedHeaderMetric *observability.ReservedHeaderMetrics } // NewTransport is a YARPC transport that facilitates sending and receiving @@ -136,6 +139,7 @@ func (o transportOptions) newTransport() *Transport { inboundTLSConfig: o.inboundTLSConfig, inboundTLSMode: o.inboundTLSMode, outboundTLSConfigProvider: o.outboundTLSConfigProvider, + reservedHeaderMetric: observability.NewReserveHeaderMetrics(o.meter, TransportName), } } @@ -225,6 +229,7 @@ func (t *Transport) start() error { tracer: t.tracer, headerCase: t.headerCase, logger: t.logger, + reservedHeaderMetrics: t.reservedHeaderMetric, newResponseWriter: t.newResponseWriter, excludeServiceHeaderInResponse: t.excludeServiceHeaderInResponse, },