diff --git a/app.go b/app.go index 7f9193a1a1..f6a24546a8 100644 --- a/app.go +++ b/app.go @@ -900,16 +900,18 @@ func (app *App) ShutdownWithTimeout(timeout time.Duration) error { // // ShutdownWithContext does not close keepalive connections so its recommended to set ReadTimeout to something else than 0. func (app *App) ShutdownWithContext(ctx context.Context) error { - if app.hooks != nil { - // TODO: check should be defered? - app.hooks.executeOnShutdownHooks() - } - app.mutex.Lock() defer app.mutex.Unlock() + if app.server == nil { return ErrNotRunning } + + // Execute the Shutdown hook + if app.hooks != nil { + app.hooks.executeOnPreShutdownHooks() + } + return app.server.ShutdownWithContext(ctx) } diff --git a/app_test.go b/app_test.go index a99796a2c1..cd919b21f6 100644 --- a/app_test.go +++ b/app_test.go @@ -21,6 +21,7 @@ import ( "regexp" "runtime" "strings" + "sync/atomic" "testing" "time" @@ -861,6 +862,13 @@ func Test_App_ShutdownWithContext(t *testing.T) { t.Parallel() app := New() + var shutdownHookCalled atomic.Int32 + + app.Hooks().OnPreShutdown(func() error { + shutdownHookCalled.Store(1) + return nil + }) + app.Get("/", func(ctx Ctx) error { time.Sleep(5 * time.Second) return ctx.SendString("body") @@ -868,24 +876,27 @@ func Test_App_ShutdownWithContext(t *testing.T) { ln := fasthttputil.NewInmemoryListener() + serverErr := make(chan error, 1) go func() { - err := app.Listener(ln) - assert.NoError(t, err) + serverErr <- app.Listener(ln) }() - time.Sleep(1 * time.Second) + time.Sleep(100 * time.Millisecond) + clientDone := make(chan struct{}) go func() { conn, err := ln.Dial() assert.NoError(t, err) - - _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")) + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) assert.NoError(t, err) + close(clientDone) }() - time.Sleep(1 * time.Second) + <-clientDone + // Sleep to ensure the server has started processing the request + time.Sleep(100 * time.Millisecond) - shutdownErr := make(chan error) + shutdownErr := make(chan error, 1) go func() { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -893,13 +904,19 @@ func Test_App_ShutdownWithContext(t *testing.T) { }() select { - case <-time.After(5 * time.Second): - t.Fatal("idle connections not closed on shutdown") + case <-time.After(2 * time.Second): + t.Fatal("shutdown did not complete in time") case err := <-shutdownErr: - if err == nil || !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded) - } + require.Error(t, err, "Expected shutdown to return an error due to timeout") + require.ErrorIs(t, err, context.DeadlineExceeded, "Expected DeadlineExceeded error") } + + assert.Equal(t, int32(1), shutdownHookCalled.Load(), "Shutdown hook was not called") + + err := <-serverErr + assert.NoError(t, err, "Server should have shut down without error") + // default: + // Server is still running, which is expected as the long-running request prevented full shutdown } // go test -run Test_App_Mixed_Routes_WithSameLen diff --git a/docs/api/fiber.md b/docs/api/fiber.md index a34a5d944c..d0723ad474 100644 --- a/docs/api/fiber.md +++ b/docs/api/fiber.md @@ -228,7 +228,7 @@ Shutdown gracefully shuts down the server without interrupting any active connec ShutdownWithTimeout will forcefully close any active connections after the timeout expires. -ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. +ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. Shutdown hooks will still be executed, even if an error occurs during the shutdown process, as they are deferred to ensure cleanup happens regardless of errors. ```go func (app *App) Shutdown() error diff --git a/hooks.go b/hooks.go index 3da5c671ff..314717d04b 100644 --- a/hooks.go +++ b/hooks.go @@ -6,14 +6,15 @@ import ( // OnRouteHandler Handlers define a function to create hooks for Fiber. type ( - OnRouteHandler = func(Route) error - OnNameHandler = OnRouteHandler - OnGroupHandler = func(Group) error - OnGroupNameHandler = OnGroupHandler - OnListenHandler = func(ListenData) error - OnShutdownHandler = func() error - OnForkHandler = func(int) error - OnMountHandler = func(*App) error + OnRouteHandler = func(Route) error + OnNameHandler = OnRouteHandler + OnGroupHandler = func(Group) error + OnGroupNameHandler = OnGroupHandler + OnListenHandler = func(ListenData) error + OnPreShutdownHandler = func() error + OnPostShutdownHandler = func(error) error + OnForkHandler = func(int) error + OnMountHandler = func(*App) error ) // Hooks is a struct to use it with App. @@ -22,14 +23,15 @@ type Hooks struct { app *App // Hooks - onRoute []OnRouteHandler - onName []OnNameHandler - onGroup []OnGroupHandler - onGroupName []OnGroupNameHandler - onListen []OnListenHandler - onShutdown []OnShutdownHandler - onFork []OnForkHandler - onMount []OnMountHandler + onRoute []OnRouteHandler + onName []OnNameHandler + onGroup []OnGroupHandler + onGroupName []OnGroupNameHandler + onListen []OnListenHandler + onPreShutdown []OnPreShutdownHandler + onPostShutdown []OnPostShutdownHandler + onFork []OnForkHandler + onMount []OnMountHandler } // ListenData is a struct to use it with OnListenHandler @@ -41,15 +43,16 @@ type ListenData struct { func newHooks(app *App) *Hooks { return &Hooks{ - app: app, - onRoute: make([]OnRouteHandler, 0), - onGroup: make([]OnGroupHandler, 0), - onGroupName: make([]OnGroupNameHandler, 0), - onName: make([]OnNameHandler, 0), - onListen: make([]OnListenHandler, 0), - onShutdown: make([]OnShutdownHandler, 0), - onFork: make([]OnForkHandler, 0), - onMount: make([]OnMountHandler, 0), + app: app, + onRoute: make([]OnRouteHandler, 0), + onGroup: make([]OnGroupHandler, 0), + onGroupName: make([]OnGroupNameHandler, 0), + onName: make([]OnNameHandler, 0), + onListen: make([]OnListenHandler, 0), + onPreShutdown: make([]OnPreShutdownHandler, 0), + onPostShutdown: make([]OnPostShutdownHandler, 0), + onFork: make([]OnForkHandler, 0), + onMount: make([]OnMountHandler, 0), } } @@ -96,10 +99,17 @@ func (h *Hooks) OnListen(handler ...OnListenHandler) { h.app.mutex.Unlock() } -// OnShutdown is a hook to execute user functions after Shutdown. -func (h *Hooks) OnShutdown(handler ...OnShutdownHandler) { +// OnPreShutdown is a hook to execute user functions before Shutdown. +func (h *Hooks) OnPreShutdown(handler ...OnPreShutdownHandler) { h.app.mutex.Lock() - h.onShutdown = append(h.onShutdown, handler...) + h.onPreShutdown = append(h.onPreShutdown, handler...) + h.app.mutex.Unlock() +} + +// OnPostShutdown is a hook to execute user functions after Shutdown. +func (h *Hooks) OnPostShutdown(handler ...OnPostShutdownHandler) { + h.app.mutex.Lock() + h.onPostShutdown = append(h.onPostShutdown, handler...) h.app.mutex.Unlock() } @@ -191,10 +201,18 @@ func (h *Hooks) executeOnListenHooks(listenData ListenData) error { return nil } -func (h *Hooks) executeOnShutdownHooks() { - for _, v := range h.onShutdown { +func (h *Hooks) executeOnPreShutdownHooks() { + for _, v := range h.onPreShutdown { if err := v(); err != nil { - log.Errorf("failed to call shutdown hook: %v", err) + log.Errorf("failed to call pre shutdown hook: %v", err) + } + } +} + +func (h *Hooks) executeOnPostShutdownHooks(err error) { + for _, v := range h.onPostShutdown { + if err := v(err); err != nil { + log.Errorf("failed to call post shutdown hook: %v", err) } } } diff --git a/hooks_test.go b/hooks_test.go index f96f570706..b39146b15e 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -181,22 +181,92 @@ func Test_Hook_OnGroupName_Error(t *testing.T) { grp.Get("/test", testSimpleHandler) } -func Test_Hook_OnShutdown(t *testing.T) { +func Test_Hook_OnPrehutdown(t *testing.T) { t.Parallel() app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) - app.Hooks().OnShutdown(func() error { - _, err := buf.WriteString("shutdowning") + app.Hooks().OnPreShutdown(func() error { + _, err := buf.WriteString("pre-shutdowning") require.NoError(t, err) return nil }) require.NoError(t, app.Shutdown()) - require.Equal(t, "shutdowning", buf.String()) + require.Equal(t, "pre-shutdowning", buf.String()) +} + +func Test_Hook_OnPostShutdown(t *testing.T) { + t.Run("should execute post shutdown hook with error", func(t *testing.T) { + app := New() + + hookCalled := false + var receivedErr error + expectedErr := errors.New("test shutdown error") + + app.Hooks().OnPostShutdown(func(err error) error { + hookCalled = true + receivedErr = err + return nil + }) + + go func() { + _ = app.Listen(":0") + }() + + time.Sleep(100 * time.Millisecond) + + app.hooks.executeOnPostShutdownHooks(expectedErr) + + if !hookCalled { + t.Fatal("hook was not called") + } + + if receivedErr != expectedErr { + t.Fatalf("hook received wrong error: want %v, got %v", expectedErr, receivedErr) + } + }) + + t.Run("should execute multiple hooks in order", func(t *testing.T) { + app := New() + + execution := make([]int, 0) + + app.Hooks().OnPostShutdown(func(err error) error { + execution = append(execution, 1) + return nil + }) + + app.Hooks().OnPostShutdown(func(err error) error { + execution = append(execution, 2) + return nil + }) + + app.hooks.executeOnPostShutdownHooks(nil) + + if len(execution) != 2 { + t.Fatalf("expected 2 hooks to execute, got %d", len(execution)) + } + + if execution[0] != 1 || execution[1] != 2 { + t.Fatal("hooks executed in wrong order") + } + }) + + t.Run("should handle hook error", func(t *testing.T) { + app := New() + hookErr := errors.New("hook error") + + app.Hooks().OnPostShutdown(func(err error) error { + return hookErr + }) + + // Should not panic + app.hooks.executeOnPostShutdownHooks(nil) + }) } func Test_Hook_OnListen(t *testing.T) { diff --git a/listen.go b/listen.go index e0c5536968..8fdb0ff453 100644 --- a/listen.go +++ b/listen.go @@ -60,17 +60,6 @@ type ListenConfig struct { // Default: nil BeforeServeFunc func(app *App) error `json:"before_serve_func"` - // OnShutdownError allows to customize error behavior when to graceful shutdown server by given signal. - // - // Print error with log.Fatalf() by default. - // Default: nil - OnShutdownError func(err error) - - // OnShutdownSuccess allows to customize success behavior when to graceful shutdown server by given signal. - // - // Default: nil - OnShutdownSuccess func() - // AutoCertManager manages TLS certificates automatically using the ACME protocol, // Enables integration with Let's Encrypt or other ACME-compatible providers. // @@ -129,9 +118,6 @@ func listenConfigDefault(config ...ListenConfig) ListenConfig { if len(config) < 1 { return ListenConfig{ ListenerNetwork: NetworkTCP4, - OnShutdownError: func(err error) { - log.Fatalf("shutdown: %v", err) //nolint:revive // It's an option - }, ShutdownTimeout: 10 * time.Second, } } @@ -141,12 +127,6 @@ func listenConfigDefault(config ...ListenConfig) ListenConfig { cfg.ListenerNetwork = NetworkTCP4 } - if cfg.OnShutdownError == nil { - cfg.OnShutdownError = func(err error) { - log.Fatalf("shutdown: %v", err) //nolint:revive // It's an option - } - } - return cfg } @@ -502,11 +482,9 @@ func (app *App) gracefulShutdown(ctx context.Context, cfg ListenConfig) { } if err != nil { - cfg.OnShutdownError(err) + app.hooks.executeOnPostShutdownHooks(err) return } - if success := cfg.OnShutdownSuccess; success != nil { - success() - } + app.hooks.executeOnPostShutdownHooks(nil) } diff --git a/listen_test.go b/listen_test.go index 123cf2b3b8..f1376a6e7c 100644 --- a/listen_test.go +++ b/listen_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + "github.com/gofiber/utils/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" @@ -37,193 +38,111 @@ func Test_Listen(t *testing.T) { // go test -run Test_Listen_Graceful_Shutdown func Test_Listen_Graceful_Shutdown(t *testing.T) { - var mu sync.Mutex - var shutdown bool - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString(c.Hostname()) + t.Run("Basic Graceful Shutdown", func(t *testing.T) { + testGracefulShutdown(t, 0) }) - ln := fasthttputil.NewInmemoryListener() - errs := make(chan error) - - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - errs <- app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - GracefulContext: ctx, - OnShutdownSuccess: func() { - mu.Lock() - shutdown = true - mu.Unlock() - }, - }) - }() - - // Server readiness check - for i := 0; i < 10; i++ { - conn, err := ln.Dial() - if err == nil { - conn.Close() //nolint:errcheck // ignore error - break - } - // Wait a bit before retrying - time.Sleep(100 * time.Millisecond) - if i == 9 { - t.Fatalf("Server did not become ready in time: %v", err) - } - } - - testCases := []struct { - ExpectedErr error - ExpectedBody string - Time time.Duration - ExpectedStatusCode int - }{ - {Time: 500 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil}, - {Time: 3 * time.Second, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: fasthttputil.ErrInmemoryListenerClosed}, - } - - for _, tc := range testCases { - time.Sleep(tc.Time) - - req := fasthttp.AcquireRequest() - req.SetRequestURI("http://example.com") - - client := fasthttp.HostClient{} - client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - resp := fasthttp.AcquireResponse() - err := client.Do(req, resp) - - require.Equal(t, tc.ExpectedErr, err) - require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) - require.Equal(t, tc.ExpectedBody, string(resp.Body())) - - fasthttp.ReleaseRequest(req) - fasthttp.ReleaseResponse(resp) - } - - mu.Lock() - err := <-errs - require.True(t, shutdown) - require.NoError(t, err) - mu.Unlock() + t.Run("Shutdown With Timeout", func(t *testing.T) { + testGracefulShutdown(t, 500*time.Millisecond) + }) } -// go test -run Test_Listen_Graceful_Shutdown_Timeout -func Test_Listen_Graceful_Shutdown_Timeout(t *testing.T) { +func testGracefulShutdown(t *testing.T, shutdownTimeout time.Duration) { var mu sync.Mutex - var shutdownSuccess bool - var shutdownTimeoutError error + var shutdown bool app := New() - app.Get("/", func(c Ctx) error { return c.SendString(c.Hostname()) }) ln := fasthttputil.NewInmemoryListener() - errs := make(chan error) + errs := make(chan error, 1) - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() + app.hooks.OnPostShutdown(func(err error) error { + mu.Lock() + defer mu.Unlock() + shutdown = true + return nil + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + go func() { errs <- app.Listener(ln, ListenConfig{ DisableStartupMessage: true, GracefulContext: ctx, - ShutdownTimeout: 500 * time.Millisecond, - OnShutdownSuccess: func() { - mu.Lock() - shutdownSuccess = true - mu.Unlock() - }, - OnShutdownError: func(err error) { - mu.Lock() - shutdownTimeoutError = err - mu.Unlock() - }, + ShutdownTimeout: shutdownTimeout, }) }() - // Server readiness check - for i := 0; i < 10; i++ { + require.Eventually(t, func() bool { conn, err := ln.Dial() - // To test a graceful shutdown timeout, do not close the connection. if err == nil { - _ = conn - break - } - // Wait a bit before retrying - time.Sleep(100 * time.Millisecond) - if i == 9 { - t.Fatalf("Server did not become ready in time: %v", err) + conn.Close() + return true } + return false + }, time.Second, 100*time.Millisecond, "Server failed to become ready") + + client := fasthttp.HostClient{ + Dial: func(_ string) (net.Conn, error) { return ln.Dial() }, } testCases := []struct { - ExpectedErr error - ExpectedShutdownError error - ExpectedBody string - Time time.Duration - ExpectedStatusCode int - ExpectedShutdownSuccess bool + name string + waitTime time.Duration + expectedBody string + expectedStatusCode int + expectedErr error + closeConnection bool }{ { - Time: 100 * time.Millisecond, - ExpectedBody: "example.com", - ExpectedStatusCode: StatusOK, - ExpectedErr: nil, - ExpectedShutdownError: nil, - ExpectedShutdownSuccess: false, + name: "Server running normally", + waitTime: 500 * time.Millisecond, + expectedBody: "example.com", + expectedStatusCode: StatusOK, + expectedErr: nil, + closeConnection: true, }, { - Time: 3 * time.Second, - ExpectedBody: "", - ExpectedStatusCode: StatusOK, - ExpectedErr: fasthttputil.ErrInmemoryListenerClosed, - ExpectedShutdownError: context.DeadlineExceeded, - ExpectedShutdownSuccess: false, + name: "Server shutdown complete", + waitTime: 3 * time.Second, + expectedBody: "", + expectedStatusCode: StatusOK, + expectedErr: fasthttputil.ErrInmemoryListenerClosed, + closeConnection: true, }, } for _, tc := range testCases { - time.Sleep(tc.Time) - - req := fasthttp.AcquireRequest() - req.SetRequestURI("http://example.com") - - client := fasthttp.HostClient{} - client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - resp := fasthttp.AcquireResponse() - err := client.Do(req, resp) - - if err == nil { - require.NoError(t, err) - require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) - require.Equal(t, tc.ExpectedBody, string(resp.Body())) - } else { - require.ErrorIs(t, err, tc.ExpectedErr) - } - - mu.Lock() - require.Equal(t, tc.ExpectedShutdownSuccess, shutdownSuccess) - require.Equal(t, tc.ExpectedShutdownError, shutdownTimeoutError) - mu.Unlock() - - fasthttp.ReleaseRequest(req) - fasthttp.ReleaseResponse(resp) + tc := tc + t.Run(tc.name, func(t *testing.T) { + time.Sleep(tc.waitTime) + + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + req.SetRequestURI("http://example.com") + + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(resp) + + err := client.Do(req, resp) + + if tc.expectedErr == nil { + assert.NoError(t, err) + assert.Equal(t, tc.expectedStatusCode, resp.StatusCode()) + assert.Equal(t, tc.expectedBody, utils.UnsafeString(resp.Body())) + } else { + assert.ErrorIs(t, err, tc.expectedErr) + } + }) } mu.Lock() - err := <-errs - require.NoError(t, err) + assert.True(t, shutdown) + assert.NoError(t, <-errs) mu.Unlock() }