From 5339148971dd66ad3a053f6fbf4894abc7552586 Mon Sep 17 00:00:00 2001 From: Ggicci Date: Sat, 13 Apr 2024 18:12:18 -0400 Subject: [PATCH] feat: new api Resolver.ResolveTo() to specify resolving target --- context.go | 1 - option.go | 33 +++++++----- resolver.go | 85 +++++++++++++++++------------ resolver_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 196 insertions(+), 59 deletions(-) diff --git a/context.go b/context.go index b67747a..e911fb9 100644 --- a/context.go +++ b/context.go @@ -5,5 +5,4 @@ type contextKey int const ( ckNamespace contextKey = iota ckResolveNestedDirectives - ckDirectiveRunOrder ) diff --git a/option.go b/option.go index 03d28db..f8f34f2 100644 --- a/option.go +++ b/option.go @@ -1,6 +1,8 @@ package owl -import "context" +import ( + "context" +) // Option is an option for New. type Option interface { @@ -16,28 +18,35 @@ func (f OptionFunc) Apply(ctx context.Context) context.Context { // WithNamespace binds a namespace to the resolver. The namespace is used to // lookup directive executors. There's a default namespace, which is used when -// the namespace is not specified. The namespace set in New will be overridden -// by the namespace set in Resolve or Scan. +// the namespace is not specified. The namespace set in New() will be overridden +// by the namespace set in Resolve() or Scan(). func WithNamespace(ns *Namespace) Option { return WithValue(ckNamespace, ns) } +// WithNestedDirectivesEnabled controls whether to resolve nested directives. +// The default value is true. When set to false, the nested directives will not +// be executed. The value set in New() will be overridden by the value set in +// Resolve() or Scan(). +func WithNestedDirectivesEnabled(resolve bool) Option { + return WithValue(ckResolveNestedDirectives, resolve) +} + // WithValue binds a value to the context. // -// When used in New, the value is bound to Resolver.Context. +// When used in New(), the value is bound to Resolver.Context. // -// When used in Resolve or Scan, the value is bound to DirectiveRuntime.Context. -// See DirectiveRuntime.Context for more details. +// When used in Resolve() or Scan(), the value is bound to +// DirectiveRuntime.Context. See DirectiveRuntime.Context for more details. func WithValue(key, value interface{}) Option { return OptionFunc(func(ctx context.Context) context.Context { return context.WithValue(ctx, key, value) }) } -// WithNestedDirectivesEnabled controls whether to resolve nested directives. -// The default value is true. When set to false, the nested directives will not -// be executed. The value set in New will be overridden by the value set in -// Resolve or Scan. -func WithNestedDirectivesEnabled(resolve bool) Option { - return WithValue(ckResolveNestedDirectives, resolve) +func buildContextWithOptionsApplied(ctx context.Context, opts ...Option) context.Context { + for _, opt := range opts { + ctx = opt.Apply(ctx) + } + return ctx } diff --git a/resolver.go b/resolver.go index 5c153b7..8e66d25 100644 --- a/resolver.go +++ b/resolver.go @@ -31,7 +31,8 @@ type Resolver struct { // New builds a resolver tree from a struct value. The given options will be // applied to all the resolvers. In the resolver tree, each node is also a -// Resolver. +// Resolver. Available options are WithNamespace, WithNestedDirectivesEnabled +// and WithValue. func New(structValue interface{}, opts ...Option) (*Resolver, error) { typ, err := reflectStructType(structValue) if err != nil { @@ -47,16 +48,7 @@ func New(structValue interface{}, opts ...Option) (*Resolver, error) { // Apply options, build the context for each resolver. defaultOpts := []Option{WithNamespace(defaultNS)} opts = append(defaultOpts, opts...) - ctx := context.Background() - for _, opt := range opts { - ctx = opt.Apply(ctx) - } - - // Apply the context to each resolver. - tree.Iterate(func(r *Resolver) error { - r.Context = ctx - return nil - }) + tree.applyContext(buildContextWithOptionsApplied(context.Background(), opts...)) if tree.Namespace() == nil { return nil, errors.New("nil namespace") @@ -150,7 +142,7 @@ func findResolver(root *Resolver, path []string) *Resolver { return nil } -func shouldResolveNestedDirectives(r *Resolver, ctx context.Context) bool { +func shouldResolveNestedDirectives(ctx context.Context, r *Resolver) bool { if r.IsRoot() { return true // always resolve the root } @@ -186,7 +178,7 @@ func (root *Resolver) iterate(ctx context.Context, fn func(*Resolver) error) err return err } - if shouldResolveNestedDirectives(root, ctx) { + if shouldResolveNestedDirectives(ctx, root) { for _, field := range root.Children { if err := field.iterate(ctx, fn); err != nil { return err @@ -196,6 +188,14 @@ func (root *Resolver) iterate(ctx context.Context, fn func(*Resolver) error) err return nil } +// applyContext applies the context to the resolver and its children. +func (r *Resolver) applyContext(ctx context.Context) { + r.Iterate(func(x *Resolver) error { + x.Context = ctx + return nil + }) +} + // Scan scans the struct value by traversing the fields in depth-first order. The value is required // to have the same type as the resolver holds. While scanning, it will run the directives on each // field. The DirectiveRuntime that can be accessed during the directive exeuction will have its @@ -204,7 +204,7 @@ func (root *Resolver) iterate(ctx context.Context, fn func(*Resolver) error) err // based on the struct value, etc. // // Use WithValue to create an Option that can add custom values to the context, the context can be -// used by the directive executors during the resolution. +// used by the directive executors during the scanning. // // NOTE: Unlike Resolve, it will iterate the whole resolver tree against the given // value, try to access each corresponding field. Even scan fails on one of the fields, @@ -225,12 +225,8 @@ func (r *Resolver) Scan(value any, opts ...Option) error { ErrTypeMismatch, rv.Type(), r.Type) } - ctx := context.Background() - for _, opt := range opts { - ctx = opt.Apply(ctx) - } - var errs []error + ctx := buildContextWithOptionsApplied(context.Background(), opts...) r.iterate(ctx, func(r *Resolver) error { errs = append(errs, scan(r, ctx, rv)) return nil @@ -283,17 +279,36 @@ func scan(resolver *Resolver, ctx context.Context, rootValue reflect.Value) erro // resolver := owl.New(Settings{}) // settings, err := resolver.Resolve(WithValue("app_config", appConfig)) // -// NOTE: while iterating the tree, if resolving a field fails, the iteration -// will be stopped and the error will be returned. +// NOTE: while iterating the tree, if resolving a field failed, the iteration +// will be stopped immediately and the error will be returned. func (r *Resolver) Resolve(opts ...Option) (reflect.Value, error) { - ctx := context.Background() - for _, opt := range opts { - ctx = opt.Apply(ctx) - } - rootValue := reflect.New(r.Type) + ctx := buildContextWithOptionsApplied(context.Background(), opts...) + rootValue := reflect.New(r.Type) // Type:User -> rootValue:*User return rootValue, r.resolve(ctx, rootValue) } +// ResolveTo works like Resolve, but it resolves the struct value to the given +// pointer value instead of creating a new value. The pointer value must be +// non-nil and a pointer to the type the resolver holds. +func (r *Resolver) ResolveTo(value any, opts ...Option) error { + if value == nil { + return fmt.Errorf("cannot resolve to nil value") + } + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return fmt.Errorf("cannot resolve to non-pointer value") + } + if rv.IsNil() { + return fmt.Errorf("cannot resolve to nil pointer value") + } + if rv.Type().Elem() != r.Type { + return fmt.Errorf("%w: cannot resolve to value of type %q, expecting type %q", + ErrTypeMismatch, rv.Type().Elem(), r.Type) + } + ctx := buildContextWithOptionsApplied(context.Background(), opts...) + return r.resolve(ctx, rv) +} + func (root *Resolver) resolve(ctx context.Context, rootValue reflect.Value) error { // Run the directives on current field. if err := root.runDirectives(ctx, rootValue); err != nil { @@ -301,17 +316,19 @@ func (root *Resolver) resolve(ctx context.Context, rootValue reflect.Value) erro } // Resolve the children fields. - if shouldResolveNestedDirectives(root, ctx) { - // If the root is a pointer, we need to allocate memory for it. - // We only expect it's a one-level pointer, e.g. *User, not **User. - underlyingValue := rootValue + if shouldResolveNestedDirectives(ctx, root) { + // If the root is a pointer, we need to allocate memory for it when it's + // not instantiated yet. We only expect it's a one-level pointer, e.g. + // *User, not **User. + underlying := rootValue if root.Type.Kind() == reflect.Ptr { - underlyingValue = reflect.New(root.Type.Elem()) - rootValue.Elem().Set(underlyingValue) + if rootValue.Elem().IsNil() { // instantiate the pointer on demand + rootValue.Elem().Set(reflect.New(root.Type.Elem())) + } + underlying = rootValue.Elem() } - for _, child := range root.Children { - if err := child.resolve(ctx, underlyingValue.Elem().Field(child.Index[len(child.Index)-1]).Addr()); err != nil { + if err := child.resolve(ctx, underlying.Elem().Field(child.Index[len(child.Index)-1]).Addr()); err != nil { return &ResolveError{ fieldError: fieldError{ Err: err, diff --git a/resolver_test.go b/resolver_test.go index b4e1fef..8aa1257 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -374,16 +374,24 @@ func TestResolve_SimpleFlatStruct(t *testing.T) { resolver, err := owl.New(GenerateAccessTokenRequest{}, owl.WithNamespace(ns)) assert.NoError(err) - - _, err = resolver.Resolve() - assert.NoError(err) - - assert.Equal([]*owl.Directive{ + expectedExecutedDirectives := []*owl.Directive{ owl.NewDirective("env", "ACCESS_TOKEN_KEY_GENERATION_KEY"), owl.NewDirective("form", "username"), owl.NewDirective("form", "expiry"), owl.NewDirective("default", "3600"), - }, tracker.Executed.ExecutedDirectives(), "should execute all directives in order") + } + + // Resolve + _, err = resolver.Resolve() + assert.NoError(err) + assert.Equal(expectedExecutedDirectives, tracker.Executed.ExecutedDirectives(), "should execute all directives in order") + + // ResolveTo + tracker.Reset() + var targetValue = new(GenerateAccessTokenRequest) + err = resolver.ResolveTo(targetValue) + assert.NoError(err) + assert.Equal(expectedExecutedDirectives, tracker.Executed.ExecutedDirectives(), "should execute all directives in order") } func TestResolve_EmbeddedStruct(t *testing.T) { @@ -402,18 +410,26 @@ func TestResolve_EmbeddedStruct(t *testing.T) { resolver, err := owl.New(UserListQuery{}, owl.WithNamespace(ns)) assert.NoError(err) - - _, err = resolver.Resolve() - assert.NoError(err) - - assert.Equal([]*owl.Directive{ + expectedExecutedDirectives := []*owl.Directive{ owl.NewDirective("form", "gender"), owl.NewDirective("form", "age", "age[]"), owl.NewDirective("default", "18", "999"), owl.NewDirective("form", "roles", "roles[]"), owl.NewDirective("form", "page"), owl.NewDirective("form", "size"), - }, tracker.Executed.ExecutedDirectives(), "should execute all directives in order") + } + + // Resolve + _, err = resolver.Resolve() + assert.NoError(err) + assert.Equal(expectedExecutedDirectives, tracker.Executed.ExecutedDirectives(), "should execute all directives in order") + + // ResolveTo + tracker.Reset() + var targetValue = new(UserListQuery) + err = resolver.ResolveTo(targetValue) + assert.NoError(err) + assert.Equal(expectedExecutedDirectives, tracker.Executed.ExecutedDirectives(), "should execute all directives in order") } func TestResolve_UnexportedField(t *testing.T) { @@ -609,6 +625,102 @@ func TestResolve_WithNestedDirectivesEnabled_false(t *testing.T) { }, tracker.Executed.ExecutedDirectives(), "should resolve nested directives") } +func TestResolveTo_InstantializeOnlyNilPointerForNestedStruct(t *testing.T) { + type Owner struct { + Type string `owl:"env=type"` + Name string `owl:"env=name"` + } + + type AddOwnershipRequest struct { + ResourceId string `owl:"env=resource_id"` + Owner *Owner + } + + ns := owl.NewNamespace() + ns.RegisterDirectiveExecutor("env", owl.DirectiveExecutorFunc(exeEnvReader)) + resolver, err := owl.New(AddOwnershipRequest{}, owl.WithNamespace(ns)) + assert.NoError(t, err) + + os.Setenv("type", "usergroup") + os.Setenv("name", "admin") + os.Setenv("resource_id", "123") + + useOwner := &Owner{} + reqWithOwnerInstantiated := &AddOwnershipRequest{ + ResourceId: "", + Owner: useOwner, + } + err = resolver.ResolveTo(reqWithOwnerInstantiated) + assert.NoError(t, err) + + // The Owner field is already instantiated, so we only populate the fields, + // but not create a new instance and assign it to the Owner field. + assert.Same(t, useOwner, reqWithOwnerInstantiated.Owner) + assert.Equal(t, "usergroup", reqWithOwnerInstantiated.Owner.Type) + assert.Equal(t, "admin", reqWithOwnerInstantiated.Owner.Name) + assert.Equal(t, "123", reqWithOwnerInstantiated.ResourceId) + + // The Owner field is nil, so we create a new instance when resolving. + reqWithOwnerNotInstantiated := &AddOwnershipRequest{ + ResourceId: "", + Owner: nil, + } + err = resolver.ResolveTo(reqWithOwnerNotInstantiated) + assert.NoError(t, err) + assert.Equal(t, &Owner{Type: "usergroup", Name: "admin"}, reqWithOwnerNotInstantiated.Owner) + assert.Equal(t, "123", reqWithOwnerNotInstantiated.ResourceId) +} + +func TestResolveTo_PopulateFieldsOnDemand(t *testing.T) { + type User struct { + Name string `owl:"env=OWL_TEST_NAME"` + } + + ns := owl.NewNamespace() + ns.RegisterDirectiveExecutor("env", owl.DirectiveExecutorFunc(exeEnvReader)) + resolver, err := owl.New(User{}, owl.WithNamespace(ns)) + assert.NoError(t, err) + + user := &User{Name: "admin"} + err = resolver.ResolveTo(user) + assert.NoError(t, err) + assert.Equal(t, "admin", user.Name) // not changed + + os.Setenv("OWL_TEST_NAME", "owl") + err = resolver.ResolveTo(user) + assert.NoError(t, err) + assert.Equal(t, "owl", user.Name) // changed +} + +func TestResolveTo_ErrNilValue(t *testing.T) { + resolver, err := owl.New(User{}) + assert.NoError(t, err) + + err = resolver.ResolveTo(nil) + assert.ErrorContains(t, err, "nil") + + err = resolver.ResolveTo((*User)(nil)) + assert.ErrorContains(t, err, "nil pointer") +} + +func TestResolveTo_ErrNonPointerValue(t *testing.T) { + resolver, err := owl.New(User{}) + assert.NoError(t, err) + + var user User + err = resolver.ResolveTo(user) + assert.ErrorContains(t, err, "non-pointer") +} + +func TestResolveTo_ErrTypeMismatch(t *testing.T) { + resolver, err := owl.New(User{}) + assert.NoError(t, err) + + var user = new(Pagination) + err = resolver.ResolveTo(user) + assert.ErrorIs(t, err, owl.ErrTypeMismatch) +} + func TestScan(t *testing.T) { ns, tracker := createNsForTracking() resolver, err := owl.New(User{}, owl.WithNamespace(ns))