diff --git a/config/config.go b/config/config.go index 5f9df1d99e..e62ce592cc 100644 --- a/config/config.go +++ b/config/config.go @@ -245,6 +245,7 @@ type Config struct { ExpectContinueTimeoutBackend time.Duration `yaml:"expect-continue-timeout-backend"` MaxIdleConnsBackend int `yaml:"max-idle-connection-backend"` DisableHTTPKeepalives bool `yaml:"disable-http-keepalives"` + EnableHttp2Cleartext bool `yaml:"enable-http2-cleartext"` // swarm: EnableSwarm bool `yaml:"enable-swarm"` @@ -523,6 +524,7 @@ func NewConfig() *Config { flag.IntVar(&cfg.MaxIdleConnsBackend, "max-idle-connection-backend", 0, "sets the maximum idle connections for all backend connections") flag.BoolVar(&cfg.DisableHTTPKeepalives, "disable-http-keepalives", false, "forces backend to always create a new connection") flag.BoolVar(&cfg.KubernetesEnableTLS, "kubernetes-enable-tls", false, "enable using kubnernetes resources to terminate tls") + flag.BoolVar(&cfg.EnableHttp2Cleartext, "enable-http2-cleartext", false, "enables HTTP/2 connections over cleartext TCP") // Swarm: flag.BoolVar(&cfg.EnableSwarm, "enable-swarm", false, "enable swarm communication between nodes in a skipper fleet") @@ -850,6 +852,7 @@ func (c *Config) ToOptions() skipper.Options { MaxIdleConnsBackend: c.MaxIdleConnsBackend, DisableHTTPKeepalives: c.DisableHTTPKeepalives, KubernetesEnableTLS: c.KubernetesEnableTLS, + EnableHttp2Cleartext: c.EnableHttp2Cleartext, // swarm: EnableSwarm: c.EnableSwarm, diff --git a/net/shutdown_listener.go b/net/shutdown_listener.go new file mode 100644 index 0000000000..c0d2293d1b --- /dev/null +++ b/net/shutdown_listener.go @@ -0,0 +1,81 @@ +package net + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" +) + +type ( + ShutdownListener struct { + net.Listener + activeConns atomic.Int64 + } + + shutdownListenerConn struct { + net.Conn + listener *ShutdownListener + closeOnce sync.Once + } +) + +var _ net.Listener = &ShutdownListener{} + +func NewShutdownListener(l net.Listener) *ShutdownListener { + return &ShutdownListener{Listener: l} +} + +func (l *ShutdownListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + l.registerConn() + + return &shutdownListenerConn{Conn: c, listener: l}, nil +} + +func (l *ShutdownListener) Close() error { + err := l.Listener.Close() + return err +} + +func (l *ShutdownListener) Shutdown(ctx context.Context) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + n := l.activeConns.Load() + log.Debugf("ShutdownListener Shutdown: %d connections", n) + if n == 0 { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +func (c *shutdownListenerConn) Close() error { + err := c.Conn.Close() + + c.closeOnce.Do(func() { c.listener.unregisterConn() }) + + return err +} + +func (l *ShutdownListener) registerConn() { + n := l.activeConns.Add(1) + log.Debugf("ShutdownListener registerConn: %d connections", n) +} + +func (l *ShutdownListener) unregisterConn() { + n := l.activeConns.Add(-1) + log.Debugf("ShutdownListener unregisterConn: %d connections", n) +} diff --git a/skipper.go b/skipper.go index c331d999d6..4ffb6966ba 100644 --- a/skipper.go +++ b/skipper.go @@ -22,6 +22,8 @@ import ( ot "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "github.com/zalando/skipper/circuit" "github.com/zalando/skipper/dataclients/kubernetes" @@ -42,7 +44,7 @@ import ( "github.com/zalando/skipper/loadbalancer" "github.com/zalando/skipper/logging" "github.com/zalando/skipper/metrics" - skpnet "github.com/zalando/skipper/net" + snet "github.com/zalando/skipper/net" pauth "github.com/zalando/skipper/predicates/auth" "github.com/zalando/skipper/predicates/content" "github.com/zalando/skipper/predicates/cookie" @@ -381,6 +383,9 @@ type Options struct { // a backend to always create a new connection. DisableHTTPKeepalives bool + // EnableHttp2Cleartext enables HTTP/2 connections over cleartext TCP. + EnableHttp2Cleartext bool + // Flag indicating to ignore trailing slashes in paths during route // lookup. IgnoreTrailingSlash bool @@ -1233,11 +1238,31 @@ func listenAndServeQuit( } } + if o.EnableHttp2Cleartext { + if serveTLS { + return fmt.Errorf("HTTP/2 connections over cleartext TCP are not supported when TLS is enabled") + } + + h2srv := &http2.Server{} + srv.Handler = h2c.NewHandler(srv.Handler, h2srv) + + // Work around https://github.com/golang/go/issues/26682 + // http2.ConfigureServer registers unexported h2srv graceful shutdown handler on srv shutdown - + // it calls srv.RegisterOnShutdown(h2srv.state.startGracefulShutdown). + // h2srv graceful shutdown handler sends GOAWAY frame to all connections and closes them after predefined delay. + // + // srv.Shutdown() runs h2srv shutdown handler in a goroutine so a special snet.ShutdownListener + // waits until all connections are closed. + http2.ConfigureServer(srv, h2srv) + } + log.Infof("Listen on %v", address) - l, err := listen(o, address, mtr) - if err != nil { + var listener *snet.ShutdownListener + if l, err := listen(o, address, mtr); err != nil { return err + } else { + listener = snet.NewShutdownListener(l) } // making idleConnsCH and sigs optional parameters is required to be able to tear down a server @@ -1258,10 +1283,16 @@ func listenAndServeQuit( log.Infof("Got shutdown signal, wait %v for health check", o.WaitForHealthcheckInterval) time.Sleep(o.WaitForHealthcheckInterval) - log.Info("Start shutdown") + log.Info("Start server shutdown") if err := srv.Shutdown(context.Background()); err != nil { - log.Errorf("Failed to graceful shutdown: %v", err) + log.Errorf("Failed to gracefully shutdown: %v", err) + } + + log.Info("Start listener shutdown") + if err := listener.Shutdown(context.Background()); err != nil { + log.Errorf("Failed to gracefully shutdown listener: %v", err) } + close(idleConnsCH) }() @@ -1281,20 +1312,21 @@ func listenAndServeQuit( }() } - if err := srv.ServeTLS(l, "", ""); err != http.ErrServerClosed { + if err := srv.ServeTLS(listener, "", ""); err != http.ErrServerClosed { log.Errorf("ServeTLS failed: %v", err) return err } } else { log.Infof("TLS settings not found, defaulting to HTTP") - if err := srv.Serve(l); err != http.ErrServerClosed { + if err := srv.Serve(listener); err != http.ErrServerClosed { log.Errorf("Serve failed: %v", err) return err } } <-idleConnsCH + log.Infof("done.") return nil } @@ -1580,13 +1612,13 @@ func run(o Options, sig chan os.Signal, idleConnsCH chan struct{}) error { } var swarmer ratelimit.Swarmer - var redisOptions *skpnet.RedisOptions + var redisOptions *snet.RedisOptions log.Infof("enable swarm: %v", o.EnableSwarm) if o.EnableSwarm { if len(o.SwarmRedisURLs) > 0 || o.KubernetesRedisServiceName != "" || o.SwarmRedisEndpointsRemoteURL != "" { log.Infof("Redis based swarm with %d shards", len(o.SwarmRedisURLs)) - redisOptions = &skpnet.RedisOptions{ + redisOptions = &snet.RedisOptions{ Addrs: o.SwarmRedisURLs, Password: o.SwarmRedisPassword, HashAlgorithm: o.SwarmRedisHashAlgorithm, diff --git a/skipper_test.go b/skipper_test.go index 4fe3445b5b..d74b4bd632 100644 --- a/skipper_test.go +++ b/skipper_test.go @@ -13,6 +13,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.org/x/net/http2" "github.com/zalando/skipper/dataclients/routestring" "github.com/zalando/skipper/filters" @@ -38,6 +39,45 @@ const ( listenTimeout = 9 * listenDelay ) +type protocol int + +const ( + protoHTTP protocol = iota + protoHTTPS + protoH2C +) + +func (p protocol) scheme() string { + return [...]string{"http", "https", "http"}[p] +} + +func (p protocol) newClient() *http.Client { + switch p { + case protoHTTP: + return &http.Client{} + case protoHTTPS: + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + case protoH2C: + return &http.Client{ + Transport: &http2.Transport{ + // allow http scheme + AllowHTTP: true, + // ignore tls.Config and dial unencrypted TCP + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + return net.Dial(network, addr) + }, + }, + } + } + return nil +} + func listenAndServe(proxy http.Handler, o *Options) error { return listenAndServeQuit(proxy, o, nil, nil, nil, nil) } @@ -69,12 +109,9 @@ func waitConn(req func() (*http.Response, error)) (*http.Response, error) { } } -func waitConnGet(url string) (*http.Response, error) { +func waitConnGet(proto protocol, address string) (*http.Response, error) { return waitConn(func() (*http.Response, error) { - return (&http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true}}}).Get(url) + return proto.newClient().Get(proto.scheme() + "://" + address) }) } @@ -214,7 +251,7 @@ func TestHTTPSServer(t *testing.T) { defer proxy.Close() go listenAndServe(proxy, &o) - r, err := waitConnGet("https://" + o.Address) + r, err := waitConnGet(protoHTTPS, o.Address) if r != nil { defer r.Body.Close() } @@ -229,7 +266,7 @@ func TestHTTPSServer(t *testing.T) { t.Fatalf("Failed to stream response body: %v", err) } - r, err = waitConnGet("http://" + o.InsecureAddress) + r, err = waitConnGet(protoHTTP, o.InsecureAddress) if r != nil { defer r.Body.Close() } @@ -267,7 +304,7 @@ func TestHTTPServer(t *testing.T) { proxy := proxy.New(rt, proxy.OptionsNone) defer proxy.Close() go listenAndServe(proxy, &o) - r, err := waitConnGet("http://" + o.Address) + r, err := waitConnGet(protoHTTP, o.Address) if r != nil { defer r.Body.Close() } @@ -285,7 +322,7 @@ func TestHTTPServer(t *testing.T) { func TestServerShutdownHTTP(t *testing.T) { o := &Options{} - testServerShutdown(t, o, "http") + testServerShutdown(t, o, protoHTTP, nil) } func TestServerShutdownHTTPS(t *testing.T) { @@ -293,7 +330,60 @@ func TestServerShutdownHTTPS(t *testing.T) { CertPathTLS: "fixtures/test.crt", KeyPathTLS: "fixtures/test.key", } - testServerShutdown(t, o, "https") + testServerShutdown(t, o, protoHTTPS, nil) +} + +func TestServerShutdownH2C(t *testing.T) { + const connectionShutdownChecks = 2 + errc := make(chan error, connectionShutdownChecks) + + testGracefulConnectionShutdown := func(address string) { + for i := 0; i < connectionShutdownChecks; i++ { + go func() { + errc <- h2cConnectAndWaitForGoAwayFrame(address) + }() + } + } + + o := &Options{} + testServerShutdown(t, o, protoH2C, testGracefulConnectionShutdown) + + for i := 0; i < connectionShutdownChecks; i++ { + require.NoError(t, <-errc, "Expected to receive GOAWAY frame on shutdown") + } +} + +// h2cConnectAndWaitForGoAwayFrame connects to address using http2 over cleartext protocol and waits for GOAWAY frame. +// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.8 +func h2cConnectAndWaitForGoAwayFrame(address string) error { + var conn net.Conn + var err error + + for i := 0; i < 3; i++ { + if conn, err = net.Dial("tcp", address); err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + if err != nil { + return err + } + defer conn.Close() + + if _, err := io.WriteString(conn, http2.ClientPreface); err != nil { + return err + } + + framer := http2.NewFramer(conn, conn) + for { + f, err := framer.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame: %w", err) + } + if _, ok := f.(*http2.GoAwayFrame); ok { + return nil + } + } } type responseOrError struct { @@ -301,14 +391,13 @@ type responseOrError struct { err error } -func testServerShutdown(t *testing.T, o *Options, scheme string) { +func testServerShutdown(t *testing.T, o *Options, proto protocol, beforeShutdown func(string)) { const shutdownDelay = 1 * time.Second address, err := findAddress() require.NoError(t, err) o.Address, o.WaitForHealthcheckInterval = address, shutdownDelay - testUrl := scheme + "://" + address // simulate a backend that got a request and should be handled correctly dc, err := routestring.New(`r0: * -> latency("3s") -> inlineContent("OK") -> status(200) -> `) @@ -330,6 +419,10 @@ func testServerShutdown(t *testing.T, o *Options, scheme string) { require.NoError(t, err) }() + if beforeShutdown != nil { + beforeShutdown(address) + } + // initiate shutdown sigs <- syscall.SIGTERM @@ -339,10 +432,9 @@ func testServerShutdown(t *testing.T, o *Options, scheme string) { roeCh := make(chan responseOrError) go func() { - rsp, err := waitConnGet(testUrl) + rsp, err := waitConnGet(proto, address) roeCh <- responseOrError{rsp, err} }() - time.Sleep(shutdownDelay) t.Logf("We are 1.5x past the shutdown delay, so shutdown should have been started") @@ -351,7 +443,7 @@ func testServerShutdown(t *testing.T, o *Options, scheme string) { case <-roeCh: t.Fatalf("Request should still be in progress after shutdown started") default: - _, err = waitConnGet(testUrl) + _, err = waitConnGet(proto, address) assert.ErrorContains(t, err, "connection refused", "Another request should fail after shutdown started") } @@ -366,7 +458,7 @@ func testServerShutdown(t *testing.T, o *Options, scheme string) { select { case <-done: case <-time.After(1 * time.Second): - t.Errorf("Shutdown takes too long after request is finished") + t.Errorf("Shutdown takes too long after all requests are finished") } }