Skip to content

Commit

Permalink
Prototypes propagation with multiple values. Adds MultiTextMapCarrier…
Browse files Browse the repository at this point in the history
…, extending TextMapCarrier.

Gives example extracting requests with multiple 'baggage' headers set.
  • Loading branch information
jamesmoessis committed Nov 14, 2024
1 parent d428313 commit 530b431
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 4 deletions.
32 changes: 29 additions & 3 deletions propagation/baggage.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ func (b Baggage) Inject(ctx context.Context, carrier TextMapCarrier) {

// Extract returns a copy of parent with the baggage from the carrier added.
func (b Baggage) Extract(parent context.Context, carrier TextMapCarrier) context.Context {
multiCarrier, isMultiCarrier := carrier.(MultiTextMapCarrier)
if isMultiCarrier {
return extractMultiBaggage(parent, multiCarrier)
}
return extractSingleBaggage(parent, carrier)
}

// Fields returns the keys who's values are set with Inject.
func (b Baggage) Fields() []string {
return []string{baggageHeader}
}

func extractSingleBaggage(parent context.Context, carrier TextMapCarrier) context.Context {
bStr := carrier.Get(baggageHeader)
if bStr == "" {
return parent
Expand All @@ -41,7 +54,20 @@ func (b Baggage) Extract(parent context.Context, carrier TextMapCarrier) context
return baggage.ContextWithBaggage(parent, bag)
}

// Fields returns the keys who's values are set with Inject.
func (b Baggage) Fields() []string {
return []string{baggageHeader}
func extractMultiBaggage(parent context.Context, carrier MultiTextMapCarrier) context.Context {
bVals := carrier.GetAll(baggageHeader)
members := make([]baggage.Member, 0)
for _, bStr := range bVals {
currBag, err := baggage.Parse(bStr)
if err != nil {
continue
}
members = append(members, currBag.Members()...)
}

b, err := baggage.New(members...)
if err != nil || b.Len() == 0 {
return parent
}
return baggage.ContextWithBaggage(parent, b)
}
49 changes: 49 additions & 0 deletions propagation/baggage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,55 @@ func TestExtractValidBaggageFromHTTPReq(t *testing.T) {
}
}

func TestExtractValidMultipleBaggageHeaders(t *testing.T) {
prop := propagation.TextMapPropagator(propagation.Baggage{})
tests := []struct {
name string
headers []string
want members
}{
{
name: "non conflicting headers",
headers: []string{"key1=val1", "key2=val2"},
want: members{
{Key: "key1", Value: "val1"},
{Key: "key2", Value: "val2"},
},
},
{
name: "conflicting keys, uses last val",
headers: []string{"key1=val1", "key1=val2"},
want: members{
{Key: "key1", Value: "val2"},
},
},
{
name: "single empty",
headers: []string{"", "key1=val1"},
want: members{
{Key: "key1", Value: "val1"},
},
},
{
name: "all empty",
headers: []string{"", ""},
want: members{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", "http://example.com", nil)
req.Header["Baggage"] = tt.headers

ctx := context.Background()
ctx = prop.Extract(ctx, propagation.HeaderCarrier(req.Header))
expected := tt.want.Baggage(t)
assert.Equal(t, expected, baggage.FromContext(ctx))
})
}
}

func TestExtractInvalidDistributedContextFromHTTPReq(t *testing.T) {
prop := propagation.TextMapPropagator(propagation.Baggage{})
tests := []struct {
Expand Down
16 changes: 15 additions & 1 deletion propagation/propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ type TextMapCarrier interface {
// must never be done outside of a new major release.
}

// MultiTextMapCarrier is a TextMapCarrier that can return multiple values for a single key.
type MultiTextMapCarrier interface {
TextMapCarrier
// GetAll returns all values associated with the passed key.
GetAll(key string) []string
// DO NOT CHANGE: any modification will not be backwards compatible and
// must never be done outside of a new major release.
}

// MapCarrier is a TextMapCarrier that uses a map held in memory as a storage
// medium for propagated key-value pairs.
type MapCarrier map[string]string
Expand Down Expand Up @@ -58,11 +67,16 @@ func (c MapCarrier) Keys() []string {
// HeaderCarrier adapts http.Header to satisfy the TextMapCarrier interface.
type HeaderCarrier http.Header

// Get returns the value associated with the passed key.
// Get returns the first value associated with the passed key.
func (hc HeaderCarrier) Get(key string) string {
return http.Header(hc).Get(key)
}

// GetAll returns all values associated with the passed key.
func (hc HeaderCarrier) GetAll(key string) []string {
return http.Header(hc).Values(key)
}

// Set stores the key-value pair.
func (hc HeaderCarrier) Set(key string, value string) {
http.Header(hc).Set(key, value)
Expand Down

0 comments on commit 530b431

Please sign in to comment.