Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix middleware wrapping #255

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions destination_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ import (
"golang.org/x/time/rate"
)

var destinationMiddlewareType = reflect.TypeFor[DestinationMiddleware]()

// DestinationMiddleware wraps a Destination and adds functionality to it.
type DestinationMiddleware interface {
Wrap(Destination) Destination
Expand Down Expand Up @@ -88,7 +86,21 @@ func DestinationWithMiddleware(d Destination) Destination {
if cfgVal.Kind() != reflect.Ptr {
panic("The struct returned in Config() must be a pointer")
}
cfgVal = cfgVal.Elem()

// Collect all middlewares from the config and wrap the destination with them
mw := destinationMiddlewareFromConfigRecursive(cfgVal.Elem())

// Wrap the middleware in reverse order to preserve the order as specified.
for i := len(mw) - 1; i >= 0; i-- {
d = mw[i].Wrap(d)
}

return d
}

func destinationMiddlewareFromConfigRecursive(cfgVal reflect.Value) []DestinationMiddleware {
destinationMiddlewareType := reflect.TypeFor[DestinationMiddleware]()
cfgType := cfgVal.Type()

// Collect all middlewares from the config and wrap the destination with them
var mw []DestinationMiddleware
Expand All @@ -100,19 +112,21 @@ func DestinationWithMiddleware(d Destination) Destination {
if field.Kind() != reflect.Ptr {
field = field.Addr()
}
if field.Type().Implements(destinationMiddlewareType) {

switch {
case field.Type().Implements(destinationMiddlewareType):
// This is a middleware config, store it.
//nolint:forcetypeassert // type checked above with field.Type().Implements()
mw = append(mw, field.Interface().(DestinationMiddleware))
}
}

// Wrap the middleware in reverse order to preserve the order as specified.
for i := len(mw) - 1; i >= 0; i-- {
d = mw[i].Wrap(d)
case cfgType.Field(i).Anonymous &&
cfgType.Field(i).Type.Kind() == reflect.Struct:
// This is an embedded struct, dive deeper.
mw = append(mw, destinationMiddlewareFromConfigRecursive(field.Elem())...)
}
}

return d
return mw
}

// -- DestinationWithBatch -----------------------------------------------------
Expand Down
22 changes: 22 additions & 0 deletions destination_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ import (
"golang.org/x/time/rate"
)

func TestDestinationWithMiddleware(t *testing.T) {
is := is.New(t)

ctrl := gomock.NewController(t)
src := NewMockDestination(ctrl)

cfg := struct {
DefaultDestinationMiddleware
}{}
src.EXPECT().Config().Return(&cfg)

got := DestinationWithMiddleware(src)

var want Destination = src
want = (&DestinationWithSchemaExtraction{}).Wrap(want)
want = (&DestinationWithBatch{}).Wrap(want)
want = (&DestinationWithRecordFormat{}).Wrap(want)
want = (&DestinationWithRateLimit{}).Wrap(want)

is.Equal(want, got)
}

// -- DestinationWithBatch -----------------------------------------------------

func TestDestinationWithBatch_Open(t *testing.T) {
Expand Down
4 changes: 0 additions & 4 deletions destination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ func TestDestinationPluginAdapter_Run_Write(t *testing.T) {
}

func TestDestinationPluginAdapter_Run_WriteBatch_Success(t *testing.T) {
t.Skip("TODO fix this test")

is := is.New(t)
ctrl := gomock.NewController(t)
dst := NewMockDestination(ctrl)
Expand Down Expand Up @@ -249,8 +247,6 @@ func TestDestinationPluginAdapter_Run_WriteBatch_Success(t *testing.T) {
}

func TestDestinationPluginAdapter_Run_WriteBatch_Partial(t *testing.T) {
t.Skip("TODO fix this test")

is := is.New(t)
ctrl := gomock.NewController(t)
dst := NewMockDestination(ctrl)
Expand Down
34 changes: 24 additions & 10 deletions source_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ import (
"github.com/jpillora/backoff"
)

var sourceMiddlewareType = reflect.TypeFor[SourceMiddleware]()

// SourceMiddleware wraps a Source and adds functionality to it.
type SourceMiddleware interface {
Wrap(Source) Source
Expand Down Expand Up @@ -84,7 +82,21 @@ func SourceWithMiddleware(s Source) Source {
if cfgVal.Kind() != reflect.Ptr {
panic("The struct returned in Config() must be a pointer")
}
cfgVal = cfgVal.Elem()

// Collect all middlewares from the config and wrap the source with them
mw := sourceMiddlewareFromConfigRecursive(cfgVal.Elem())

// Wrap the middleware in reverse order to preserve the order as specified.
for i := len(mw) - 1; i >= 0; i-- {
s = mw[i].Wrap(s)
}

return s
}

func sourceMiddlewareFromConfigRecursive(cfgVal reflect.Value) []SourceMiddleware {
sourceMiddlewareType := reflect.TypeFor[SourceMiddleware]()
cfgType := cfgVal.Type()

// Collect all middlewares from the config and wrap the source with them
var mw []SourceMiddleware
Expand All @@ -96,19 +108,21 @@ func SourceWithMiddleware(s Source) Source {
if field.Kind() != reflect.Ptr {
field = field.Addr()
}
if field.Type().Implements(sourceMiddlewareType) {

switch {
case field.Type().Implements(sourceMiddlewareType):
// This is a middleware config, store it.
//nolint:forcetypeassert // type checked above with field.Type().Implements()
mw = append(mw, field.Interface().(SourceMiddleware))
}
}

// Wrap the middleware in reverse order to preserve the order as specified.
for i := len(mw) - 1; i >= 0; i-- {
s = mw[i].Wrap(s)
case cfgType.Field(i).Anonymous &&
cfgType.Field(i).Type.Kind() == reflect.Struct:
// This is an embedded struct, dive deeper.
mw = append(mw, sourceMiddlewareFromConfigRecursive(field.Elem())...)
}
}

return s
return mw
}

// -- SourceWithSchemaExtraction -----------------------------------------------
Expand Down
22 changes: 22 additions & 0 deletions source_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,28 @@ import (
"go.uber.org/mock/gomock"
)

func TestSourceWithMiddleware(t *testing.T) {
is := is.New(t)

ctrl := gomock.NewController(t)
src := NewMockSource(ctrl)

cfg := struct {
DefaultSourceMiddleware
}{}
src.EXPECT().Config().Return(&cfg)

got := SourceWithMiddleware(src)

var want Source = src
want = (&SourceWithSchemaExtraction{}).Wrap(want)
want = (&SourceWithSchemaContext{}).Wrap(want)
want = (&SourceWithEncoding{}).Wrap(want)
want = (&SourceWithBatch{}).Wrap(want)

is.Equal(want, got)
}

// -- SourceWithSchemaExtraction -----------------------------------------------

func TestSourceWithSchemaExtraction_SchemaType(t *testing.T) {
Expand Down