diff --git a/filters/tracing/statebagtotag.go b/filters/tracing/statebagtotag.go index c96092e128..50debd281d 100644 --- a/filters/tracing/statebagtotag.go +++ b/filters/tracing/statebagtotag.go @@ -25,7 +25,7 @@ func (stateBagToTagSpec) Name() string { } func (stateBagToTagSpec) CreateFilter(args []interface{}) (filters.Filter, error) { - if len(args) < 1 { + if len(args) < 1 || len(args) > 2 { return nil, filters.ErrInvalidFilterParameters } @@ -43,7 +43,7 @@ func (stateBagToTagSpec) CreateFilter(args []interface{}) (filters.Filter, error tagName = tagNameArg } - return stateBagToTagFilter{ + return &stateBagToTagFilter{ stateBagItemName: stateBagItemName, tagName: tagName, }, nil @@ -53,16 +53,22 @@ func NewStateBagToTag() filters.Spec { return stateBagToTagSpec{} } -func (f stateBagToTagFilter) Request(ctx filters.FilterContext) { +func (f *stateBagToTagFilter) Request(ctx filters.FilterContext) { + value, ok := ctx.StateBag()[f.stateBagItemName] + if !ok { + return + } + span := opentracing.SpanFromContext(ctx.Request().Context()) if span == nil { return } - value, ok := ctx.StateBag()[f.stateBagItemName] - if !ok { - return + + if _, ok := value.(string); ok { + span.SetTag(f.tagName, value) + } else { + span.SetTag(f.tagName, fmt.Sprint(value)) } - span.SetTag(f.tagName, fmt.Sprint(value)) } -func (stateBagToTagFilter) Response(ctx filters.FilterContext) {} +func (*stateBagToTagFilter) Response(ctx filters.FilterContext) {} diff --git a/filters/tracing/statebagtotag_test.go b/filters/tracing/statebagtotag_test.go index 0a724b00ec..af487cdcbd 100644 --- a/filters/tracing/statebagtotag_test.go +++ b/filters/tracing/statebagtotag_test.go @@ -59,13 +59,18 @@ func TestStateBagToTag_CreateFilter(t *testing.T) { args: []interface{}{""}, err: filters.ErrInvalidFilterParameters, }, + { + msg: "too many args", + args: []interface{}{"foo", "bar", "baz"}, + err: filters.ErrInvalidFilterParameters, + }, } { t.Run(ti.msg, func(t *testing.T) { f, err := NewStateBagToTag().CreateFilter(ti.args) assert.Equal(t, ti.err, err) if err == nil { - ff := f.(stateBagToTagFilter) + ff := f.(*stateBagToTagFilter) assert.Equal(t, ti.stateBag, ff.stateBagItemName) assert.Equal(t, ti.tag, ff.tagName)