Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Refactor certificate watcher to use polling, instead of fsnotify #3020

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/scratch-env/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/evanphx/json-patch/v5 v5.9.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/zapr v1.3.0 // indirect
Expand Down
2 changes: 0 additions & 2 deletions examples/scratch-env/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8
github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg=
github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.23.0

require (
github.com/evanphx/json-patch/v5 v5.9.0
github.com/fsnotify/fsnotify v1.7.0
github.com/go-logr/logr v1.4.2
github.com/go-logr/zapr v1.3.0
github.com/google/go-cmp v0.6.0
Expand Down Expand Up @@ -41,6 +40,7 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
Expand Down
161 changes: 57 additions & 104 deletions pkg/certwatcher/certwatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,55 @@ limitations under the License.
package certwatcher

import (
"bytes"
"context"
"crypto/tls"
"fmt"
"os"
"sync"
"time"

"github.com/fsnotify/fsnotify"
kerrors "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
"sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
logf "sigs.k8s.io/controller-runtime/pkg/internal/log"
)

var log = logf.RuntimeLog.WithName("certwatcher")

// CertWatcher watches certificate and key files for changes. When either file
// changes, it reads and parses both and calls an optional callback with the new
// certificate.
const defaultWatchInterval = 10 * time.Second

// CertWatcher watches certificate and key files for changes.
// It always returns the cached version,
// but periodically reads and parses certificate and key for changes
// and calls an optional callback with the new certificate.
type CertWatcher struct {
sync.RWMutex

currentCert *tls.Certificate
watcher *fsnotify.Watcher
interval time.Duration

certPath string
keyPath string

cachedKeyPEMBlock []byte

// callback is a function to be invoked when the certificate changes.
callback func(tls.Certificate)
}

// New returns a new CertWatcher watching the given certificate and key.
func New(certPath, keyPath string) (*CertWatcher, error) {
var err error

cw := &CertWatcher{
certPath: certPath,
keyPath: keyPath,
interval: defaultWatchInterval,
}

// Initial read of certificate and key.
if err := cw.ReadCertificate(); err != nil {
return nil, err
}

cw.watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return cw, cw.ReadCertificate()
}

return cw, nil
// WithWatchInterval sets the watch interval and returns the CertWatcher pointer
func (cw *CertWatcher) WithWatchInterval(interval time.Duration) *CertWatcher {
cw.interval = interval
return cw
}

// RegisterCallback registers a callback to be invoked when the certificate changes.
Expand All @@ -91,72 +88,64 @@ func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate,

// Start starts the watch on the certificate and key files.
func (cw *CertWatcher) Start(ctx context.Context) error {
files := sets.New(cw.certPath, cw.keyPath)

{
var watchErr error
if err := wait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Second, true, func(ctx context.Context) (done bool, err error) {
for _, f := range files.UnsortedList() {
if err := cw.watcher.Add(f); err != nil {
watchErr = err
return false, nil //nolint:nilerr // We want to keep trying.
}
// We've added the watch, remove it from the set.
files.Delete(f)
}
return true, nil
}); err != nil {
return fmt.Errorf("failed to add watches: %w", kerrors.NewAggregate([]error{err, watchErr}))
}
}

go cw.Watch()
ticker := time.NewTicker(cw.interval)
defer ticker.Stop()

log.Info("Starting certificate watcher")

// Block until the context is done.
<-ctx.Done()

return cw.watcher.Close()
}

