From decdb571b8dcf3d04a983a5cdb3e33c5c864e1e8 Mon Sep 17 00:00:00 2001 From: Ben Ash <32777270+benashz@users.noreply.github.com> Date: Mon, 7 Feb 2022 10:29:54 -0500 Subject: [PATCH] Properly handle role.AliasNameSource migration (#135) (#136) - set role.AliasNameSource to be the default if both it and its field value are unset instead of returning an error - add tests for the update operation Co-authored-by: Tom Proctor --- path_role.go | 13 +++- path_role_test.go | 171 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+), 3 deletions(-) diff --git a/path_role.go b/path_role.go index 3822d19b..52918905 100644 --- a/path_role.go +++ b/path_role.go @@ -315,14 +315,21 @@ func (b *kubeAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical } if source, ok := data.GetOk("alias_name_source"); ok { - if err := validateAliasNameSource(source.(string)); err != nil { - return logical.ErrorResponse(err.Error()), nil + // migrate the role.AliasNameSource to be the default + // if both it and the field value are unset + if role.AliasNameSource == aliasNameSourceUnset && source.(string) == aliasNameSourceUnset { + role.AliasNameSource = data.GetDefaultOrZero("alias_name_source").(string) + } else { + role.AliasNameSource = source.(string) } - role.AliasNameSource = source.(string) } else if role.AliasNameSource == aliasNameSourceUnset { role.AliasNameSource = data.Get("alias_name_source").(string) } + if err := validateAliasNameSource(role.AliasNameSource); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + // Store the entry. entry, err := logical.StorageEntryJSON("role/"+strings.ToLower(roleName), role) if err != nil { diff --git a/path_role_test.go b/path_role_test.go index 3abafa38..3b63113c 100644 --- a/path_role_test.go +++ b/path_role_test.go @@ -2,6 +2,7 @@ package kubeauth import ( "context" + "encoding/json" "errors" "fmt" "testing" @@ -311,3 +312,173 @@ func TestPath_Delete(t *testing.T) { t.Fatalf("Unexpected resp data: expected nil got %#v\n", resp.Data) } } + +func TestPath_Update(t *testing.T) { + testCases := map[string]struct { + storageData map[string]interface{} + requestData map[string]interface{} + expected *roleStorageEntry + wantErr error + }{ + "default": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "policies": []string{"test"}, + "period": 1 * time.Second, + "ttl": 1 * time.Second, + "num_uses": 12, + "max_ttl": 5 * time.Second, + "alias_name_source": aliasNameSourceDefault, + }, + requestData: map[string]interface{}{ + "alias_name_source": aliasNameSourceDefault, + "policies": []string{"bar", "foo"}, + "period": "3s", + }, + expected: &roleStorageEntry{ + TokenParams: tokenutil.TokenParams{ + TokenPolicies: []string{"bar", "foo"}, + TokenPeriod: 3 * time.Second, + TokenTTL: 1 * time.Second, + TokenMaxTTL: 5 * time.Second, + TokenNumUses: 12, + TokenBoundCIDRs: nil, + }, + Policies: []string{"bar", "foo"}, + Period: 3 * time.Second, + ServiceAccountNames: []string{"name"}, + ServiceAccountNamespaces: []string{"namespace"}, + TTL: 1 * time.Second, + MaxTTL: 5 * time.Second, + NumUses: 12, + BoundCIDRs: nil, + AliasNameSource: aliasNameSourceDefault, + }, + wantErr: nil, + }, + "migrate-alias-name-source": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "policies": []string{"test"}, + "period": 1 * time.Second, + "ttl": 1 * time.Second, + "num_uses": 12, + "max_ttl": 5 * time.Second, + }, + requestData: map[string]interface{}{ + "alias_name_source": aliasNameSourceUnset, + }, + expected: &roleStorageEntry{ + TokenParams: tokenutil.TokenParams{ + TokenPolicies: []string{"test"}, + TokenPeriod: 1 * time.Second, + TokenTTL: 1 * time.Second, + TokenMaxTTL: 5 * time.Second, + TokenNumUses: 12, + TokenBoundCIDRs: nil, + }, + Policies: []string{"test"}, + Period: 1 * time.Second, + ServiceAccountNames: []string{"name"}, + ServiceAccountNamespaces: []string{"namespace"}, + TTL: 1 * time.Second, + MaxTTL: 5 * time.Second, + NumUses: 12, + BoundCIDRs: nil, + AliasNameSource: aliasNameSourceDefault, + }, + wantErr: nil, + }, + "invalid-alias-name-source": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "alias_name_source": aliasNameSourceDefault, + }, + requestData: map[string]interface{}{ + "alias_name_source": "_invalid_", + }, + wantErr: errInvalidAliasNameSource, + }, + "invalid-alias-name-source-in-storage": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "alias_name_source": "_invalid_", + }, + requestData: map[string]interface{}{}, + wantErr: errInvalidAliasNameSource, + }, + "invalid-alias-name-source-migration": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "alias_name_source": aliasNameSourceUnset, + }, + requestData: map[string]interface{}{ + "alias_name_source": "_invalid_", + }, + wantErr: errInvalidAliasNameSource, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + b, storage := getBackend(t) + path := fmt.Sprintf("role/%s", name) + + data, err := json.Marshal(tc.storageData) + if err != nil { + t.Fatal(err) + } + + entry := &logical.StorageEntry{ + Key: path, + Value: data, + SealWrap: false, + } + if err := storage.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: path, + Storage: storage, + Data: tc.requestData, + } + + resp, err := b.HandleRequest(context.Background(), req) + + if tc.wantErr != nil { + var actual error + if err != nil { + actual = err + } else if resp != nil && resp.IsError() { + actual = resp.Error() + } else { + t.Fatalf("expected error") + } + + if tc.wantErr.Error() != actual.Error() { + t.Fatalf("expected err %q, actual %q", tc.wantErr, actual) + } + } else { + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + actual, err := b.(*kubeAuthBackend).role(context.Background(), storage, name) + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(tc.expected, actual); diff != nil { + t.Fatal(diff) + } + } + }) + } +}