Skip to content

Commit

Permalink
Merge pull request #1155 from DirectXMan12/bug/webhook-server-threadsafe
Browse files Browse the repository at this point in the history
🐛 Ensure that webhook server is thread/start-safe
  • Loading branch information
k8s-ci-robot authored Sep 22, 2020
2 parents ea6a506 + 5f1af13 commit 5757a38
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 44 deletions.
2 changes: 1 addition & 1 deletion pkg/builder/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
)

var _ = Describe("application", func() {
var _ = Describe("webhook", func() {
var stop chan struct{}

BeforeEach(func() {
Expand Down
21 changes: 19 additions & 2 deletions pkg/envtest/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ type WebhookInstallOptions struct {
// it will be automatically populated by the local temp dir
LocalServingCertDir string

// CAData is the CA that can be used to trust the serving certificates in LocalServingCertDir.
LocalServingCAData []byte

// MaxTime is the max time to wait
MaxTime time.Duration

Expand Down Expand Up @@ -143,8 +146,12 @@ func (o *WebhookInstallOptions) generateHostPort() (string, error) {
return net.JoinHostPort(host, fmt.Sprintf("%d", port)), nil
}

// Install installs specified webhooks to the API server
func (o *WebhookInstallOptions) Install(config *rest.Config) error {
// PrepWithoutInstalling does the setup parts of Install (populating host-port,
// setting up CAs, etc), without actually truing to do anything with webhook
// definitions. This is largely useful for internal testing of
// controller-runtime, where we need a random host-port & caData for webhook
// tests, but may be useful in similar scenarios.
func (o *WebhookInstallOptions) PrepWithoutInstalling() error {
hookCA, err := o.setupCA()
if err != nil {
return err
Expand All @@ -158,6 +165,15 @@ func (o *WebhookInstallOptions) Install(config *rest.Config) error {
return err
}

return nil
}

// Install installs specified webhooks to the API server
func (o *WebhookInstallOptions) Install(config *rest.Config) error {
if err := o.PrepWithoutInstalling(); err != nil {
return err
}

if err := createWebhooks(config, o.MutatingWebhooks, o.ValidatingWebhooks); err != nil {
return err
}
Expand Down Expand Up @@ -273,6 +289,7 @@ func (o *WebhookInstallOptions) setupCA() ([]byte, error) {
return nil, fmt.Errorf("unable to write webhook serving key to disk: %v", err)
}

o.LocalServingCAData = certData
return certData, nil
}

Expand Down
23 changes: 19 additions & 4 deletions pkg/manager/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,32 @@ func (cm *controllerManager) GetAPIReader() client.Reader {
}

func (cm *controllerManager) GetWebhookServer() *webhook.Server {
if cm.webhookServer == nil {
server, wasNew := func() (*webhook.Server, bool) {
cm.mu.Lock()
defer cm.mu.Unlock()

if cm.webhookServer != nil {
return cm.webhookServer, false
}

cm.webhookServer = &webhook.Server{
Port: cm.port,
Host: cm.host,
CertDir: cm.certDir,
}
if err := cm.Add(cm.webhookServer); err != nil {
panic("unable to add webhookServer to the controller manager")
return cm.webhookServer, true
}()

// only add the server if *we ourselves* just registered it.
// Add has its own lock, so just do this separately -- there shouldn't
// be a "race" in this lock gap because the condition is the population
// of cm.webhookServer, not anything to do with Add.
if wasNew {
if err := cm.Add(server); err != nil {
panic("unable to add webhook server to the controller manager")
}
}
return cm.webhookServer
return server
}

func (cm *controllerManager) GetLogger() logr.Logger {
Expand Down
59 changes: 43 additions & 16 deletions pkg/webhook/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type Server struct {

// defaultingOnce ensures that the default fields are only ever set once.
defaultingOnce sync.Once

// mu protects access to the webhook map & setFields for Start, Register, etc
mu sync.Mutex
}

// setDefaults does defaulting for the Server.
Expand Down Expand Up @@ -111,6 +114,9 @@ func (*Server) NeedLeaderElection() bool {
// Register marks the given webhook as being served at the given path.
// It panics if two hooks are registered on the same path.
func (s *Server) Register(path string, hook http.Handler) {
s.mu.Lock()
defer s.mu.Unlock()

s.defaultingOnce.Do(s.setDefaults)
_, found := s.webhooks[path]
if found {
Expand All @@ -119,7 +125,28 @@ func (s *Server) Register(path string, hook http.Handler) {
// TODO(directxman12): call setfields if we've already started the server
s.webhooks[path] = hook
s.WebhookMux.Handle(path, instrumentedHook(path, hook))
log.Info("registering webhook", "path", path)

regLog := log.WithValues("path", path)
regLog.Info("registering webhook")

// we've already been "started", inject dependencies here.
// Otherwise, InjectFunc will do this for us later.
if s.setFields != nil {
if err := s.setFields(hook); err != nil {
// TODO(directxman12): swallowing this error isn't great, but we'd have to
// change the signature to fix that
regLog.Error(err, "unable to inject fields into webhook during registration")
}

baseHookLog := log.WithName("webhooks")

// NB(directxman12): we don't propagate this further by wrapping setFields because it's
// unclear if this is how we want to deal with log propagation. In this specific instance,
// we want to be able to pass a logger to webhooks because they don't know their own path.
if _, err := inject.LoggerInto(baseHookLog.WithValues("webhook", path), hook); err != nil {
regLog.Error(err, "unable to logger into webhook during registration")
}
}
}

// instrumentedHook adds some instrumentation on top of the given webhook.
Expand Down Expand Up @@ -151,21 +178,6 @@ func (s *Server) Start(stop <-chan struct{}) error {
baseHookLog := log.WithName("webhooks")
baseHookLog.Info("starting webhook server")

// inject fields here as opposed to in Register so that we're certain to have our setFields
// function available.
for hookPath, webhook := range s.webhooks {
if err := s.setFields(webhook); err != nil {
return err
}

// NB(directxman12): we don't propagate this further by wrapping setFields because it's
// unclear if this is how we want to deal with log propagation. In this specific instance,
// we want to be able to pass a logger to webhooks because they don't know their own path.
if _, err := inject.LoggerInto(baseHookLog.WithValues("webhook", hookPath), webhook); err != nil {
return err
}
}

certPath := filepath.Join(s.CertDir, s.CertName)
keyPath := filepath.Join(s.CertDir, s.KeyName)

Expand Down Expand Up @@ -238,5 +250,20 @@ func (s *Server) Start(stop <-chan struct{}) error {
// InjectFunc injects the field setter into the server.
func (s *Server) InjectFunc(f inject.Func) error {
s.setFields = f

// inject fields here that weren't injected in Register because we didn't have setFields yet.
baseHookLog := log.WithName("webhooks")
for hookPath, webhook := range s.webhooks {
if err := s.setFields(webhook); err != nil {
return err
}

// NB(directxman12): we don't propagate this further by wrapping setFields because it's
// unclear if this is how we want to deal with log propagation. In this specific instance,
// we want to be able to pass a logger to webhooks because they don't know their own path.
if _, err := inject.LoggerInto(baseHookLog.WithValues("webhook", hookPath), webhook); err != nil {
return err
}
}
return nil
}
187 changes: 187 additions & 0 deletions pkg/webhook/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
Copyright 2019 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package webhook_test

import (
"fmt"
"io/ioutil"
"net"
"net/http"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"k8s.io/client-go/rest"
"sigs.k8s.io/controller-runtime/pkg/envtest"
"sigs.k8s.io/controller-runtime/pkg/webhook"
)

var _ = Describe("Webhook Server", func() {
var (
stop chan struct{}
testHostPort string
client *http.Client
server *webhook.Server
)

BeforeEach(func() {
stop = make(chan struct{})
// closed in indivual tests differently

servingOpts := envtest.WebhookInstallOptions{}
Expect(servingOpts.PrepWithoutInstalling()).To(Succeed())

testHostPort = net.JoinHostPort(servingOpts.LocalServingHost, fmt.Sprintf("%d", servingOpts.LocalServingPort))

// bypass needing to set up the x509 cert pool, etc ourselves
clientTransport, err := rest.TransportFor(&rest.Config{
TLSClientConfig: rest.TLSClientConfig{CAData: servingOpts.LocalServingCAData},
})
Expect(err).NotTo(HaveOccurred())
client = &http.Client{
Transport: clientTransport,
}

server = &webhook.Server{
Host: servingOpts.LocalServingHost,
Port: servingOpts.LocalServingPort,
CertDir: servingOpts.LocalServingCertDir,
}

// TODO(directxman12): cleanup generated certificate dir, etc
})

startServer := func() (done <-chan struct{}) {
doneCh := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(doneCh)
Expect(server.Start(stop)).To(Succeed())
}()
// wait till we can ping the server to start the test
Eventually(func() error {
_, err := client.Get(fmt.Sprintf("https://%s/unservedpath", testHostPort))
return err
}).Should(Succeed())

// this is normally called before Start by the manager
Expect(server.InjectFunc(func(i interface{}) error {
boolInj, canInj := i.(interface{ InjectBool(bool) error })
if !canInj {
return nil
}
return boolInj.InjectBool(true)
})).To(Succeed())

return doneCh
}

// TODO(directxman12): figure out a good way to test all the serving setup
// with httptest.Server to get all the niceness from that.

Context("when serving", func() {
PIt("should verify the client CA name when asked to", func() {

})
PIt("should support HTTP/2", func() {

})

// TODO(directxman12): figure out a good way to test the port default, etc
})

It("should panic if a duplicate path is registered", func() {
server.Register("/somepath", &testHandler{})
doneCh := startServer()

Expect(func() { server.Register("/somepath", &testHandler{}) }).To(Panic())

close(stop)
Eventually(doneCh, "4s").Should(BeClosed())
})

Context("when registering new webhooks before starting", func() {
It("should serve a webhook on the requested path", func() {
server.Register("/somepath", &testHandler{})

doneCh := startServer()

Eventually(func() ([]byte, error) {
resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
}).Should(Equal([]byte("gadzooks!")))

close(stop)
Eventually(doneCh, "4s").Should(BeClosed())
})

It("should inject dependencies eventually, given an inject func is eventually provided", func() {
handler := &testHandler{}
server.Register("/somepath", handler)
doneCh := startServer()

Eventually(func() bool { return handler.injectedField }).Should(BeTrue())

close(stop)
Eventually(doneCh, "4s").Should(BeClosed())
})
})

Context("when registering webhooks after starting", func() {
var (
doneCh <-chan struct{}
)
BeforeEach(func() {
doneCh = startServer()
})
AfterEach(func() {
// wait for cleanup to happen
close(stop)
Eventually(doneCh, "4s").Should(BeClosed())
})

It("should serve a webhook on the requested path", func() {
server.Register("/somepath", &testHandler{})
resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()

Expect(ioutil.ReadAll(resp.Body)).To(Equal([]byte("gadzooks!")))
})

It("should inject dependencies, if an inject func has been provided already", func() {
handler := &testHandler{}
server.Register("/somepath", handler)
Expect(handler.injectedField).To(BeTrue())
})
})
})

type testHandler struct {
injectedField bool
}

func (t *testHandler) InjectBool(val bool) error {
t.injectedField = val
return nil
}
func (t *testHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
if _, err := resp.Write([]byte("gadzooks!")); err != nil {
panic("unable to write http response!")
}
}
Loading

0 comments on commit 5757a38

Please sign in to comment.