// Watch reads events from the watcher's channel and reacts to changes.
func (cw *CertWatcher) Watch() {
for {
select {
case event, ok := <-cw.watcher.Events:
// Channel is closed.
if !ok {
return
case <-ctx.Done():
return nil
case <-ticker.C:
if err := cw.ReadCertificate(); err != nil {
log.Error(err, "failed read certificate")
}
}
}
}

cw.handleEvent(event)

case err, ok := <-cw.watcher.Errors:
// Channel is closed.
if !ok {
return
}
// updateCachedCertificate checks if the new certificate differs from the cache,
// updates it and returns the result if it was updated or not
func (cw *CertWatcher) updateCachedCertificate(cert *tls.Certificate, keyPEMBlock []byte) bool {
cw.Lock()
defer cw.Unlock()

log.Error(err, "certificate watch error")
}
if cw.currentCert != nil &&
bytes.Equal(cw.currentCert.Certificate[0], cert.Certificate[0]) &&
bytes.Equal(cw.cachedKeyPEMBlock, keyPEMBlock) {
log.V(7).Info("certificate already cached")
return false
}
cw.currentCert = cert
cw.cachedKeyPEMBlock = keyPEMBlock
return true
}

// ReadCertificate reads the certificate and key files from disk, parses them,
// and updates the current certificate on the watcher. If a callback is set, it
// and updates the current certificate on the watcher if updated. If a callback is set, it
// is invoked with the new certificate.
func (cw *CertWatcher) ReadCertificate() error {
metrics.ReadCertificateTotal.Inc()
cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
certPEMBlock, err := os.ReadFile(cw.certPath)
if err != nil {
metrics.ReadCertificateErrors.Inc()
return err
}
keyPEMBlock, err := os.ReadFile(cw.keyPath)
if err != nil {
metrics.ReadCertificateErrors.Inc()
return err
}

cw.Lock()
cw.currentCert = &cert
cw.Unlock()
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
metrics.ReadCertificateErrors.Inc()
return err
}

if !cw.updateCachedCertificate(&cert, keyPEMBlock) {
return nil
}

log.Info("Updated current TLS certificate")

Expand All @@ -170,39 +159,3 @@ func (cw *CertWatcher) ReadCertificate() error {
}
return nil
}

func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
// Only care about events which may modify the contents of the file.
if !(isWrite(event) || isRemove(event) || isCreate(event) || isChmod(event)) {
return
}

log.V(1).Info("certificate event", "event", event)

// If the file was removed or renamed, re-add the watch to the previous name
if isRemove(event) || isChmod(event) {
if err := cw.watcher.Add(event.Name); err != nil {
log.Error(err, "error re-watching file")
}
}

if err := cw.ReadCertificate(); err != nil {
log.Error(err, "error re-reading certificate")
}
}

func isWrite(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Write)
}

func isCreate(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Create)
}

func isRemove(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Remove)
}

func isChmod(event fsnotify.Event) bool {
return event.Op.Has(fsnotify.Chmod)
}
1 change: 1 addition & 0 deletions pkg/certwatcher/certwatcher_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

logf "sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
)
Expand Down
50 changes: 39 additions & 11 deletions pkg/certwatcher/certwatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/prometheus/client_golang/prometheus/testutil"

