Skip to content

Commit

Permalink
Properly handle role.AliasNameSource migration (#135) (#136)
Browse files Browse the repository at this point in the history
- set role.AliasNameSource to be the default if both it and its field
  value are unset instead of returning an error
- add tests for the update operation

Co-authored-by: Tom Proctor <[email protected]>
  • Loading branch information
benashz and tomhjp authored Feb 7, 2022
1 parent e2ded2a commit decdb57
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 3 deletions.
13 changes: 10 additions & 3 deletions path_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,21 @@ func (b *kubeAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical
}

if source, ok := data.GetOk("alias_name_source"); ok {
if err := validateAliasNameSource(source.(string)); err != nil {
return logical.ErrorResponse(err.Error()), nil
// migrate the role.AliasNameSource to be the default
// if both it and the field value are unset
if role.AliasNameSource == aliasNameSourceUnset && source.(string) == aliasNameSourceUnset {
role.AliasNameSource = data.GetDefaultOrZero("alias_name_source").(string)
} else {
role.AliasNameSource = source.(string)
}
role.AliasNameSource = source.(string)
} else if role.AliasNameSource == aliasNameSourceUnset {
role.AliasNameSource = data.Get("alias_name_source").(string)
}

if err := validateAliasNameSource(role.AliasNameSource); err != nil {
return logical.ErrorResponse(err.Error()), nil
}

// Store the entry.
entry, err := logical.StorageEntryJSON("role/"+strings.ToLower(roleName), role)
if err != nil {
Expand Down
171 changes: 171 additions & 0 deletions path_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kubeauth

import (
"context"
"encoding/json"
"errors"
"fmt"
"testing"
Expand Down Expand Up @@ -311,3 +312,173 @@ func TestPath_Delete(t *testing.T) {
t.Fatalf("Unexpected resp data: expected nil got %#v\n", resp.Data)
}
}

func TestPath_Update(t *testing.T) {
testCases := map[string]struct {
storageData map[string]interface{}
requestData map[string]interface{}
expected *roleStorageEntry
wantErr error
}{
"default": {
storageData: map[string]interface{}{
"bound_service_account_names": []string{"name"},
"bound_service_account_namespaces": []string{"namespace"},
"policies": []string{"test"},
"period": 1 * time.Second,
"ttl": 1 * time.Second,
"num_uses": 12,
"max_ttl": 5 * time.Second,
"alias_name_source": aliasNameSourceDefault,
},
requestData: map[string]interface{}{
"alias_name_source": aliasNameSourceDefault,
"policies": []string{"bar", "foo"},
"period": "3s",
},
expected: &roleStorageEntry{
TokenParams: tokenutil.TokenParams{
TokenPolicies: []string{"bar", "foo"},
TokenPeriod: 3 * time.Second,
TokenTTL: 1 * time.Second,
TokenMaxTTL: 5 * time.Second,
TokenNumUses: 12,
TokenBoundCIDRs: nil,
},
Policies: []string{"bar", "foo"},
Period: 3 * time.Second,
ServiceAccountNames: []string{"name"},
ServiceAccountNamespaces: []string{"namespace"},
TTL: 1 * time.Second,
MaxTTL: 5 * time.Second,
NumUses: 12,
BoundCIDRs: nil,
AliasNameSource: aliasNameSourceDefault,
},
wantErr: nil,
},
"migrate-alias-name-source": {
storageData: map[string]interface{}{
"bound_service_account_names": []string{"name"},
"bound_service_account_namespaces": []string{"namespace"},
"policies": []string{"test"},
"period": 1 * time.Second,
"ttl": 1 * time.Second,
"num_uses": 12,
"max_ttl": 5 * time.Second,
},
requestData: map[string]interface{}{
"alias_name_source": aliasNameSourceUnset,
},
expected: &roleStorageEntry{
TokenParams: tokenutil.TokenParams{
TokenPolicies: []string{"test"},
TokenPeriod: 1 * time.Second,
TokenTTL: 1 * time.Second,
TokenMaxTTL: 5 * time.Second,
TokenNumUses: 12,
TokenBoundCIDRs: nil,
},
Policies: []string{"test"},
Period: 1 * time.Second,
ServiceAccountNames: []string{"name"},
ServiceAccountNamespaces: []string{"namespace"},
TTL: 1 * time.Second,
MaxTTL: 5 * time.Second,
NumUses: 12,
BoundCIDRs: nil,
AliasNameSource: aliasNameSourceDefault,
},
wantErr: nil,
},
"invalid-alias-name-source": {
storageData: map[string]interface{}{
"bound_service_account_names": []string{"name"},
"bound_service_account_namespaces": []string{"namespace"},
"alias_name_source": aliasNameSourceDefault,
},
requestData: map[string]interface{}{
"alias_name_source": "_invalid_",
},
wantErr: errInvalidAliasNameSource,
},
"invalid-alias-name-source-in-storage": {
storageData: map[string]interface{}{
"bound_service_account_names": []string{"name"},
"bound_service_account_namespaces": []string{"namespace"},
"alias_name_source": "_invalid_",
},
requestData: map[string]interface{}{},
wantErr: errInvalidAliasNameSource,
},
"invalid-alias-name-source-migration": {
storageData: map[string]interface{}{
"bound_service_account_names": []string{"name"},
"bound_service_account_namespaces": []string{"namespace"},
"alias_name_source": aliasNameSourceUnset,
},
requestData: map[string]interface{}{
"alias_name_source": "_invalid_",
},
wantErr: errInvalidAliasNameSource,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
b, storage := getBackend(t)
path := fmt.Sprintf("role/%s", name)

data, err := json.Marshal(tc.storageData)
if err != nil {
t.Fatal(err)
}

entry := &logical.StorageEntry{
Key: path,
Value: data,
SealWrap: false,
}
if err := storage.Put(context.Background(), entry); err != nil {
t.Fatal(err)
}

req := &logical.Request{
Operation: logical.UpdateOperation,
Path: path,
Storage: storage,
Data: tc.requestData,
}

resp, err := b.HandleRequest(context.Background(), req)

if tc.wantErr != nil {
var actual error
if err != nil {
actual = err
} else if resp != nil && resp.IsError() {
actual = resp.Error()
} else {
t.Fatalf("expected error")
}

if tc.wantErr.Error() != actual.Error() {
t.Fatalf("expected err %q, actual %q", tc.wantErr, actual)
}
} else {
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}

actual, err := b.(*kubeAuthBackend).role(context.Background(), storage, name)
if err != nil {
t.Fatal(err)
}

if diff := deep.Equal(tc.expected, actual); diff != nil {
t.Fatal(diff)
}
}
})
}
}

0 comments on commit decdb57

Please sign in to comment.