From 4e2b9b0988c9a456ea6a0afc2c710177eeef261f Mon Sep 17 00:00:00 2001 From: Vojtech Vitek Date: Mon, 14 Oct 2024 17:14:30 +0200 Subject: [PATCH] Fix -race bug, tests and improve recursion --- README.md | 15 +++-- collector.go | 43 +++++-------- collector_test.go | 51 ++++----------- gcp/gcp.go | 12 ++-- hydrate.go | 47 +++++++------- hydrate_test.go | 20 ++++-- replacer.go | 58 +++++++++-------- replacer_test.go | 157 ---------------------------------------------- 8 files changed, 116 insertions(+), 287 deletions(-) delete mode 100644 replacer_test.go diff --git a/README.md b/README.md index ff32d6b..095dcdd 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,22 @@ # go-cloudsecrets -Go package to hydrate runtime secrets from Cloud providers +Go package for hydrating config secrets from Cloud secret managers - [x] `"gcp"`, GCP Secret Manager - [ ] `"aws"`, AWS Secrets Manager +- [ ] `""`, empty provider, which errors out on `$SECRET:` value ```go cloudsecrets.Hydrate(ctx, "gcp", &Config{}) ``` -`Hydrate()` recursively walks a given config (struct pointer) and hydrates all string -values matching `"$SECRET:"` prefix using a given Cloud secrets provider. +`Hydrate()` recursively walks given config (a `struct` pointer) and replaces all string +fields having `"$SECRET:"` prefix with a value fetched from a given Cloud secret provider. -The secret values to be replaced must have a format of `"$SECRET:{name|path}"`. +The value to be replaced must have a format of `"$SECRET:{name|path}"`. + +Secrets are de-duplicated and fetched only once. + +The `Hydrate()` function tries to replace as many fields as possible before returning error. ## Usage ```go @@ -23,7 +28,7 @@ func main() { Database: "postgres", Host: "localhost:5432", Username: "sequence", - DPassword: "$SECRET:dbPassword", // to be hydrated + DPassword: "$SECRET:dbPassword", // will be hydrated (replaced with value of "dbPassword" secret) }, } diff --git a/collector.go b/collector.go index 68ec5c5..f7aa216 100644 --- a/collector.go +++ b/collector.go @@ -1,55 +1,49 @@ package cloudsecrets import ( - "fmt" "reflect" + "slices" "strings" ) -func collectSecretFields(v reflect.Value) (map[string]string, error) { - c := &collector{ - fields: map[string]string{}, - } - c.collectSecretFields(v, "config") - if c.err != nil { - return nil, fmt.Errorf("failed to collect fields: %w", c.err) - } +// Returns de-duplicated secret keys found recursively in given v. +func collectSecretKeys(v reflect.Value) []string { + c := collector{} + c.collectSecretFields(v) - return c.fields, nil -} + slices.Sort(c) + dedup := slices.Compact(c) -type collector struct { - fields map[string]string - err error + return []string(dedup) } -// Walks given reflect value recursively and collects any string fields matching $SECRET: prefix. -func (c *collector) collectSecretFields(v reflect.Value, path string) { +type collector []string + +// Walk given reflect value recursively and collects any string fields matching $SECRET: prefix. +func (c *collector) collectSecretFields(v reflect.Value) { switch v.Kind() { case reflect.Ptr: if v.IsNil() { return } - - // Dereference pointer - c.collectSecretFields(v.Elem(), path) + c.collectSecretFields(v.Elem()) case reflect.Struct: for i := 0; i < v.NumField(); i++ { field := v.Field(i) - c.collectSecretFields(field, fmt.Sprintf("%v.%v", path, v.Type().Field(i).Name)) + c.collectSecretFields(field) } case reflect.Slice, reflect.Array: for i := 0; i < v.Len(); i++ { item := v.Index(i) - c.collectSecretFields(item, fmt.Sprintf("%v[%v]", path, i)) + c.collectSecretFields(item) } case reflect.Map: for _, key := range v.MapKeys() { item := v.MapIndex(key) - c.collectSecretFields(item, fmt.Sprintf("%v[%v]", path, key)) + c.collectSecretFields(item) } case reflect.String: @@ -57,10 +51,7 @@ func (c *collector) collectSecretFields(v reflect.Value, path string) { if !found { return } - - if _, ok := c.fields[secretName]; !ok { - c.fields[secretName] = path - } + *c = append(*c, secretName) default: return diff --git a/collector_test.go b/collector_test.go index cff5fbf..78c166f 100644 --- a/collector_test.go +++ b/collector_test.go @@ -2,17 +2,16 @@ package cloudsecrets import ( "reflect" - "slices" "testing" "github.com/google/go-cmp/cmp" ) -func TestCollectFields(t *testing.T) { +func TestCollectSecretKeys(t *testing.T) { tt := []struct { Name string Input any - Out []string // field paths + Out []string // collected secret keys Error bool }{ { @@ -112,27 +111,16 @@ func TestCollectFields(t *testing.T) { Input: &cfg{ DB: dbConfig{ User: "db-user", - Password: "$SECRET:dup", + Password: "$SECRET:duplicatedKey", }, - JWTSecrets: []jwtSecret{"$SECRET:dup", "$SECRET:dup"}, + JWTSecrets: []jwtSecret{"$SECRET:duplicatedKey", "$SECRET:duplicatedKey"}, ProvidersPtr: map[string]*providerConfig{ - "provider1": {Name: "provider1", Secret: "$SECRET:dup"}, - "provider2": {Name: "provider2", Secret: "$SECRET:dup"}, - "provider3": {Name: "provider3", Secret: "$SECRET:dup"}, + "provider1": {Name: "provider1", Secret: "$SECRET:duplicatedKey"}, + "provider2": {Name: "provider2", Secret: "$SECRET:duplicatedKey"}, + "provider3": {Name: "provider3", Secret: "$SECRET:duplicatedKey"}, }, }, - Out: []string{"dup"}, - }, - { - Name: "Unexported_field_should_fail_to_hydrate", - Input: &cfg{ - unexported: dbConfig{ // unexported fields can't be updated via reflect pkg - User: "db-user", - Password: "$SECRET:secretName", // match inside unexported field - }, - }, - Out: []string{}, - Error: true, // expect error + Out: []string{"duplicatedKey"}, }, } @@ -141,17 +129,9 @@ func TestCollectFields(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { v := reflect.ValueOf(tc.Input) - secretFields, err := collectSecretFields(v) - if tc.Error && err == nil { - t.Error("expected error, got nil") - return - } else if !tc.Error && err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if !cmp.Equal(mapKeysSorted(secretFields), tc.Out) { - t.Errorf(cmp.Diff(tc.Out, mapKeysSorted(secretFields))) + secretFields := collectSecretKeys(v) + if !cmp.Equal(secretFields, tc.Out) { + t.Errorf(cmp.Diff(tc.Out, secretFields)) } }) } @@ -181,12 +161,3 @@ type providerConfig struct { type jwtSecret string func ptr[T any](v T) *T { return &v } - -func mapKeysSorted(m map[string]string) []string { - keys := []string{} - for key, _ := range m { - keys = append(keys, key) - } - slices.Sort(keys) - return keys -} diff --git a/gcp/gcp.go b/gcp/gcp.go index 937926f..4b93670 100644 --- a/gcp/gcp.go +++ b/gcp/gcp.go @@ -21,14 +21,14 @@ type SecretsProvider struct { func NewSecretsProvider() (*SecretsProvider, error) { gcpClient, err := secretmanager.NewClient(context.Background()) if err != nil { - return nil, fmt.Errorf("secretmanager client: %w", err) + return nil, fmt.Errorf("gcp: secretmanager client: %w", err) } var projectNumber string if metadata.OnGCE() { projectNumber, err = metadata.NumericProjectID() if err != nil { - return nil, fmt.Errorf("getting project ID from metadata: %w", err) + return nil, fmt.Errorf("gcp: getting project ID from metadata: %w", err) } } else { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -36,7 +36,7 @@ func NewSecretsProvider() (*SecretsProvider, error) { projectNumber, err = getProjectNumberFromGcloud(ctx) if err != nil { - return nil, fmt.Errorf("getting project ID from gcloud: %w", err) + return nil, fmt.Errorf("gcp: getting project ID from gcloud: %w", err) } } @@ -59,7 +59,7 @@ func (p SecretsProvider) FetchSecret(ctx context.Context, secretId string) (stri // Access the secret version result, err := p.client.AccessSecretVersion(reqCtx, req) if err != nil { - return "", fmt.Errorf("accessing GCP secret %v (%v): %w", secretId, versionId, err) + return "", fmt.Errorf("gcp: accessing secret: %w", err) } // Return the secret value @@ -75,7 +75,7 @@ func getProjectNumberFromGcloud(ctx context.Context) (string, error) { if projectId == "" { out, err := exec.CommandContext(ctx, "gcloud", "config", "get-value", "project").Output() if err != nil { - return "", fmt.Errorf("getting current gcloud project (try `gcloud auth application-default login'): %w", err) + return "", fmt.Errorf("gcp: getting current gcloud project (try `gcloud auth application-default login'): %w", err) } projectId = strings.TrimSpace(string(out)) } @@ -83,7 +83,7 @@ func getProjectNumberFromGcloud(ctx context.Context) (string, error) { // We need projectNumber (not projectName!) for GCP Secret Manager APIs. out, err := exec.CommandContext(ctx, "gcloud", "projects", "describe", projectId, "--format=value(projectNumber)").Output() if err != nil { - return "", fmt.Errorf("getting gcloud projectNumber from projectId %q: %w", projectId, err) + return "", fmt.Errorf("gcp: getting gcloud projectNumber from projectId %q: %w", projectId, err) } return strings.TrimSpace(string(out)), nil } diff --git a/hydrate.go b/hydrate.go index 353713b..317a0d3 100644 --- a/hydrate.go +++ b/hydrate.go @@ -3,7 +3,6 @@ package cloudsecrets import ( "context" "fmt" - "log" "reflect" "github.com/0xsequence/go-cloudsecrets/gcp" @@ -43,45 +42,45 @@ func Hydrate(ctx context.Context, providerName string, config interface{}) error } func hydrateConfig(ctx context.Context, provider secretsProvider, v reflect.Value) error { - if v.Kind() == reflect.Ptr { - if v.IsNil() { - return fmt.Errorf("passed config is nil") - } - v = v.Elem() + if v.Kind() != reflect.Ptr { + return fmt.Errorf("passed config must be a pointer") } - - if v.Kind() != reflect.Struct { - return fmt.Errorf("passed config must be struct, actual %s", v.Kind()) + if v.IsNil() { + return fmt.Errorf("passed config is nil") } + v = v.Elem() - secretFields, err := collectSecretFields(v) - if err != nil { - return fmt.Errorf("collecting secrets: %w", err) + if v.Kind() != reflect.Struct { + return fmt.Errorf("passed config must be pointer to a struct, got pointer to %s", v.Kind()) } - secretValues := map[string]string{} + secretKeys := collectSecretKeys(v) + secrets := make([]secret, len(secretKeys)) g := &errgroup.Group{} - for secretName, fieldPath := range secretFields { - secretName := secretName - fieldPath := fieldPath + for i, key := range secretKeys { + i, key := i, key g.Go(func() error { - secretValue, err := provider.FetchSecret(ctx, secretName) - if err != nil { - return fmt.Errorf("field %v=%q: fetching secret: %w", fieldPath, "$SECRET:"+secretName, err) + value, err := provider.FetchSecret(ctx, key) + secrets[i] = secret{ + key: key, + value: value, + fetchErr: err, } - log.Printf("Fetched %v\n", secretName) - secretValues[secretName] = secretValue - return nil }) } - if err := g.Wait(); err != nil { return err } - return replaceSecrets(v, secretValues) + return replaceSecrets(v, secrets) +} + +type secret struct { + key string + value string + fetchErr error } diff --git a/hydrate_test.go b/hydrate_test.go index 634c8af..50922e9 100644 --- a/hydrate_test.go +++ b/hydrate_test.go @@ -35,13 +35,25 @@ type service struct { Pass string } -func TestFailWhenPassedValueIsNotStruct(t *testing.T) { - input := "hello" +func TestHydrateFailIfNotPointerToStruct(t *testing.T) { + ctx := context.Background() + + str := "hello" + assert.Error(t, Hydrate(ctx, "", str)) + assert.Error(t, Hydrate(ctx, "", &str)) + + slice := []string{"hello", "hello2"} + assert.Error(t, Hydrate(ctx, "", slice)) + assert.Error(t, Hydrate(ctx, "", &slice)) - assert.Error(t, Hydrate(context.Background(), "", input)) + cfg := struct { + X, Y string + }{} + assert.Error(t, Hydrate(ctx, "", cfg)) + assert.NoError(t, Hydrate(ctx, "", &cfg)) } -func TestReplacePlaceholdersWithSecrets(t *testing.T) { +func TestHydrate(t *testing.T) { ctx := context.Background() tests := []struct { diff --git a/replacer.go b/replacer.go index 70953d6..6856a61 100644 --- a/replacer.go +++ b/replacer.go @@ -7,44 +7,53 @@ import ( "strings" ) -func replaceSecrets(v reflect.Value, secretValues map[string]string) error { - c := &replacer{ - secrets: secretValues, +// Replace values with "$SECRET:" prefix in v with values from secrets. +func replaceSecrets(v reflect.Value, secrets []secret) error { + r := &replacer{ + secretValues: map[string]string{}, + fetchErrors: map[string]error{}, } - c.replaceSecrets(v, "config") - if c.err != nil { - return fmt.Errorf("failed to collect fields: %w", c.err) + for _, secret := range secrets { + if secret.fetchErr != nil { + r.fetchErrors[secret.key] = secret.fetchErr + } else { + r.secretValues[secret.key] = secret.value + } + } + + r.replaceSecrets(v, "config") + if len(r.errs) > 0 { + return fmt.Errorf("failed to replace %v field(s): %w", len(r.errs), errors.Join(r.errs...)) } return nil } type replacer struct { - secrets map[string]string - err error + secretValues map[string]string + fetchErrors map[string]error + errs []error } -// Walk given value recursively and replace all string fields matching $SECRET: prefix. -func (c *replacer) replaceSecrets(v reflect.Value, path string) { +// Walk given v recursively and try to replace all secrets. Record errors along the way. +func (r *replacer) replaceSecrets(v reflect.Value, path string) { switch v.Kind() { case reflect.Ptr: if v.IsNil() { return } - - // Dereference pointer - c.replaceSecrets(v.Elem(), path) + r.replaceSecrets(v.Elem(), path) case reflect.Struct: for i := 0; i < v.NumField(); i++ { field := v.Field(i) - c.replaceSecrets(field, fmt.Sprintf("%v.%v", path, v.Type().Field(i).Name)) + r.replaceSecrets(field, fmt.Sprintf("%v.%v", path, v.Type().Field(i).Name)) } case reflect.Slice, reflect.Array: for i := 0; i < v.Len(); i++ { item := v.Index(i) - c.replaceSecrets(item, fmt.Sprintf("%v[%v]", path, i)) + r.replaceSecrets(item, fmt.Sprintf("%v[%v]", path, i)) } case reflect.Map: @@ -52,33 +61,32 @@ func (c *replacer) replaceSecrets(v reflect.Value, path string) { item := v.MapIndex(key) if item.Kind() == reflect.Struct { - // If the value is a struct, create a pointer to the map value and modify via pointer + // If the value is a struct, create a pointer to it, update the value and reassign the map. ptr := reflect.New(item.Type()) ptr.Elem().Set(item) - - c.replaceSecrets(ptr, fmt.Sprintf("%v[%v]", path, key)) - - // Set the modified struct back into the map + r.replaceSecrets(ptr, fmt.Sprintf("%v[%v]", path, key)) v.SetMapIndex(key, ptr.Elem()) } else { - c.replaceSecrets(item, fmt.Sprintf("%v[%v]", path, key)) + r.replaceSecrets(item, fmt.Sprintf("%v[%v]", path, key)) } } case reflect.String: - secretName, found := strings.CutPrefix(v.String(), "$SECRET:") + secretKey, found := strings.CutPrefix(v.String(), "$SECRET:") if !found { return } if !v.CanSet() { - c.err = errors.Join(c.err, fmt.Errorf("can't set field %v", path)) + r.errs = append(r.errs, fmt.Errorf("%v: reflect: can't set field", path)) return } - secretValue, ok := c.secrets[secretName] + secretValue, ok := r.secretValues[secretKey] if !ok { - c.err = errors.Join(c.err, fmt.Errorf("secret %v not found for field %v", secretName, path)) + err, _ := r.fetchErrors[secretKey] + r.errs = append(r.errs, fmt.Errorf("%v: %w", path, err)) + return } v.SetString(secretValue) diff --git a/replacer_test.go b/replacer_test.go deleted file mode 100644 index 30a1ab2..0000000 --- a/replacer_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package cloudsecrets - -import ( - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestReplaceFields(t *testing.T) { - tt := []struct { - Name string - Input any - Out []string // field paths - Error bool - }{ - { - Name: "DB_config_with_no_creds", - Input: &cfg{ - DB: dbConfig{ - User: "db-user", - Password: "db-password", - }, - DBPtr: &dbConfig{ - User: "db-user", - Password: "db-password", - }, - DBDoublePtr: ptr(&dbConfig{ - User: "db-user", - Password: "db-password", - }), - }, - Out: []string{}, - }, - { - Name: "DB_config_with_creds", - Input: &cfg{ - DB: dbConfig{ - User: "db-user", - Password: "$SECRET:db-password", - }, - }, - Out: []string{"db-password"}, - }, - { - Name: "DB config ptr with creds", - Input: &cfg{ - DBPtr: &dbConfig{ - User: "db-user", - Password: "$SECRET:db-password", - }, - }, - Out: []string{"db-password"}, - }, - { - Name: "DB_config_double_ptr_with_creds", - Input: &cfg{ - DBDoublePtr: ptr(&dbConfig{ - User: "db-user", - Password: "$SECRET:db-password", - }), - }, - Out: []string{"db-password"}, - }, - { - Name: "Slice_of_secret_values", - Input: &cfg{ - DB: dbConfig{ - User: "db-user", - Password: "$SECRET:secretName", - }, - JWTSecrets: []jwtSecret{"$SECRET:jwtSecret1", "$SECRET:jwtSecret2", "nope"}, - }, - Out: []string{"jwtSecret1", "jwtSecret2", "secretName"}, - }, - { - Name: "Slice_of_secret_pointer_values", - Input: &cfg{ - DB: dbConfig{ - User: "db-user", - Password: "$SECRET:secretName", - }, - JWTSecretsPtr: []*jwtSecret{ptr(jwtSecret("$SECRET:jwtSecret1")), ptr(jwtSecret("$SECRET:jwtSecret2")), ptr(jwtSecret("nope"))}, - }, - Out: []string{"jwtSecret1", "jwtSecret2", "secretName"}, - }, - { - Name: "Map_with_values", - Input: &cfg{ - Providers: map[string]providerConfig{ - "provider1": {Name: "provider1", Secret: "$SECRET:secretProvider1"}, - "provider2": {Name: "provider2", Secret: "$SECRET:secretProvider2"}, - "provider3": {Name: "provider3", Secret: "$SECRET:secretProvider3"}, - }, - }, - Out: []string{"secretProvider1", "secretProvider2", "secretProvider3"}, - }, - { - Name: "Map_with_ptr_values", - Input: &cfg{ - ProvidersPtr: map[string]*providerConfig{ - "provider1": {Name: "provider1", Secret: "$SECRET:secretProvider1"}, - "provider2": {Name: "provider2", Secret: "$SECRET:secretProvider2"}, - "provider3": {Name: "provider3", Secret: "$SECRET:secretProvider3"}, - }, - }, - Out: []string{"secretProvider1", "secretProvider2", "secretProvider3"}, - }, - { - Name: "Duplicated_secret", - Input: &cfg{ - DB: dbConfig{ - User: "db-user", - Password: "$SECRET:dup", - }, - JWTSecrets: []jwtSecret{"$SECRET:dup", "$SECRET:dup"}, - ProvidersPtr: map[string]*providerConfig{ - "provider1": {Name: "provider1", Secret: "$SECRET:dup"}, - "provider2": {Name: "provider2", Secret: "$SECRET:dup"}, - "provider3": {Name: "provider3", Secret: "$SECRET:dup"}, - }, - }, - Out: []string{"dup"}, - }, - { - Name: "Unexported_field_should_fail_to_hydrate", - Input: &cfg{ - unexported: dbConfig{ // unexported fields can't be updated via reflect pkg - User: "db-user", - Password: "$SECRET:secretName", // match inside unexported field - }, - }, - Out: []string{}, - Error: true, // expect error - }, - } - - for _, tc := range tt { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - v := reflect.ValueOf(tc.Input) - - secretFields, err := collectSecretFields(v) - if tc.Error && err == nil { - t.Error("expected error, got nil") - return - } else if !tc.Error && err != nil { - t.Errorf("unexpected error: %v", err) - return - } - - if !cmp.Equal(mapKeysSorted(secretFields), tc.Out) { - t.Errorf(cmp.Diff(tc.Out, mapKeysSorted(secretFields))) - } - }) - } -}