diff --git a/pkg/runner/aggregate_sender.go b/pkg/runner/aggregate_sender.go index 734b296..23ff2ec 100644 --- a/pkg/runner/aggregate_sender.go +++ b/pkg/runner/aggregate_sender.go @@ -27,7 +27,7 @@ type aggregateSender struct { signingHTTPClient *httpsign.Client } -func (edm *dnstapMinimiser) newAggregateSender(aggrecURL *url.URL, signingKeyName string, signingKey *ecdsa.PrivateKey, caCertPool *x509.CertPool, clientCert tls.Certificate) aggregateSender { +func (edm *dnstapMinimiser) newAggregateSender(aggrecURL *url.URL, signingKeyName string, signingKey *ecdsa.PrivateKey, caCertPool *x509.CertPool, clientCertStore *certStore) aggregateSender { // Create HTTP handler for sending aggregate files to aggrec httpClient := http.Client{ Transport: &http.Transport{ @@ -38,9 +38,9 @@ func (edm *dnstapMinimiser) newAggregateSender(aggrecURL *url.URL, signingKeyNam TLSHandshakeTimeout: 10 * time.Second, ResponseHeaderTimeout: 10 * time.Second, TLSClientConfig: &tls.Config{ - RootCAs: caCertPool, - Certificates: []tls.Certificate{clientCert}, - MinVersion: tls.VersionTLS13, + RootCAs: caCertPool, + GetClientCertificate: clientCertStore.getClientCertficate, + MinVersion: tls.VersionTLS13, }, }, } diff --git a/pkg/runner/mqtt.go b/pkg/runner/mqtt.go index ae02789..1d9e8b3 100644 --- a/pkg/runner/mqtt.go +++ b/pkg/runner/mqtt.go @@ -14,7 +14,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jws" ) -func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, server string, clientID string, clientCert tls.Certificate, mqttKeepAlive uint16) (autopaho.ClientConfig, error) { +func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, server string, clientID string, clientCertStore *certStore, mqttKeepAlive uint16) (autopaho.ClientConfig, error) { u, err := url.Parse(server) if err != nil { return autopaho.ClientConfig{}, fmt.Errorf("newAutoPahoClientConfig: unable to parse URL: %w", err) @@ -23,9 +23,9 @@ func (edm *dnstapMinimiser) newAutoPahoClientConfig(caCertPool *x509.CertPool, s cliCfg := autopaho.ClientConfig{ ServerUrls: []*url.URL{u}, TlsCfg: &tls.Config{ - RootCAs: caCertPool, - Certificates: []tls.Certificate{clientCert}, - MinVersion: tls.VersionTLS13, + RootCAs: caCertPool, + GetClientCertificate: clientCertStore.getClientCertficate, + MinVersion: tls.VersionTLS13, }, KeepAlive: mqttKeepAlive, OnConnectionUp: func(*autopaho.ConnectionManager, *paho.Connack) { edm.log.Info("mqtt connection up") }, diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 564c293..2ebd8f1 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -162,6 +162,34 @@ type prevSessions struct { rotationTime time.Time } +type certStore struct { + cert *tls.Certificate + mtx sync.RWMutex +} + +// Implements tls.Config.GetClientCertificate +func (cs *certStore) getClientCertficate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + cs.mtx.RLock() + defer cs.mtx.RUnlock() + return cs.cert, nil +} + +func (cs *certStore) setCert(certPath string, keyPath string) error { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return fmt.Errorf("unable to load x509 cert: %w", err) + } + cs.mtx.Lock() + cs.cert = &cert + cs.mtx.Unlock() + + return nil +} + +func newCertStore() *certStore { + return &certStore{} +} + func (edm *dnstapMinimiser) setHistogramLabels(labels []string, labelLimit int, hd *histogramData) { // If labels is nil (the "." zone) we can depend on the zero type of // the label fields being nil, so nothing to do @@ -413,7 +441,7 @@ func setHllDefaults() error { return err } -func (edm *dnstapMinimiser) setupHistogramSender() { +func (edm *dnstapMinimiser) setupHistogramSender(httpClientCertStore *certStore) { httpURL, err := url.Parse(viper.GetString("http-url")) if err != nil { edm.log.Error("unable to parse 'http-url' setting", "error", err) @@ -438,16 +466,10 @@ func (edm *dnstapMinimiser) setupHistogramSender() { } } - httpClientCert, err := tls.LoadX509KeyPair(viper.GetString("http-client-cert-file"), viper.GetString("http-client-key-file")) - if err != nil { - edm.log.Error("unable to load x509 HTTP client cert", "error", err) - os.Exit(1) - } - - edm.aggregSender = edm.newAggregateSender(httpURL, viper.GetString("http-signing-key-id"), httpSigningKey, httpCACertPool, httpClientCert) + edm.aggregSender = edm.newAggregateSender(httpURL, viper.GetString("http-signing-key-id"), httpSigningKey, httpCACertPool, httpClientCertStore) } -func (edm *dnstapMinimiser) setupMQTT() { +func (edm *dnstapMinimiser) setupMQTT(mqttClientCertStore *certStore) { mqttSigningKey, err := ecdsaPrivateKeyFromFile(viper.GetString("mqtt-signing-key-file")) if err != nil { edm.log.Error("unable to parse key material from 'mqtt-signing-key-file'", "error", err) @@ -475,14 +497,7 @@ func (edm *dnstapMinimiser) setupMQTT() { } } - // Setup client cert/key for mTLS authentication - mqttClientCert, err := tls.LoadX509KeyPair(viper.GetString("mqtt-client-cert-file"), viper.GetString("mqtt-client-key-file")) - if err != nil { - edm.log.Error("unable to load x509 mqtt client cert", "error", err) - os.Exit(1) - } - - autopahoConfig, err := edm.newAutoPahoClientConfig(mqttCACertPool, viper.GetString("mqtt-server"), viper.GetString("mqtt-client-id"), mqttClientCert, uint16(viper.GetInt("mqtt-keepalive"))) + autopahoConfig, err := edm.newAutoPahoClientConfig(mqttCACertPool, viper.GetString("mqtt-server"), viper.GetString("mqtt-client-id"), mqttClientCertStore, uint16(viper.GetInt("mqtt-keepalive"))) if err != nil { edm.log.Error("unable to create autopaho config", "error", err) os.Exit(1) @@ -639,11 +654,13 @@ func (edm *dnstapMinimiser) fsEventWatcher() { timers := map[string]*time.Timer{} timersMutex := new(sync.Mutex) - callbackHandler := func(callback func(string) error, name string) func() { + callbackHandler := func(callbacks []func(string) error, name string) func() { return func() { - err := callback(name) - if err != nil { - edm.log.Error("fsEventWatcher: callback error", "filename", name, "error", err) + for _, callback := range callbacks { + err := callback(name) + if err != nil { + edm.log.Error("fsEventWatcher: callback error", "filename", name, "error", err) + } } // Cleanup expired timer @@ -668,7 +685,7 @@ func (edm *dnstapMinimiser) fsEventWatcher() { cleanName := filepath.Clean(event.Name) edm.fsWatcherMutex.RLock() - callback, ok := edm.fsWatcherFuncs[cleanName] + callbacks, ok := edm.fsWatcherFuncs[cleanName] edm.fsWatcherMutex.RUnlock() if !ok { if edm.debug { @@ -681,7 +698,7 @@ func (edm *dnstapMinimiser) fsEventWatcher() { t, ok := timers[cleanName] timersMutex.Unlock() if !ok { - t = time.AfterFunc(math.MaxInt64, callbackHandler(callback, cleanName)) + t = time.AfterFunc(math.MaxInt64, callbackHandler(callbacks, cleanName)) t.Stop() timersMutex.Lock() @@ -709,7 +726,7 @@ func (edm *dnstapMinimiser) registerFSWatcher(filename string, callback func(str } edm.fsWatcherMutex.Lock() - edm.fsWatcherFuncs[filename] = callback + edm.fsWatcherFuncs[filename] = append(edm.fsWatcherFuncs[filename], callback) edm.fsWatcherMutex.Unlock() return nil @@ -763,8 +780,6 @@ func Run(logger *slog.Logger) { os.Exit(1) } - go edm.fsEventWatcher() - viperNotifyCh := make(chan fsnotify.Event) go configUpdater(viperNotifyCh, edm) @@ -787,13 +802,51 @@ func Run(logger *slog.Logger) { }() if !edm.histogramSenderDisabled { - edm.setupHistogramSender() + // Setup client cert/key for mTLS authentication + httpClientCertStore := newCertStore() + err = httpClientCertStore.setCert(viper.GetString("http-client-cert-file"), viper.GetString("http-client-key-file")) + if err != nil { + edm.log.Error("unable to load x509 HTTP client cert", "error", err) + os.Exit(1) + } + + edm.setupHistogramSender(httpClientCertStore) + + err = edm.registerFSWatcher(viper.GetString("http-client-cert-file"), func(filename string) error { + edm.log.Info("reloading HTTP cert store because file was modified", "filename", filename) + err := httpClientCertStore.setCert(viper.GetString("http-client-cert-file"), viper.GetString("http-client-key-file")) + return err + }) + if err != nil { + logger.Error("unable to register fsWatcher callback", "filename", viper.GetString("http-client-cert-file"), "error", err) + os.Exit(1) + } } if !edm.mqttDisabled { - edm.setupMQTT() + // Setup client cert/key for mTLS authentication + mqttClientCertStore := newCertStore() + err = mqttClientCertStore.setCert(viper.GetString("mqtt-client-cert-file"), viper.GetString("mqtt-client-key-file")) + if err != nil { + edm.log.Error("unable to load x509 mqtt client cert", "error", err) + os.Exit(1) + } + + edm.setupMQTT(mqttClientCertStore) + + err = edm.registerFSWatcher(viper.GetString("mqtt-client-cert-file"), func(filename string) error { + edm.log.Info("reloading MQTT cert store because file was modified", "filename", filename) + err := mqttClientCertStore.setCert(viper.GetString("mqtt-client-cert-file"), viper.GetString("mqtt-client-key-file")) + return err + }) + if err != nil { + logger.Error("unable to register fsWatcher callback", "filename", viper.GetString("mqtt-client-cert-file"), "error", err) + os.Exit(1) + } } + go edm.fsEventWatcher() + // Setup the dnstap.Input, only one at a time is supported. var dti *dnstap.FrameStreamSockInput if viper.GetString("input-unix") != "" { @@ -1029,7 +1082,7 @@ type dnstapMinimiser struct { ignoredQuestions dawg.Finder ignoredQuestionsMutex sync.RWMutex fsWatcher *fsnotify.Watcher - fsWatcherFuncs map[string]func(string) error + fsWatcherFuncs map[string][]func(string) error fsWatcherMutex sync.RWMutex } @@ -1132,7 +1185,7 @@ func newDnstapMinimiser(logger *slog.Logger, cryptopanKey string, cryptopanSalt return nil, fmt.Errorf("newDnstapMinimiser: unable to create fsWatcher: %w", err) } - edm.fsWatcherFuncs = map[string]func(string) error{} + edm.fsWatcherFuncs = map[string][]func(string) error{} // Setup channels for feeding writers and data senders that should do // their work outside the main minimiser loop. They are buffered to