From 6d5523e30ea1ba9e93ad181172e7def60075a2d5 Mon Sep 17 00:00:00 2001 From: Yota Hamada Date: Thu, 23 May 2024 11:36:29 +0900 Subject: [PATCH] Ensure reentrancy --- command.go | 2 +- mux.go | 5 +++-- mux_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/command.go b/command.go index 73f91f4..92edd38 100644 --- a/command.go +++ b/command.go @@ -131,7 +131,7 @@ func resolveHandler[T Command](op OpType, bus Bus) (HandlerFunc[T], *Mux) { if n != nil { h := n.handler.handler hh := convertInterface[HandlerFunc[T]](h.handler) - storeCache[T](&mx.cache, typ, h.mux, hh) + storeCache[T](mx.cache, typ, h.mux, hh) return hh, h.mux } diff --git a/mux.go b/mux.go index 69c3c45..ee4e39c 100644 --- a/mux.go +++ b/mux.go @@ -26,7 +26,7 @@ type Mux struct { tree *node middlewares [mAll][]middleware mHandlers [mAll]func(ctx Context, fn mHandlerFunc) error - cache syncMap + cache *syncMap // context pool pool *sync.Pool @@ -68,7 +68,7 @@ func newMux() *Mux { mux.pool.New = func() interface{} { return NewContext() } - mux.cache.kv = make(map[reflect.Type]any) + mux.cache = &syncMap{kv: make(map[reflect.Type]any)} return mux } @@ -229,6 +229,7 @@ func (mx *Mux) child() Bus { inline: true, middlewares: mws, tree: mx.tree, + cache: mx.cache, } } diff --git a/mux_test.go b/mux_test.go index 0aa5a44..354dbe0 100644 --- a/mux_test.go +++ b/mux_test.go @@ -177,6 +177,47 @@ func TestMux_QueryAsync_Error(t *testing.T) { } } +func TestMux_Reentrant(t *testing.T) { + mux := dew.New() + mux.Register(new(userHandler)) + mux.Register(new(postHandler)) + + type findUserPost struct { + ID int + Result struct { + User string + Post string + } + } + + mux.Register(dew.HandlerFunc[findUserPost]( + func(ctx context.Context, query *findUserPost) error { + findUserQuery, err := dew.Query(ctx, dew.NewQuery(dew.FromContext(ctx), &findUser{ID: query.ID})) + if err != nil { + return err + } + postQuery, err := dew.Query(ctx, dew.NewQuery(dew.FromContext(ctx), &findPost{ID: query.ID})) + if err != nil { + return err + } + query.Result.User = findUserQuery.Result + query.Result.Post = postQuery.Result + return nil + }, + )) + + query, err := dew.Query(context.Background(), dew.NewQuery(mux, &findUserPost{ID: 1})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if query.Result.User != "john" { + t.Fatalf("unexpected result: %s", query.Result.User) + } + if query.Result.Post != "hello" { + t.Fatalf("unexpected result: %s", query.Result.Post) + } +} + type ctxKey struct { name string } @@ -702,6 +743,11 @@ func (h *postHandler) CreatePost(_ context.Context, command *createPost) error { return nil } +func (h *postHandler) FindPost(_ context.Context, query *findPost) error { + query.Result = "hello" + return nil +} + func (*userHandler) FindUser(_ context.Context, query *findUser) error { if query.ID == 1 { query.Result = "john"