Skip to content

Commit

Permalink
Fix -race bug, tests and improve recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek committed Oct 14, 2024
1 parent fd8022c commit 4e2b9b0
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 287 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# go-cloudsecrets

Go package to hydrate runtime secrets from Cloud providers
Go package for hydrating config secrets from Cloud secret managers
- [x] `"gcp"`, GCP Secret Manager
- [ ] `"aws"`, AWS Secrets Manager
- [ ] `""`, empty provider, which errors out on `$SECRET:` value

```go
cloudsecrets.Hydrate(ctx, "gcp", &Config{})
```

`Hydrate()` recursively walks a given config (struct pointer) and hydrates all string
values matching `"$SECRET:"` prefix using a given Cloud secrets provider.
`Hydrate()` recursively walks given config (a `struct` pointer) and replaces all string
fields having `"$SECRET:"` prefix with a value fetched from a given Cloud secret provider.

The secret values to be replaced must have a format of `"$SECRET:{name|path}"`.
The value to be replaced must have a format of `"$SECRET:{name|path}"`.

Secrets are de-duplicated and fetched only once.

The `Hydrate()` function tries to replace as many fields as possible before returning error.

## Usage
```go
Expand All @@ -23,7 +28,7 @@ func main() {
Database: "postgres",
Host: "localhost:5432",
Username: "sequence",
DPassword: "$SECRET:dbPassword", // to be hydrated
DPassword: "$SECRET:dbPassword", // will be hydrated (replaced with value of "dbPassword" secret)
},
}

Expand Down
43 changes: 17 additions & 26 deletions collector.go
Original file line number Diff line number Diff line change
@@ -1,66 +1,57 @@
package cloudsecrets

import (
"fmt"
"reflect"
"slices"
"strings"
)

func collectSecretFields(v reflect.Value) (map[string]string, error) {
c := &collector{
fields: map[string]string{},
}
c.collectSecretFields(v, "config")
if c.err != nil {
return nil, fmt.Errorf("failed to collect fields: %w", c.err)
}
// Returns de-duplicated secret keys found recursively in given v.
func collectSecretKeys(v reflect.Value) []string {
c := collector{}
c.collectSecretFields(v)

return c.fields, nil
}
slices.Sort(c)
dedup := slices.Compact(c)

type collector struct {
fields map[string]string
err error
return []string(dedup)
}

// Walks given reflect value recursively and collects any string fields matching $SECRET: prefix.
func (c *collector) collectSecretFields(v reflect.Value, path string) {
type collector []string

// Walk given reflect value recursively and collects any string fields matching $SECRET: prefix.
func (c *collector) collectSecretFields(v reflect.Value) {
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
return
}

// Dereference pointer
c.collectSecretFields(v.Elem(), path)
c.collectSecretFields(v.Elem())

case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
c.collectSecretFields(field, fmt.Sprintf("%v.%v", path, v.Type().Field(i).Name))
c.collectSecretFields(field)
}

case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
item := v.Index(i)
c.collectSecretFields(item, fmt.Sprintf("%v[%v]", path, i))
c.collectSecretFields(item)
}

case reflect.Map:
for _, key := range v.MapKeys() {
item := v.MapIndex(key)
c.collectSecretFields(item, fmt.Sprintf("%v[%v]", path, key))
c.collectSecretFields(item)
}

case reflect.String:
secretName, found := strings.CutPrefix(v.String(), "$SECRET:")
if !found {
return
}

if _, ok := c.fields[secretName]; !ok {
c.fields[secretName] = path
}
*c = append(*c, secretName)

default:
return
Expand Down
51 changes: 11 additions & 40 deletions collector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package cloudsecrets

import (
"reflect"
"slices"
"testing"

"github.com/google/go-cmp/cmp"
)

func TestCollectFields(t *testing.T) {
func TestCollectSecretKeys(t *testing.T) {
tt := []struct {
Name string
Input any
Out []string // field paths
Out []string // collected secret keys
Error bool
}{
{
Expand Down Expand Up @@ -112,27 +111,16 @@ func TestCollectFields(t *testing.T) {
Input: &cfg{
DB: dbConfig{
User: "db-user",
Password: "$SECRET:dup",
Password: "$SECRET:duplicatedKey",
},
JWTSecrets: []jwtSecret{"$SECRET:dup", "$SECRET:dup"},
JWTSecrets: []jwtSecret{"$SECRET:duplicatedKey", "$SECRET:duplicatedKey"},
ProvidersPtr: map[string]*providerConfig{
"provider1": {Name: "provider1", Secret: "$SECRET:dup"},
"provider2": {Name: "provider2", Secret: "$SECRET:dup"},
"provider3": {Name: "provider3", Secret: "$SECRET:dup"},
"provider1": {Name: "provider1", Secret: "$SECRET:duplicatedKey"},
"provider2": {Name: "provider2", Secret: "$SECRET:duplicatedKey"},
"provider3": {Name: "provider3", Secret: "$SECRET:duplicatedKey"},
},
},
Out: []string{"dup"},
},
{
Name: "Unexported_field_should_fail_to_hydrate",
Input: &cfg{
unexported: dbConfig{ // unexported fields can't be updated via reflect pkg
User: "db-user",
Password: "$SECRET:secretName", // match inside unexported field
},
},
Out: []string{},
Error: true, // expect error
Out: []string{"duplicatedKey"},
},
}

Expand All @@ -141,17 +129,9 @@ func TestCollectFields(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
v := reflect.ValueOf(tc.Input)

secretFields, err := collectSecretFields(v)
if tc.Error && err == nil {
t.Error("expected error, got nil")
return
} else if !tc.Error && err != nil {
t.Errorf("unexpected error: %v", err)
return
}

if !cmp.Equal(mapKeysSorted(secretFields), tc.Out) {
t.Errorf(cmp.Diff(tc.Out, mapKeysSorted(secretFields)))
secretFields := collectSecretKeys(v)
if !cmp.Equal(secretFields, tc.Out) {
t.Errorf(cmp.Diff(tc.Out, secretFields))
}
})
}
Expand Down Expand Up @@ -181,12 +161,3 @@ type providerConfig struct {
type jwtSecret string

func ptr[T any](v T) *T { return &v }

func mapKeysSorted(m map[string]string) []string {
keys := []string{}
for key, _ := range m {
keys = append(keys, key)
}
slices.Sort(keys)
return keys
}
12 changes: 6 additions & 6 deletions gcp/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ type SecretsProvider struct {
func NewSecretsProvider() (*SecretsProvider, error) {
gcpClient, err := secretmanager.NewClient(context.Background())
if err != nil {
return nil, fmt.Errorf("secretmanager client: %w", err)
return nil, fmt.Errorf("gcp: secretmanager client: %w", err)
}

var projectNumber string
if metadata.OnGCE() {
projectNumber, err = metadata.NumericProjectID()
if err != nil {
return nil, fmt.Errorf("getting project ID from metadata: %w", err)
return nil, fmt.Errorf("gcp: getting project ID from metadata: %w", err)
}
} else {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

projectNumber, err = getProjectNumberFromGcloud(ctx)
if err != nil {
return nil, fmt.Errorf("getting project ID from gcloud: %w", err)
return nil, fmt.Errorf("gcp: getting project ID from gcloud: %w", err)
}
}

Expand All @@ -59,7 +59,7 @@ func (p SecretsProvider) FetchSecret(ctx context.Context, secretId string) (stri
// Access the secret version
result, err := p.client.AccessSecretVersion(reqCtx, req)
if err != nil {
return "", fmt.Errorf("accessing GCP secret %v (%v): %w", secretId, versionId, err)
return "", fmt.Errorf("gcp: accessing secret: %w", err)
}

// Return the secret value
Expand All @@ -75,15 +75,15 @@ func getProjectNumberFromGcloud(ctx context.Context) (string, error) {
if projectId == "" {
out, err := exec.CommandContext(ctx, "gcloud", "config", "get-value", "project").Output()
if err != nil {
return "", fmt.Errorf("getting current gcloud project (try `gcloud auth application-default login'): %w", err)
return "", fmt.Errorf("gcp: getting current gcloud project (try `gcloud auth application-default login'): %w", err)
}
projectId = strings.TrimSpace(string(out))
}

