From e2ded2af4e3410f306c2a2257b34dab46105d2fc Mon Sep 17 00:00:00 2001 From: Theron Voran Date: Wed, 19 Jan 2022 12:33:41 -0800 Subject: [PATCH] Do not store local service account token and CA to config. (#122) (#131) When defaulting to local JWT token and CA certificate in a pod, always read them from local filesystem and do not store them persistently with the config. Token will be re-read periodically to avoid using expired token. The change allows running Vault on Kubernetes 1.21 and newer, which switched to ID token that is bound to the pod and will expire. Signed-off-by: Tero Saarni * review comment fix: load only token or ca cert if other is given in config * changed the reload period to 1 minute * fixed review comments * take lock also on alias lookahead path * more review fixes * proposal to fix the read/write lock issue * fixed typo * cachedFile by value to avoid mutation while not holding log * acquire lock in same place as in pathLogin * added debug log entry when local token is not found and falling back to client token Co-authored-by: Tero Saarni --- backend.go | 60 ++++++++++++++- caching_file_reader.go | 68 +++++++++++++++++ caching_file_reader_test.go | 65 ++++++++++++++++ path_config.go | 26 ++----- path_config_test.go | 146 +++++++++++++++++++++++++++++------- path_login.go | 3 + 6 files changed, 320 insertions(+), 48 deletions(-) create mode 100644 caching_file_reader.go create mode 100644 caching_file_reader_test.go diff --git a/backend.go b/backend.go index 775cba57..2d12c931 100644 --- a/backend.go +++ b/backend.go @@ -7,6 +7,7 @@ import ( "fmt" "strings" "sync" + "time" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" @@ -27,6 +28,18 @@ var ( // when adding new alias name sources make sure to update the corresponding FieldSchema description in path_role.go aliasNameSources = []string{aliasNameSourceSAUid, aliasNameSourceSAName} errInvalidAliasNameSource = fmt.Errorf(`invalid alias_name_source, must be one of: %s`, strings.Join(aliasNameSources, ", ")) + + // jwtReloadPeriod is the time period how often the in-memory copy of local + // service account token can be used, before reading it again from disk. + // + // The value is selected according to recommendation in Kubernetes 1.21 changelog: + // "Clients should reload the token from disk periodically (once per minute + // is recommended) to ensure they continue to use a valid token." + jwtReloadPeriod = 1 * time.Minute + + // caReloadPeriod is the time period how often the in-memory copy of local + // CA cert can be used, before reading it again from disk. + caReloadPeriod = 1 * time.Hour ) // kubeAuthBackend implements logical.Backend @@ -38,6 +51,19 @@ type kubeAuthBackend struct { // review. Mocks should only be used in tests. reviewFactory tokenReviewFactory + // localSATokenReader caches the service account token in memory. + // It periodically reloads the token to support token rotation/renewal. + // Local token is used when running in a pod with following configuration + // - token_reviewer_jwt is not set + // - disable_local_ca_jwt is false + localSATokenReader *cachingFileReader + + // localCACertReader contains the local CA certificate. Local CA certificate is + // used when running in a pod with following configuration + // - kubernetes_ca_cert is not set + // - disable_local_ca_jwt is false + localCACertReader *cachingFileReader + l sync.RWMutex } @@ -51,7 +77,10 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, } func Backend() *kubeAuthBackend { - b := &kubeAuthBackend{} + b := &kubeAuthBackend{ + localSATokenReader: newCachingFileReader(localJWTPath, jwtReloadPeriod, time.Now), + localCACertReader: newCachingFileReader(localCACertPath, caReloadPeriod, time.Now), + } b.Backend = &framework.Backend{ AuthRenew: b.pathLoginRenew(), @@ -80,7 +109,8 @@ func Backend() *kubeAuthBackend { return b } -// config takes a storage object and returns a kubeConfig object +// config takes a storage object and returns a kubeConfig object. +// It does not return local token and CA file which are specific to the pod we run in. func (b *kubeAuthBackend) config(ctx context.Context, s logical.Storage) (*kubeConfig, error) { raw, err := s.Get(ctx, configPath) if err != nil { @@ -107,6 +137,8 @@ func (b *kubeAuthBackend) config(ctx context.Context, s logical.Storage) (*kubeC return conf, nil } +// loadConfig fetches the kubeConfig from storage and optionally decorates it with +// local token and CA certificate. func (b *kubeAuthBackend) loadConfig(ctx context.Context, s logical.Storage) (*kubeConfig, error) { config, err := b.config(ctx, s) if err != nil { @@ -115,6 +147,30 @@ func (b *kubeAuthBackend) loadConfig(ctx context.Context, s logical.Storage) (*k if config == nil { return nil, errors.New("could not load backend configuration") } + + // Nothing more to do if loading local CA cert and JWT token is disabled. + if config.DisableLocalCAJwt { + return config, nil + } + + // Read local JWT token unless it was not stored in config. + if config.TokenReviewerJWT == "" { + config.TokenReviewerJWT, err = b.localSATokenReader.ReadFile() + if err != nil { + // Ignore error: make best effort trying to load local JWT, + // otherwise the JWT submitted in login payload will be used. + b.Logger().Debug("failed to read local service account token, will use client token", "error", err) + } + } + + // Read local CA cert unless it was stored in config. + if config.CACert == "" { + config.CACert, err = b.localCACertReader.ReadFile() + if err != nil { + return nil, err + } + } + return config, nil } diff --git a/caching_file_reader.go b/caching_file_reader.go new file mode 100644 index 00000000..d18445fc --- /dev/null +++ b/caching_file_reader.go @@ -0,0 +1,68 @@ +package kubeauth + +import ( + "io/ioutil" + "sync" + "time" +) + +// cachingFileReader reads a file and keeps an in-memory copy of it, until the +// copy is considered stale. Next ReadFile() after expiry will re-read the file from disk. +type cachingFileReader struct { + // path is the file path to the cached file. + path string + + // ttl is the time-to-live duration when cached file is considered stale + ttl time.Duration + + // cache is the buffer holding the in-memory copy of the file. + cache cachedFile + + l sync.RWMutex + + // currentTime is a function that returns the current local time. + // Normally set to time.Now but it can be overwritten by test cases to manipulate time. + currentTime func() time.Time +} + +type cachedFile struct { + // buf is the buffer holding the in-memory copy of the file. + buf string + + // expiry is the time when the cached copy is considered stale and must be re-read. + expiry time.Time +} + +func newCachingFileReader(path string, ttl time.Duration, currentTime func() time.Time) *cachingFileReader { + return &cachingFileReader{ + path: path, + ttl: ttl, + currentTime: currentTime, + } +} + +func (r *cachingFileReader) ReadFile() (string, error) { + // Fast path requiring read lock only: file is already in memory and not stale. + r.l.RLock() + now := r.currentTime() + cache := r.cache + r.l.RUnlock() + if now.Before(cache.expiry) { + return cache.buf, nil + } + + // Slow path: read the file from disk. + r.l.Lock() + defer r.l.Unlock() + + buf, err := ioutil.ReadFile(r.path) + if err != nil { + return "", err + } + r.cache = cachedFile{ + buf: string(buf), + expiry: now.Add(r.ttl), + } + + return r.cache.buf, nil +} diff --git a/caching_file_reader_test.go b/caching_file_reader_test.go new file mode 100644 index 00000000..ba282510 --- /dev/null +++ b/caching_file_reader_test.go @@ -0,0 +1,65 @@ +package kubeauth + +import ( + "io/ioutil" + "os" + "testing" + "time" +) + +func TestCachingFileReader(t *testing.T) { + content1 := "before" + content2 := "after" + + // Create temporary file. + f, err := ioutil.TempFile("", "testfile") + if err != nil { + t.Error(err) + } + f.Close() + defer os.Remove(f.Name()) + + currentTime := time.Now() + + r := newCachingFileReader(f.Name(), 1*time.Minute, + func() time.Time { + return currentTime + }) + + // Write initial content to file and check that we can read it. + ioutil.WriteFile(f.Name(), []byte(content1), 0644) + got, err := r.ReadFile() + if err != nil { + t.Error(err) + } + if got != content1 { + t.Errorf("got '%s', expected '%s'", got, content1) + } + + // Write new content to the file. + ioutil.WriteFile(f.Name(), []byte(content2), 0644) + + // Advance simulated time, but not enough for cache to expire. + currentTime = currentTime.Add(30 * time.Second) + + // Read again and check we still got the old cached content. + got, err = r.ReadFile() + if err != nil { + t.Error(err) + } + if got != content1 { + t.Errorf("got '%s', expected '%s'", got, content1) + } + + // Advance simulated time for cache to expire. + currentTime = currentTime.Add(30 * time.Second) + + // Read again and check that we got the new content. + got, err = r.ReadFile() + if err != nil { + t.Error(err) + } + if got != content2 { + t.Errorf("got '%s', expected '%s'", got, content2) + } +} diff --git a/path_config.go b/path_config.go index b593dede..d61b5d93 100644 --- a/path_config.go +++ b/path_config.go @@ -7,14 +7,13 @@ import ( "crypto/x509" "encoding/pem" "errors" - "io/ioutil" "github.com/briankassouf/jose/jws" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" ) -var ( +const ( localCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" localJWTPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" ) @@ -126,30 +125,13 @@ func (b *kubeAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Requ } disableLocalJWT := data.Get("disable_local_ca_jwt").(bool) - localCACert := []byte{} - localTokenReviewer := []byte{} - if !disableLocalJWT { - localCACert, _ = ioutil.ReadFile(localCACertPath) - localTokenReviewer, _ = ioutil.ReadFile(localJWTPath) - } pemList := data.Get("pem_keys").([]string) caCert := data.Get("kubernetes_ca_cert").(string) issuer := data.Get("issuer").(string) disableIssValidation := data.Get("disable_iss_validation").(bool) - if len(pemList) == 0 && len(caCert) == 0 { - if len(localCACert) > 0 { - caCert = string(localCACert) - } else { - return logical.ErrorResponse("one of pem_keys or kubernetes_ca_cert must be set"), nil - } - } - tokenReviewer := data.Get("token_reviewer_jwt").(string) - if !disableLocalJWT && len(tokenReviewer) == 0 && len(localTokenReviewer) > 0 { - tokenReviewer = string(localTokenReviewer) - } - if len(tokenReviewer) > 0 { + if tokenReviewer != "" { // Validate it's a JWT _, err := jws.ParseJWT([]byte(tokenReviewer)) if err != nil { @@ -157,6 +139,10 @@ func (b *kubeAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Requ } } + if disableLocalJWT && caCert == "" { + return logical.ErrorResponse("kubernetes_ca_cert must be given when disable_local_ca_jwt is true"), nil + } + config := &kubeConfig{ PublicKeys: make([]interface{}, len(pemList)), PEMKeys: pemList, diff --git a/path_config_test.go b/path_config_test.go index c49ed6a4..e5cd86bb 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -6,13 +6,40 @@ import ( "os" "reflect" "testing" + "time" "github.com/hashicorp/vault/sdk/logical" ) +func setupLocalFiles(t *testing.T, b logical.Backend) func() { + cert, err := ioutil.TempFile("", "ca.crt") + if err != nil { + t.Fatal(err) + } + cert.WriteString(testLocalCACert) + cert.Close() + + token, err := ioutil.TempFile("", "token") + if err != nil { + t.Fatal(err) + } + token.WriteString(testLocalJWT) + token.Close() + b.(*kubeAuthBackend).localCACertReader = newCachingFileReader(cert.Name(), caReloadPeriod, time.Now) + b.(*kubeAuthBackend).localSATokenReader = newCachingFileReader(token.Name(), jwtReloadPeriod, time.Now) + + return func() { + os.Remove(cert.Name()) + os.Remove(token.Name()) + } +} + func TestConfig_Read(t *testing.T) { b, storage := getBackend(t) + cleanup := setupLocalFiles(t, b) + defer cleanup() + data := map[string]interface{}{ "pem_keys": []string{testRSACert, testECCert}, "kubernetes_host": "host", @@ -54,6 +81,9 @@ func TestConfig_Read(t *testing.T) { func TestConfig(t *testing.T) { b, storage := getBackend(t) + cleanup := setupLocalFiles(t, b) + defer cleanup() + // test no certificate data := map[string]interface{}{ "kubernetes_host": "host", @@ -67,11 +97,8 @@ func TestConfig(t *testing.T) { } resp, err := b.HandleRequest(context.Background(), req) - if resp == nil || !resp.IsError() { - t.Fatal("expected error") - } - if resp.Error().Error() != "one of pem_keys or kubernetes_ca_cert must be set" { - t.Fatalf("got unexpected error: %v", resp.Error()) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) } // test no host @@ -331,24 +358,16 @@ func TestConfig(t *testing.T) { } func TestConfig_LocalCaJWT(t *testing.T) { - b, storage := getBackend(t) - - // write "local" CA and JWT, and override local path vars - caFile := writeToTempFile(t, testLocalCACert) - localCACertPath = caFile - defer os.Remove(caFile) - jwtFile := writeToTempFile(t, testLocalJWT) - localJWTPath = jwtFile - defer os.Remove(jwtFile) - testCases := map[string]struct { - config map[string]interface{} - expected *kubeConfig + config map[string]interface{} + setupInClusterFiles bool + expected *kubeConfig }{ "no CA or JWT, default to local": { config: map[string]interface{}{ "kubernetes_host": "host", }, + setupInClusterFiles: true, expected: &kubeConfig{ PublicKeys: []interface{}{}, PEMKeys: []string{}, @@ -364,6 +383,7 @@ func TestConfig_LocalCaJWT(t *testing.T) { "kubernetes_host": "host", "kubernetes_ca_cert": testCACert, }, + setupInClusterFiles: true, expected: &kubeConfig{ PublicKeys: []interface{}{}, PEMKeys: []string{}, @@ -379,6 +399,7 @@ func TestConfig_LocalCaJWT(t *testing.T) { "kubernetes_host": "host", "token_reviewer_jwt": jwtData, }, + setupInClusterFiles: true, expected: &kubeConfig{ PublicKeys: []interface{}{}, PEMKeys: []string{}, @@ -409,6 +430,13 @@ func TestConfig_LocalCaJWT(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { + b, storage := getBackend(t) + + if tc.setupInClusterFiles { + cleanup := setupLocalFiles(t, b) + defer cleanup() + } + req := &logical.Request{ Operation: logical.CreateOperation, Path: configPath, @@ -421,7 +449,7 @@ func TestConfig_LocalCaJWT(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - conf, err := b.(*kubeAuthBackend).config(context.Background(), storage) + conf, err := b.(*kubeAuthBackend).loadConfig(context.Background(), storage) if err != nil { t.Fatal(err) } @@ -433,18 +461,84 @@ func TestConfig_LocalCaJWT(t *testing.T) { } } -func writeToTempFile(t *testing.T, contents string) string { - t.Helper() +func TestConfig_LocalJWTRenewal(t *testing.T) { + b, storage := getBackend(t) - f, err := ioutil.TempFile("", "test") + cleanup := setupLocalFiles(t, b) + defer cleanup() + + // Create temp file that will be used as token. + f, err := ioutil.TempFile("", "renewed-token") if err != nil { - t.Fatalf("Failure to create test file: %s", err) + t.Error(err) } - _, err = f.WriteString(contents) - if err != nil { - t.Fatalf("Failure to write test file: %s", err) + f.Close() + defer os.Remove(f.Name()) + + currentTime := time.Now() + + b.(*kubeAuthBackend).localSATokenReader = newCachingFileReader(f.Name(), jwtReloadPeriod, func() time.Time { + return currentTime + }) + + token1 := "before-renewal" + token2 := "after-renewal" + + // Write initial token to the temp file. + ioutil.WriteFile(f.Name(), []byte(token1), 0644) + + data := map[string]interface{}{ + "kubernetes_host": "host", } - return f.Name() + req := &logical.Request{ + Operation: logical.CreateOperation, + Path: configPath, + Storage: storage, + Data: data, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Loading the config will load the initial token file from disk. + conf, err := b.(*kubeAuthBackend).loadConfig(context.Background(), storage) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Check that we loaded the initial token. + if conf.TokenReviewerJWT != token1 { + t.Fatalf("got unexpected JWT: expected %#v\n got %#v\n", token1, conf.TokenReviewerJWT) + } + + // Write new value to the token file to simulate renewal. + ioutil.WriteFile(f.Name(), []byte(token2), 0644) + + // Load again to check we still got the old cached token from memory. + conf, err = b.(*kubeAuthBackend).loadConfig(context.Background(), storage) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if conf.TokenReviewerJWT != token1 { + t.Fatalf("got unexpected JWT: expected %#v\n got %#v\n", token1, conf.TokenReviewerJWT) + } + + // Advance simulated time for cache to expire + currentTime = currentTime.Add(1 * time.Minute) + + // Load again and check we the new renewed token from disk. + conf, err = b.(*kubeAuthBackend).loadConfig(context.Background(), storage) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if conf.TokenReviewerJWT != token2 { + t.Fatalf("got unexpected JWT: expected %#v\n got %#v\n", token2, conf.TokenReviewerJWT) + } + } var testLocalCACert string = `-----BEGIN CERTIFICATE----- diff --git a/path_login.go b/path_login.go index 51f1fcf6..ec8d8a1b 100644 --- a/path_login.go +++ b/path_login.go @@ -180,6 +180,9 @@ func (b *kubeAuthBackend) aliasLookahead(ctx context.Context, req *logical.Reque return resp, nil } + b.l.RLock() + defer b.l.RUnlock() + role, err := b.role(ctx, req.Storage, roleName) if err != nil { return nil, err