Skip to content

Commit

Permalink
Fix certificate handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mrIncompetent committed Dec 13, 2022
1 parent c8f8042 commit 7641a93
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 19 deletions.
89 changes: 73 additions & 16 deletions pkg/tls/certificate_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"reflect"
"sync"
"time"

Expand Down Expand Up @@ -44,23 +45,63 @@ type CertificateWatcher struct {
metrics CertificateWatcherMetrics
}

func (cw *CertificateWatcher) setupWatches(watcher *fsnotify.Watcher) error {
existingWatches := map[string]bool{}
for _, p := range watcher.WatchList() {
existingWatches[p] = true
}

for _, path := range []string{cw.certPath, cw.keyPath} {
if _, exists := existingWatches[path]; !exists {
if err := watcher.Add(path); err != nil {
return fmt.Errorf("failed to add file to watcher for %s: %w", path, err)
}
}
}

return nil
}

func (cw *CertificateWatcher) Run(ctx context.Context) error {
reloadChan := make(chan bool, 1)
defer close(reloadChan)
reWatchChan := make(chan bool, 1)
defer close(reWatchChan)

watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to create filewatcher: %w", err)
}
defer watcher.Close()

if err := watcher.Add(cw.certPath); err != nil {
return fmt.Errorf("failed to watch %s: %w", cw.certPath, err)
if err := cw.setupWatches(watcher); err != nil {
return fmt.Errorf("failed to setup filewatcher: %w", err)
}

if err := watcher.Add(cw.keyPath); err != nil {
return fmt.Errorf("failed to watch %s: %w", cw.certPath, err)
}
go func() {
for {
select {
case <-ctx.Done():
return
default:
}

_, ok := <-reWatchChan
if !ok {
return
}

for {
if err := cw.setupWatches(watcher); err != nil {
cw.log.Error("failed to setup filewatcher", zap.Error(err))
time.Sleep(1 * time.Second)
continue
}
reloadChan <- true
break
}
}
}()

go func() {
for {
Expand All @@ -74,13 +115,25 @@ func (cw *CertificateWatcher) Run(ctx context.Context) error {
if !ok {
return
}

cw.log.Info("Watcher event", zap.Object("event", zapcore.ObjectMarshalerFunc(func(enc zapcore.ObjectEncoder) error {
cw.log.Debug("Watcher event", zap.Object("event", zapcore.ObjectMarshalerFunc(func(enc zapcore.ObjectEncoder) error {
enc.AddString("name", event.Name)
enc.AddString("op", event.Op.String())
return nil
})))
reloadChan <- true

if event.Op == fsnotify.Chmod {
continue
}
if event.Op == fsnotify.Remove {
reWatchChan <- true
continue
}

if event.Op == fsnotify.Remove {
reWatchChan <- true
} else {
reloadChan <- true
}
}
}()

Expand All @@ -96,15 +149,13 @@ func (cw *CertificateWatcher) Run(ctx context.Context) error {
return
}

if err := cw.reload(); err != nil {
cw.log.Error("failed to reload certificate", zap.Error(err))

go func() {
for {
if err := cw.reload(); err != nil {
cw.log.Error("failed to reload certificates", zap.Error(err))
time.Sleep(1 * time.Second)
reloadChan <- true
}()
} else {
cw.log.Info("Reloaded certificates")
continue
}
break
}
}
}()
Expand All @@ -127,13 +178,19 @@ func (cw *CertificateWatcher) reload() error {
return fmt.Errorf("failed to parse certificate: %w", err)
}

if reflect.DeepEqual(&cert, cw.cert) {
return nil
}

cw.certLock.Lock()
defer cw.certLock.Unlock()
cw.cert = &cert

cw.metrics.ReloadSuccess()
cw.metrics.CertificateExpirationTimestamp(x509Cert.NotAfter)

cw.log.Info("Reloaded certificates")

return nil
}

Expand Down
4 changes: 1 addition & 3 deletions pkg/tls/certificate_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ func TestNewCertificateWatcher(t *testing.T) {
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv)
require.NoError(t, err)

dir, err := os.MkdirTemp(os.TempDir(), "cert-watcher")
require.NoError(t, err)
t.Cleanup(func() { assert.NoError(t, os.RemoveAll(dir)) })
dir := t.TempDir()

certPath := path.Join(dir, "cert.pem")
require.NoError(t, os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), 0666))
Expand Down

0 comments on commit 7641a93

Please sign in to comment.