// We need projectNumber (not projectName!) for GCP Secret Manager APIs.
out, err := exec.CommandContext(ctx, "gcloud", "projects", "describe", projectId, "--format=value(projectNumber)").Output()
if err != nil {
return "", fmt.Errorf("getting gcloud projectNumber from projectId %q: %w", projectId, err)
return "", fmt.Errorf("gcp: getting gcloud projectNumber from projectId %q: %w", projectId, err)
}
return strings.TrimSpace(string(out)), nil
}
47 changes: 23 additions & 24 deletions hydrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cloudsecrets
import (
"context"
"fmt"
"log"
"reflect"

"github.com/0xsequence/go-cloudsecrets/gcp"
Expand Down Expand Up @@ -43,45 +42,45 @@ func Hydrate(ctx context.Context, providerName string, config interface{}) error
}

func hydrateConfig(ctx context.Context, provider secretsProvider, v reflect.Value) error {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return fmt.Errorf("passed config is nil")
}
v = v.Elem()
if v.Kind() != reflect.Ptr {
return fmt.Errorf("passed config must be a pointer")
}

if v.Kind() != reflect.Struct {
return fmt.Errorf("passed config must be struct, actual %s", v.Kind())
if v.IsNil() {
return fmt.Errorf("passed config is nil")
}
v = v.Elem()

secretFields, err := collectSecretFields(v)
if err != nil {
return fmt.Errorf("collecting secrets: %w", err)
if v.Kind() != reflect.Struct {
return fmt.Errorf("passed config must be pointer to a struct, got pointer to %s", v.Kind())
}

secretValues := map[string]string{}
secretKeys := collectSecretKeys(v)
secrets := make([]secret, len(secretKeys))

g := &errgroup.Group{}
for secretName, fieldPath := range secretFields {
secretName := secretName
fieldPath := fieldPath
for i, key := range secretKeys {
i, key := i, key

g.Go(func() error {
secretValue, err := provider.FetchSecret(ctx, secretName)
if err != nil {
return fmt.Errorf("field %v=%q: fetching secret: %w", fieldPath, "$SECRET:"+secretName, err)
value, err := provider.FetchSecret(ctx, key)
secrets[i] = secret{
key: key,
value: value,
fetchErr: err,
}

log.Printf("Fetched %v\n", secretName)
secretValues[secretName] = secretValue

return nil
})
}

if err := g.Wait(); err != nil {
return err
}

return replaceSecrets(v, secretValues)
return replaceSecrets(v, secrets)
}

type secret struct {
key string
value string
fetchErr error
}
20 changes: 16 additions & 4 deletions hydrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,25 @@ type service struct {
Pass string
}

func TestFailWhenPassedValueIsNotStruct(t *testing.T) {
input := "hello"
func TestHydrateFailIfNotPointerToStruct(t *testing.T) {
ctx := context.Background()

str := "hello"
assert.Error(t, Hydrate(ctx, "", str))
assert.Error(t, Hydrate(ctx, "", &str))

slice := []string{"hello", "hello2"}
assert.Error(t, Hydrate(ctx, "", slice))
assert.Error(t, Hydrate(ctx, "", &slice))

assert.Error(t, Hydrate(context.Background(), "", input))
cfg := struct {
X, Y string
}{}
assert.Error(t, Hydrate(ctx, "", cfg))
assert.NoError(t, Hydrate(ctx, "", &cfg))
}

func TestReplacePlaceholdersWithSecrets(t *testing.T) {
func TestHydrate(t *testing.T) {
ctx := context.Background()

tests := []struct {
Expand Down
Loading

0 comments on commit 4e2b9b0

Please sign in to comment.