"sigs.k8s.io/controller-runtime/pkg/certwatcher"
"sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
)
Expand Down Expand Up @@ -80,7 +81,7 @@ var _ = Describe("CertWatcher", func() {
go func() {
defer GinkgoRecover()
defer close(doneCh)
Expect(watcher.Start(ctx)).To(Succeed())
Expect(watcher.WithWatchInterval(time.Second).Start(ctx)).To(Succeed())
}()
// wait till we read first cert
Eventually(func() error {
Expand Down Expand Up @@ -113,7 +114,7 @@ var _ = Describe("CertWatcher", func() {
Eventually(func() bool {
secondcert, _ := watcher.GetCertificate(nil)
first := firstcert.PrivateKey.(*rsa.PrivateKey)
return first.Equal(secondcert.PrivateKey)
return first.Equal(secondcert.PrivateKey) || firstcert.Leaf.SerialNumber == secondcert.Leaf.SerialNumber
}).ShouldNot(BeTrue())

ctxCancel()
Expand Down Expand Up @@ -143,14 +144,41 @@ var _ = Describe("CertWatcher", func() {
Eventually(func() bool {
secondcert, _ := watcher.GetCertificate(nil)
first := firstcert.PrivateKey.(*rsa.PrivateKey)
return first.Equal(secondcert.PrivateKey)
return first.Equal(secondcert.PrivateKey) || firstcert.Leaf.SerialNumber == secondcert.Leaf.SerialNumber
}).ShouldNot(BeTrue())

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
Expect(called.Load()).To(BeNumerically(">=", 1))
})

It("should reload currentCert after move out", func() {
doneCh := startWatcher()
called := atomic.Int64{}
watcher.RegisterCallback(func(crt tls.Certificate) {
called.Add(1)
Expect(crt.Certificate).ToNot(BeEmpty())
})

firstcert, _ := watcher.GetCertificate(nil)

Expect(os.Rename(certPath, certPath+".old")).To(Succeed())
Expect(os.Rename(keyPath, keyPath+".old")).To(Succeed())

err := writeCerts(certPath, keyPath, "192.168.0.3")
Expect(err).ToNot(HaveOccurred())

Eventually(func() bool {
secondcert, _ := watcher.GetCertificate(nil)
first := firstcert.PrivateKey.(*rsa.PrivateKey)
return first.Equal(secondcert.PrivateKey) || firstcert.Leaf.SerialNumber == secondcert.Leaf.SerialNumber
}, "10s", "1s").ShouldNot(BeTrue())

ctxCancel()
Eventually(doneCh, "4s").Should(BeClosed())
Expect(called.Load()).To(BeNumerically(">=", 1))
})

Context("prometheus metric read_certificate_total", func() {
var readCertificateTotalBefore float64
var readCertificateErrorsBefore float64
Expand All @@ -165,8 +193,8 @@ var _ = Describe("CertWatcher", func() {

Eventually(func() error {
readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
if readCertificateTotalAfter < readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
}
return nil
}, "4s").Should(Succeed())
Expand All @@ -180,8 +208,8 @@ var _ = Describe("CertWatcher", func() {

Eventually(func() error {
readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
if readCertificateTotalAfter < readCertificateTotalBefore+1.0 {
return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
}
readCertificateTotalBefore = readCertificateTotalAfter
return nil
Expand All @@ -192,15 +220,15 @@ var _ = Describe("CertWatcher", func() {
// Note, we are checking two errors here, because os.Remove generates two fsnotify events: Chmod + Remove
Eventually(func() error {
readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
if readCertificateTotalAfter != readCertificateTotalBefore+2.0 {
return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+2.0, readCertificateTotalAfter)
if readCertificateTotalAfter < readCertificateTotalBefore+2.0 {
return fmt.Errorf("metric read certificate total expected at least: %v and got: %v", readCertificateTotalBefore+2.0, readCertificateTotalAfter)
}
return nil
}, "4s").Should(Succeed())
Eventually(func() error {
readCertificateErrorsAfter := testutil.ToFloat64(metrics.ReadCertificateErrors)
if readCertificateErrorsAfter != readCertificateErrorsBefore+2.0 {
return fmt.Errorf("metric read certificate errors expected: %v and got: %v", readCertificateErrorsBefore+2.0, readCertificateErrorsAfter)
if readCertificateErrorsAfter < readCertificateErrorsBefore+2.0 {
return fmt.Errorf("metric read certificate errors expected at least: %v and got: %v", readCertificateErrorsBefore+2.0, readCertificateErrorsAfter)
}
return nil
}, "4s").Should(Succeed())
Expand Down
2 changes: 1 addition & 1 deletion pkg/certwatcher/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func Example() {
panic(err)
}

// Start goroutine with certwatcher running fsnotify against supplied certdir
// Start goroutine with certwatcher running against supplied cert
go func() {
if err := watcher.Start(ctx); err != nil {
panic(err)
Expand Down
1 change: 1 addition & 0 deletions pkg/certwatcher/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package metrics

import (
"github.com/prometheus/client_golang/prometheus"

"sigs.k8s.io/controller-runtime/pkg/metrics"
)

Expand Down