Skip to content

Commit

Permalink
Make client cert reloading dynamic
Browse files Browse the repository at this point in the history
Instead of requiring a restart of edm to pick up new client certificates
for MQTT and HTTP aggregate sending we now use fsnotify monitoring to
dynamically reload a certificate if it is updated on disk.

This also required modification of registerFSWatcher() as it would
previously only support a single callback function for a given filename.
Because MQTT and HTTP aggregate sending may use the same certificate
file we now support the assignment of multiple callbacks for a given
file.
  • Loading branch information
eest committed Oct 31, 2024
1 parent 4e3c9a2 commit 8410b24
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 38 deletions.
8 changes: 4 additions & 4 deletions pkg/runner/aggregate_sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type aggregateSender struct {
signingHTTPClient *httpsign.Client
}

func (edm *dnstapMinimiser) newAggregateSender(aggrecURL *url.URL, signingKeyName string, signingKey *ecdsa.PrivateKey, caCertPool *x509.CertPool, clientCert tls.Certificate) aggregateSender {
func (edm *dnstapMinimiser) newAggregateSender(aggrecURL *url.URL, signingKeyName string, signingKey *ecdsa.PrivateKey, caCertPool *x509.CertPool, clientCertStore *certStore) aggregateSender {
// Create HTTP handler for sending aggregate files to aggrec
httpClient := http.Client{
Transport: &http.Transport{
Expand All @@ -38,9 +38,9 @@ func (edm *dnstapMinimiser) newAggregateSender(aggrecURL *url.URL, signingKeyNam
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{clientCert},
MinVersion: tls.VersionTLS13,
RootCAs: caCertPool,
GetClientCertificate: clientCertStore.getClientCertficate,
MinVersion: tls.VersionTLS13,
},
},
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/runner/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jws"
)

func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, server string, clientID string, clientCert tls.Certificate, mqttKeepAlive uint16) (autopaho.ClientConfig, error) {
func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, server string, clientID string, clientCertStore *certStore, mqttKeepAlive uint16) (autopaho.ClientConfig, error) {
u, err := url.Parse(server)
if err != nil {
return autopaho.ClientConfig{}, fmt.Errorf("newAutoPahoClientConfig: unable to parse URL: %w", err)
Expand All @@ -23,9 +23,9 @@ func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, s
cliCfg := autopaho.ClientConfig{
ServerUrls: []*url.URL{u},
TlsCfg: &tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{clientCert},
MinVersion: tls.VersionTLS13,
RootCAs: caCertPool,
GetClientCertificate: clientCertStore.getClientCertficate,
MinVersion: tls.VersionTLS13,
},
KeepAlive: mqttKeepAlive,
OnConnectionUp: func(*autopaho.ConnectionManager, *paho.Connack) { edm.log.Info("mqtt connection up") },
Expand Down
113 changes: 83 additions & 30 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,34 @@ type prevSessions struct {
rotationTime time.Time
}

type certStore struct {
cert *tls.Certificate
mtx sync.RWMutex
}

// Implements tls.Config.GetClientCertificate
func (cs *certStore) getClientCertficate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
cs.mtx.RLock()
defer cs.mtx.RUnlock()
return cs.cert, nil
}

func (cs *certStore) setCert(certPath string, keyPath string) error {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return fmt.Errorf("unable to load x509 cert: %w", err)
}
cs.mtx.Lock()
cs.cert = &cert
cs.mtx.Unlock()

return nil
}

func newCertStore() *certStore {
return &certStore{}
}

func (edm *dnstapMinimiser) setHistogramLabels(labels []string, labelLimit int, hd *histogramData) {
// If labels is nil (the "." zone) we can depend on the zero type of
// the label fields being nil, so nothing to do
Expand Down Expand Up @@ -413,7 +441,7 @@ func setHllDefaults() error {
return err
}

func (edm *dnstapMinimiser) setupHistogramSender() {
func (edm *dnstapMinimiser) setupHistogramSender(httpClientCertStore *certStore) {
httpURL, err := url.Parse(viper.GetString("http-url"))
if err != nil {
edm.log.Error("unable to parse 'http-url' setting", "error", err)
Expand All @@ -438,16 +466,10 @@ func (edm *dnstapMinimiser) setupHistogramSender() {
}
}

httpClientCert, err := tls.LoadX509KeyPair(viper.GetString("http-client-cert-file"), viper.GetString("http-client-key-file"))
if err != nil {
edm.log.Error("unable to load x509 HTTP client cert", "error", err)
os.Exit(1)
}

edm.aggregSender = edm.newAggregateSender(httpURL, viper.GetString("http-signing-key-id"), httpSigningKey, httpCACertPool, httpClientCert)
edm.aggregSender = edm.newAggregateSender(httpURL, viper.GetString("http-signing-key-id"), httpSigningKey, httpCACertPool, httpClientCertStore)
}

func (edm *dnstapMinimiser) setupMQTT() {
func (edm *dnstapMinimiser) setupMQTT(mqttClientCertStore *certStore) {
mqttSigningKey, err := ecdsaPrivateKeyFromFile(viper.GetString("mqtt-signing-key-file"))
if err != nil {
edm.log.Error("unable to parse key material from 'mqtt-signing-key-file'", "error", err)
Expand Down Expand Up @@ -475,14 +497,7 @@ func (edm *dnstapMinimiser) setupMQTT() {
}
}

// Setup client cert/key for mTLS authentication
mqttClientCert, err := tls.LoadX509KeyPair(viper.GetString("mqtt-client-cert-file"), viper.GetString("mqtt-client-key-file"))
if err != nil {
edm.log.Error("unable to load x509 mqtt client cert", "error", err)
os.Exit(1)
}

autopahoConfig, err := edm.newAutoPahoClientConfig(mqttCACertPool, viper.GetString("mqtt-server"), viper.GetString("mqtt-client-id"), mqttClientCert, uint16(viper.GetInt("mqtt-keepalive")))
autopahoConfig, err := edm.newAutoPahoClientConfig(mqttCACertPool, viper.GetString("mqtt-server"), viper.GetString("mqtt-client-id"), mqttClientCertStore, uint16(viper.GetInt("mqtt-keepalive")))
if err != nil {
edm.log.Error("unable to create autopaho config", "error", err)
os.Exit(1)
Expand Down Expand Up @@ -639,11 +654,13 @@ func (edm *dnstapMinimiser) fsEventWatcher() {
timers := map[string]*time.Timer{}
timersMutex := new(sync.Mutex)

callbackHandler := func(callback func(string) error, name string) func() {
callbackHandler := func(callbacks []func(string) error, name string) func() {
return func() {
err := callback(name)
if err != nil {
edm.log.Error("fsEventWatcher: callback error", "filename", name, "error", err)
for _, callback := range callbacks {
err := callback(name)
if err != nil {
edm.log.Error("fsEventWatcher: callback error", "filename", name, "error", err)
}
}

// Cleanup expired timer
Expand All @@ -668,7 +685,7 @@ func (edm *dnstapMinimiser) fsEventWatcher() {
cleanName := filepath.Clean(event.Name)

edm.fsWatcherMutex.RLock()
callback, ok := edm.fsWatcherFuncs[cleanName]
callbacks, ok := edm.fsWatcherFuncs[cleanName]
edm.fsWatcherMutex.RUnlock()
if !ok {
if edm.debug {
Expand All @@ -681,7 +698,7 @@ func (edm *dnstapMinimiser) fsEventWatcher() {
t, ok := timers[cleanName]
timersMutex.Unlock()
if !ok {
t = time.AfterFunc(math.MaxInt64, callbackHandler(callback, cleanName))
t = time.AfterFunc(math.MaxInt64, callbackHandler(callbacks, cleanName))
t.Stop()

timersMutex.Lock()
Expand Down Expand Up @@ -709,7 +726,7 @@ func (edm *dnstapMinimiser) registerFSWatcher(filename string, callback func(str
}

edm.fsWatcherMutex.Lock()
edm.fsWatcherFuncs[filename] = callback
edm.fsWatcherFuncs[filename] = append(edm.fsWatcherFuncs[filename], callback)
edm.fsWatcherMutex.Unlock()

return nil
Expand Down Expand Up @@ -763,8 +780,6 @@ func Run(logger *slog.Logger) {
os.Exit(1)
}

go edm.fsEventWatcher()

viperNotifyCh := make(chan fsnotify.Event)

go configUpdater(viperNotifyCh, edm)
Expand All @@ -787,13 +802,51 @@ func Run(logger *slog.Logger) {
}()

if !edm.histogramSenderDisabled {
edm.setupHistogramSender()
// Setup client cert/key for mTLS authentication
httpClientCertStore := newCertStore()
err = httpClientCertStore.setCert(viper.GetString("http-client-cert-file"), viper.GetString("http-client-key-file"))
if err != nil {
edm.log.Error("unable to load x509 HTTP client cert", "error", err)
os.Exit(1)
}

edm.setupHistogramSender(httpClientCertStore)

err = edm.registerFSWatcher(viper.GetString("http-client-cert-file"), func(filename string) error {
edm.log.Info("reloading HTTP cert store because file was modified", "filename", filename)
err := httpClientCertStore.setCert(viper.GetString("http-client-cert-file"), viper.GetString("http-client-key-file"))
return err
})
if err != nil {
logger.Error("unable to register fsWatcher callback", "filename", viper.GetString("http-client-cert-file"), "error", err)
os.Exit(1)
}
}

if !edm.mqttDisabled {
edm.setupMQTT()
// Setup client cert/key for mTLS authentication
mqttClientCertStore := newCertStore()
err = mqttClientCertStore.setCert(viper.GetString("mqtt-client-cert-file"), viper.GetString("mqtt-client-key-file"))
if err != nil {
edm.log.Error("unable to load x509 mqtt client cert", "error", err)
os.Exit(1)
}

edm.setupMQTT(mqttClientCertStore)

err = edm.registerFSWatcher(viper.GetString("mqtt-client-cert-file"), func(filename string) error {
edm.log.Info("reloading MQTT cert store because file was modified", "filename", filename)
err := mqttClientCertStore.setCert(viper.GetString("mqtt-client-cert-file"), viper.GetString("mqtt-client-key-file"))
return err
})
if err != nil {
logger.Error("unable to register fsWatcher callback", "filename", viper.GetString("mqtt-client-cert-file"), "error", err)
os.Exit(1)
}
}

go edm.fsEventWatcher()

// Setup the dnstap.Input, only one at a time is supported.
var dti *dnstap.FrameStreamSockInput
if viper.GetString("input-unix") != "" {
Expand Down Expand Up @@ -1029,7 +1082,7 @@ type dnstapMinimiser struct {
ignoredQuestions dawg.Finder
ignoredQuestionsMutex sync.RWMutex
fsWatcher *fsnotify.Watcher
fsWatcherFuncs map[string]func(string) error
fsWatcherFuncs map[string][]func(string) error
fsWatcherMutex sync.RWMutex
}

Expand Down Expand Up @@ -1132,7 +1185,7 @@ func newDnstapMinimiser(logger *slog.Logger, cryptopanKey string, cryptopanSalt
return nil, fmt.Errorf("newDnstapMinimiser: unable to create fsWatcher: %w", err)
}

edm.fsWatcherFuncs = map[string]func(string) error{}
edm.fsWatcherFuncs = map[string][]func(string) error{}

// Setup channels for feeding writers and data senders that should do
// their work outside the main minimiser loop. They are buffered to
Expand Down

0 comments on commit 8410b24

Please sign in to comment.