diff --git a/cmd/extensions/main.go b/cmd/extensions/main.go index 5e5f253d99..a7ed0e9e47 100644 --- a/cmd/extensions/main.go +++ b/cmd/extensions/main.go @@ -17,7 +17,6 @@ package main import ( "context" - "crypto/tls" "io" "net/http" "os" @@ -36,7 +35,6 @@ import ( "agones.dev/agones/pkg/gameserversets" "agones.dev/agones/pkg/metrics" "agones.dev/agones/pkg/util/apiserver" - "agones.dev/agones/pkg/util/fswatch" "agones.dev/agones/pkg/util/https" "agones.dev/agones/pkg/util/runtime" "agones.dev/agones/pkg/util/signals" @@ -53,10 +51,6 @@ import ( "k8s.io/client-go/tools/clientcmd" ) -const ( - tlsDir = "/home/agones/certs/" -) - const ( enableStackdriverMetricsFlag = "stackdriver-exporter" stackdriverLabels = "stackdriver-labels" @@ -144,22 +138,7 @@ func main() { logger.WithError(err).Fatal("Could not initialize cloud product") } // https server and the items that share the Mux for routing - httpsServer := https.NewServer(ctlConf.CertFile, ctlConf.KeyFile) - - cancelTLS, err := fswatch.Watch(logger, tlsDir, time.Second, func() { - tlsCert, err := readTLSCert() - if err != nil { - logger.WithError(err).Error("could not load TLS certs; keeping old one") - return - } - httpsServer.SetCertificate(tlsCert) - logger.Info("TLS certs updated") - }) - if err != nil { - logger.WithError(err).Fatal("could not create watcher for TLS certs") - } - defer cancelTLS() - + httpsServer := https.NewServer(ctlConf.CertFile, ctlConf.KeyFile, logger) wh := webhooks.NewWebHook(httpsServer.Mux) api := apiserver.NewAPIServer(httpsServer.Mux) @@ -242,14 +221,6 @@ func main() { logger.Info("Shut down agones extensions") } -func readTLSCert() (*tls.Certificate, error) { - tlsCert, err := tls.LoadX509KeyPair(tlsDir+"server.crt", tlsDir+"server.key") - if err != nil { - return nil, err - } - return &tlsCert, nil -} - func parseEnvFlags() config { exec, err := os.Executable() if err != nil { diff --git a/pkg/util/https/server.go b/pkg/util/https/server.go index 0e8824d3fa..2e0645e5cf 100644 --- a/pkg/util/https/server.go +++ b/pkg/util/https/server.go @@ -19,12 +19,20 @@ import ( "crypto/tls" "net/http" "sync" + "time" + "agones.dev/agones/pkg/util/fswatch" "agones.dev/agones/pkg/util/runtime" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) +const ( + tlsDir = "/certs/" +) + +var tlsMutex sync.Mutex + // tls is a http server interface to enable easier testing type testTLS interface { Close() error @@ -38,26 +46,46 @@ type Server struct { logger *logrus.Entry Mux *http.ServeMux tls testTLS - certMu sync.RWMutex - cert *tls.Certificate certFile string keyFile string } // NewServer returns a Server instance. -func NewServer(certFile, keyFile string) *Server { +func NewServer(certFile, keyFile string, logger *logrus.Entry) *Server { mux := http.NewServeMux() - tls := &http.Server{ + tls_server := &http.Server{ Addr: ":8081", Handler: mux, } + go func() { + cancelTLS, err := fswatch.Watch(logger, tlsDir, time.Second, func() { + tlsCert, err := readTLSCert() + if err != nil { + logger.WithError(err).Error("could not load TLS certs; keeping old one") + return + } + tlsMutex.Lock() + defer tlsMutex.Unlock() + tls_server.TLSConfig = &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return tlsCert, nil + }, + } + logger.Info("TLS certs updated") + }) + if err != nil { + logger.WithError(err).Fatal("could not create watcher for TLS certs") + } + defer cancelTLS() + + }() + wh := &Server{ Mux: mux, - tls: tls, + tls: tls_server, certFile: certFile, keyFile: keyFile, - cert: nil, } wh.Mux.HandleFunc("/", wh.defaultHandler) wh.logger = runtime.NewLoggerWithType(wh) @@ -65,12 +93,6 @@ func NewServer(certFile, keyFile string) *Server { return wh } -func (s *Server) SetCertificate(cert *tls.Certificate) { - s.certMu.Lock() - defer s.certMu.Unlock() - s.cert = cert -} - // Run runs the webhook server, starting a https listener. // Will close the http server on stop channel close. func (s *Server) Run(ctx context.Context, _ int) error { @@ -101,3 +123,11 @@ func (s *Server) defaultHandler(w http.ResponseWriter, r *http.Request) { FourZeroFour(s.logger, w, r) } + +func readTLSCert() (*tls.Certificate, error) { + tlsCert, err := tls.LoadX509KeyPair(tlsDir+"server.crt", tlsDir+"server.key") + if err != nil { + return nil, err + } + return &tlsCert, nil +} diff --git a/pkg/util/https/server_test.go b/pkg/util/https/server_test.go index 95348e6308..4a9d97b0f2 100644 --- a/pkg/util/https/server_test.go +++ b/pkg/util/https/server_test.go @@ -41,7 +41,7 @@ func (ts *testServer) ListenAndServeTLS(certFile, keyFile string) error { func TestServerRun(t *testing.T) { t.Parallel() - s := NewServer("", "") + s := NewServer("", "",nil) ts := &testServer{server: httptest.NewUnstartedServer(s.Mux)} s.tls = ts