Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add graceful shutdown #599

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log"
"net"
"net/http"
"os"
"strings"
"time"

Expand Down Expand Up @@ -71,29 +72,33 @@ func main() {

flag.Parse()

sh := utils.NewShutdownHandler(2 * time.Second)

// Create KV store for persistence
options := gomap.DefaultOptions
options.Codec = utils.ProtoCodec{}
// TODO: we can change to redis or badger at any given time
store := gomap.NewStore(options)
defer func(store gokv.Store) {
err := store.Close()
if err != nil {
log.Panic(err)
}
}(store)
sh.AddGokvStore(store)

go runGatewayServer(grpcPort, httpPort)
runGrpcServer(grpcPort, useKvm, store, spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles)
runGrpcServer(grpcPort, useKvm, store, spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles, sh)
runGatewayServer(grpcPort, httpPort, sh)

if err := sh.RunAndWait(); err != nil {
log.Printf("Bridge error: %v", err)
os.Exit(-1)
}
log.Print("Bridge successfully stopped")
}

func runGrpcServer(grpcPort int, useKvm bool, store gokv.Store, spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles string) {
func runGrpcServer(
grpcPort int,
useKvm bool,
store gokv.Store,
spdkAddress, qmpAddress, ctrlrDir, busesStr, tlsFiles string,
sh *utils.ShutdownHandler) {
tp := utils.InitTracerProvider("opi-spdk-bridge")
defer func() {
if err := tp.Shutdown(context.Background()); err != nil {
log.Panicf("Tracer Provider Shutdown: %v", err)
}
}()
sh.AddTraceProvider(tp)

buses := splitBusesBySeparator(busesStr)

Expand Down Expand Up @@ -171,16 +176,17 @@ func runGrpcServer(grpcPort int, useKvm bool, store gokv.Store, spdkAddress, qmp

reflection.Register(s)

log.Printf("gRPC server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil {
log.Panicf("failed to serve: %v", err)
}
sh.AddGrpcServer(s, lis)
}

func runGatewayServer(grpcPort int, httpPort int) {
func runGatewayServer(grpcPort int, httpPort int, sh *utils.ShutdownHandler) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
sh.AddShutdown(func(_ context.Context) error {
log.Println("Canceling context to close HTTP gateway endpoint to gRPC server")
cancel()
return nil
})

// Register gRPC server endpoint
// Note: Make sure the gRPC server is running properly and accessible
Expand All @@ -192,15 +198,11 @@ func runGatewayServer(grpcPort int, httpPort int) {
}

// Start HTTP server (and proxy calls to gRPC server endpoint)
log.Printf("HTTP Server listening at %v", httpPort)
server := &http.Server{
Addr: fmt.Sprintf(":%d", httpPort),
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
}
err = server.ListenAndServe()
if err != nil {
log.Panic("cannot start HTTP gateway server")
}
sh.AddHTTPServer(server)
}
217 changes: 217 additions & 0 deletions pkg/utils/shutdown.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2023 Intel Corporation

// Package utils contains utility functions
package utils

import (
"context"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"

"github.com/philippgille/gokv"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
)

// ServeFunc function to run service job
type ServeFunc func() error

// ShutdownFunc function to perform shutdown of a service
type ShutdownFunc func(ctx context.Context) error

// ShutdownHandler is responsible for running services and perform their shutdowns
// on service error or signals
type ShutdownHandler struct {
waitSignal chan os.Signal
timeoutPerShutdown time.Duration

mu sync.Mutex
serves []ServeFunc
shutdowns []ShutdownFunc

eg *errgroup.Group
egCtx context.Context
}

// NewShutdownHandler creates an instance of ShutdownHandler
func NewShutdownHandler(
timeoutPerShutdown time.Duration,
) *ShutdownHandler {
eg, egCtx := errgroup.WithContext(context.Background())

return &ShutdownHandler{
waitSignal: make(chan os.Signal, 1),
timeoutPerShutdown: timeoutPerShutdown,

mu: sync.Mutex{},
serves: []ServeFunc{},
shutdowns: []ShutdownFunc{},

eg: eg,
egCtx: egCtx,
}
}

// AddServe adds a service to run ant corresponding shutdown
func (s *ShutdownHandler) AddServe(serve ServeFunc, shutdown ShutdownFunc) {
s.mu.Lock()
defer s.mu.Unlock()
s.serves = append(s.serves, serve)
s.shutdowns = append(s.shutdowns, shutdown)
}

// AddShutdown add a shutdown procedure to execute.
// Shutdowns are executed in backward order
func (s *ShutdownHandler) AddShutdown(shutdown ShutdownFunc) {
s.mu.Lock()
defer s.mu.Unlock()
s.shutdowns = append(s.shutdowns, shutdown)
}

// AddGrpcServer adds serve and shutdown procedures for provided gRPC server
func (s *ShutdownHandler) AddGrpcServer(server *grpc.Server, lis net.Listener) {
s.AddServe(
func() error {
log.Printf("gRPC Server listening at %v", lis.Addr())
return server.Serve(lis)
},
func(ctx context.Context) error {
log.Println("Stopping gRPC Server")
return runWithCtx(ctx, func() error {
server.GracefulStop()
return nil
})
},
)
}

// AddHTTPServer adds serve and shutdown procedures for provided HTTP server
func (s *ShutdownHandler) AddHTTPServer(server *http.Server) {
s.AddServe(
func() error {
log.Printf("HTTP Server listening at %v", server.Addr)
err := server.ListenAndServe()
if errors.Is(err, http.ErrServerClosed) {
return nil
}

return err
},
func(ctx context.Context) error {
log.Println("Stopping HTTP Server")
err := server.Shutdown(ctx)
if err != nil {
cerr := server.Close()
log.Println("HTTP server close error:", cerr)
}
return err
},
)
}

// AddGokvStore adds gokv shutdown procedure
func (s *ShutdownHandler) AddGokvStore(store gokv.Store) {
s.AddShutdown(func(ctx context.Context) error {
log.Println("Stopping gokv storage")
return runWithCtx(ctx, func() error {
return store.Close()
})
})
}

// AddTraceProvider adds trace provider shutdown procedure
func (s *ShutdownHandler) AddTraceProvider(tp *sdktrace.TracerProvider) {
s.AddShutdown(func(ctx context.Context) error {
log.Println("Stopping tracer")
return tp.Shutdown(ctx)
})
}

// RunAndWait runs all services and execute shutdowns on a signal received
func (s *ShutdownHandler) RunAndWait() error {
for i := range s.serves {
fn := s.serves[i]
s.eg.Go(func() error {
return wrapServeFuncPanic(fn)()
})
}

s.eg.Go(func() error {
signal.Notify(s.waitSignal, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-s.waitSignal:
log.Printf("Got signal: %v", sig)
case <-s.egCtx.Done():
// can be reached if any Serve returned an error. Thus, initiating shutdown
log.Println("A process from errgroup exited with error:", s.egCtx.Err())
}
log.Printf("Start graceful shutdown with timeout per shutdown call: %v", s.timeoutPerShutdown)

s.mu.Lock()
defer s.mu.Unlock()

var err error
for i := len(s.shutdowns) - 1; i >= 0; i-- {
timeoutCtx, cancel := context.WithTimeout(context.Background(), s.timeoutPerShutdown)
defer cancel()
shutdownFn := wrapShutdownFuncPanic(s.shutdowns[i])
err = errors.Join(err, shutdownFn(timeoutCtx))
}

return err
})

return s.eg.Wait()
}

func wrapServeFuncPanic(fn ServeFunc) ServeFunc {
return func() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("was panic for serve function, recovered value: %v", r)
}
}()
err = fn()
return err
}
}

func wrapShutdownFuncPanic(fn ShutdownFunc) ShutdownFunc {
return func(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("was panic for shutdown function, recovered value: %v", r)
}
}()
err = fn(ctx)
return err
}
}

func runWithCtx(ctx context.Context, fn func() error) error {
var err error

stopped := make(chan struct{}, 1)
go func() {
err = fn()
stopped <- struct{}{}
}()

select {
case <-ctx.Done():
err = ctx.Err()
case <-stopped:
}

return err
}
Loading