From 7641a9349a52670f9f5829840aace34bd3fa9b31 Mon Sep 17 00:00:00 2001 From: mrIncompetent Date: Tue, 13 Dec 2022 23:30:36 +0100 Subject: [PATCH] Fix certificate handling --- pkg/tls/certificate_watcher.go | 89 +++++++++++++++++++++++------ pkg/tls/certificate_watcher_test.go | 4 +- 2 files changed, 74 insertions(+), 19 deletions(-) diff --git a/pkg/tls/certificate_watcher.go b/pkg/tls/certificate_watcher.go index 3df9ae5..e8e3b11 100644 --- a/pkg/tls/certificate_watcher.go +++ b/pkg/tls/certificate_watcher.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "reflect" "sync" "time" @@ -44,9 +45,28 @@ type CertificateWatcher struct { metrics CertificateWatcherMetrics } +func (cw *CertificateWatcher) setupWatches(watcher *fsnotify.Watcher) error { + existingWatches := map[string]bool{} + for _, p := range watcher.WatchList() { + existingWatches[p] = true + } + + for _, path := range []string{cw.certPath, cw.keyPath} { + if _, exists := existingWatches[path]; !exists { + if err := watcher.Add(path); err != nil { + return fmt.Errorf("failed to add file to watcher for %s: %w", path, err) + } + } + } + + return nil +} + func (cw *CertificateWatcher) Run(ctx context.Context) error { reloadChan := make(chan bool, 1) defer close(reloadChan) + reWatchChan := make(chan bool, 1) + defer close(reWatchChan) watcher, err := fsnotify.NewWatcher() if err != nil { @@ -54,13 +74,34 @@ func (cw *CertificateWatcher) Run(ctx context.Context) error { } defer watcher.Close() - if err := watcher.Add(cw.certPath); err != nil { - return fmt.Errorf("failed to watch %s: %w", cw.certPath, err) + if err := cw.setupWatches(watcher); err != nil { + return fmt.Errorf("failed to setup filewatcher: %w", err) } - if err := watcher.Add(cw.keyPath); err != nil { - return fmt.Errorf("failed to watch %s: %w", cw.certPath, err) - } + go func() { + for { + select { + case <-ctx.Done(): + return + default: + } + + _, ok := <-reWatchChan + if !ok { + return + } + + for { + if err := cw.setupWatches(watcher); err != nil { + cw.log.Error("failed to setup filewatcher", zap.Error(err)) + time.Sleep(1 * time.Second) + continue + } + reloadChan <- true + break + } + } + }() go func() { for { @@ -74,13 +115,25 @@ func (cw *CertificateWatcher) Run(ctx context.Context) error { if !ok { return } - - cw.log.Info("Watcher event", zap.Object("event", zapcore.ObjectMarshalerFunc(func(enc zapcore.ObjectEncoder) error { + cw.log.Debug("Watcher event", zap.Object("event", zapcore.ObjectMarshalerFunc(func(enc zapcore.ObjectEncoder) error { enc.AddString("name", event.Name) enc.AddString("op", event.Op.String()) return nil }))) - reloadChan <- true + + if event.Op == fsnotify.Chmod { + continue + } + if event.Op == fsnotify.Remove { + reWatchChan <- true + continue + } + + if event.Op == fsnotify.Remove { + reWatchChan <- true + } else { + reloadChan <- true + } } }() @@ -96,15 +149,13 @@ func (cw *CertificateWatcher) Run(ctx context.Context) error { return } - if err := cw.reload(); err != nil { - cw.log.Error("failed to reload certificate", zap.Error(err)) - - go func() { + for { + if err := cw.reload(); err != nil { + cw.log.Error("failed to reload certificates", zap.Error(err)) time.Sleep(1 * time.Second) - reloadChan <- true - }() - } else { - cw.log.Info("Reloaded certificates") + continue + } + break } } }() @@ -127,6 +178,10 @@ func (cw *CertificateWatcher) reload() error { return fmt.Errorf("failed to parse certificate: %w", err) } + if reflect.DeepEqual(&cert, cw.cert) { + return nil + } + cw.certLock.Lock() defer cw.certLock.Unlock() cw.cert = &cert @@ -134,6 +189,8 @@ func (cw *CertificateWatcher) reload() error { cw.metrics.ReloadSuccess() cw.metrics.CertificateExpirationTimestamp(x509Cert.NotAfter) + cw.log.Info("Reloaded certificates") + return nil } diff --git a/pkg/tls/certificate_watcher_test.go b/pkg/tls/certificate_watcher_test.go index 9bdeca1..2896a52 100644 --- a/pkg/tls/certificate_watcher_test.go +++ b/pkg/tls/certificate_watcher_test.go @@ -40,9 +40,7 @@ func TestNewCertificateWatcher(t *testing.T) { derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv) require.NoError(t, err) - dir, err := os.MkdirTemp(os.TempDir(), "cert-watcher") - require.NoError(t, err) - t.Cleanup(func() { assert.NoError(t, os.RemoveAll(dir)) }) + dir := t.TempDir() certPath := path.Join(dir, "cert.pem") require.NoError(t, os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), 0666))