Skip to content

Commit

Permalink
Merge pull request #12 from vasayxtx/avoid-panics-in-middleware
Browse files Browse the repository at this point in the history
Avoid implicit panics in middleware and middleware/throttle packages
  • Loading branch information
vasayxtx authored Sep 9, 2024
2 parents 55e5fab + d31d722 commit 6d0330c
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 85 deletions.
2 changes: 1 addition & 1 deletion httpserver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func NewConfig() *Config {
}

// NewConfigWithKeyPrefix creates a new instance of the Config.
// Allows to specify key prefix which will be used for parsing configuration parameters.
// Allows specifying key prefix which will be used for parsing configuration parameters.
func NewConfigWithKeyPrefix(keyPrefix string) *Config {
return &Config{keyPrefix: keyPrefix}
}
Expand Down
3 changes: 0 additions & 3 deletions httpserver/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ import (
"github.com/acronis/go-appkit/service"
)

// ErrInvalidMaxServingRequests error is returned when maximum number of currently serving requests is negative.
var ErrInvalidMaxServingRequests = errors.New("maximum number of currently serving requests must not be negative")

const (
networkTCP = "tcp"
networkUnix = "unix"
Expand Down
4 changes: 2 additions & 2 deletions httpserver/middleware/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func Example() {
ExcludedEndpoints: []string{"/metrics", "/healthz"}, // Metrics will not be collected for "/metrics" and "/healthz" endpoints.
}))

userCreateInFlightLimitMiddleware := InFlightLimitWithOpts(32, errDomain, InFlightLimitOpts{
userCreateInFlightLimitMiddleware := MustInFlightLimitWithOpts(32, errDomain, InFlightLimitOpts{
GetKey: func(r *http.Request) (string, bool, error) {
key := r.Header.Get("X-Client-ID")
return key, key == "", nil
Expand All @@ -49,7 +49,7 @@ func Example() {
BacklogTimeout: time.Second * 10,
})

usersListRateLimitMiddleware := RateLimit(Rate{Count: 100, Duration: time.Second}, errDomain)
usersListRateLimitMiddleware := MustRateLimit(Rate{Count: 100, Duration: time.Second}, errDomain)

router.Route("/users", func(r chi.Router) {
r.With(usersListRateLimitMiddleware).Get("/", func(rw http.ResponseWriter, req *http.Request) {
Expand Down
28 changes: 23 additions & 5 deletions httpserver/middleware/in_flight_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,28 @@ type InFlightLimitOpts struct {

// InFlightLimit is a middleware that limits the total number of currently served (in-flight) HTTP requests.
// It checks how many requests are in-flight and rejects with 503 if exceeded.
func InFlightLimit(limit int, errDomain string) func(next http.Handler) http.Handler {
func InFlightLimit(limit int, errDomain string) (func(next http.Handler) http.Handler, error) {
return InFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{})
}

// MustInFlightLimit is a version of InFlightLimit that panics on error.
func MustInFlightLimit(limit int, errDomain string) func(next http.Handler) http.Handler {
mw, err := InFlightLimit(limit, errDomain)
if err != nil {
panic(err)
}
return mw
}

// InFlightLimitWithOpts is a configurable version of a middleware to limit in-flight HTTP requests.
func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) func(next http.Handler) http.Handler {
func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) (func(next http.Handler) http.Handler, error) {
if limit <= 0 {
panic(fmt.Errorf("limit should be positive, got %d", limit))
return nil, fmt.Errorf("limit should be positive, got %d", limit)
}

backlogLimit := opts.BacklogLimit
if backlogLimit < 0 {
panic(fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit))
return nil, fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit)
}

backlogTimeout := opts.BacklogTimeout
Expand All @@ -120,7 +129,7 @@ func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts)

getSlots, err := makeInFlightLimitSlotsProvider(limit, backlogLimit, maxKeys)
if err != nil {
panic(fmt.Errorf("make in-flight limit slots provider: %w", err))
return nil, fmt.Errorf("make in-flight limit slots provider: %w", err)
}

respStatusCode := opts.ResponseStatusCode
Expand All @@ -141,7 +150,16 @@ func InFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts)
onReject: makeInFlightLimitOnRejectFunc(opts),
onError: makeInFlightLimitOnErrorFunc(opts),
}
}, nil
}

// MustInFlightLimitWithOpts is a version of InFlightLimitWithOpts that panics on error.
func MustInFlightLimitWithOpts(limit int, errDomain string, opts InFlightLimitOpts) func(next http.Handler) http.Handler {
mw, err := InFlightLimitWithOpts(limit, errDomain, opts)
if err != nil {
panic(err)
}
return mw
}

func (h *inFlightLimitHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
Expand Down
22 changes: 11 additions & 11 deletions httpserver/middleware/in_flight_limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
}
rw.WriteHeader(http.StatusOK)
})
handler := InFlightLimit(1, errDomain)(next)
handler := MustInFlightLimit(1, errDomain)(next)

respCode := make(chan int)
go func() {
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
}
rw.WriteHeader(http.StatusOK)
})
handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{BacklogLimit: 1, BacklogTimeout: backlogTimeout})(next)
handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{BacklogLimit: 1, BacklogTimeout: backlogTimeout})(next)

resp1Code := make(chan int)
go func() {
Expand Down Expand Up @@ -139,7 +139,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})

handler := InFlightLimit(limit, errDomain)(next)
handler := MustInFlightLimit(limit, errDomain)(next)
var wg sync.WaitGroup
for i := 0; i < reqsNum; i++ {
wg.Add(1)
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
}
rw.WriteHeader(http.StatusOK)
})
handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
GetKey: makeInFlightLimitGetKeyByHeader(headerClientID),
ResponseStatusCode: http.StatusTooManyRequests,
})(next)
Expand Down Expand Up @@ -235,7 +235,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
}
rw.WriteHeader(http.StatusOK)
})
handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
GetKey: func(r *http.Request) (string, bool, error) {
key := r.Header.Get(headerClientID)
return key, key == "", nil
Expand Down Expand Up @@ -308,7 +308,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
}
rw.WriteHeader(http.StatusOK)
})
handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
GetKey: func(r *http.Request) (string, bool, error) {
return r.Header.Get(headerClientID), false, nil
},
Expand Down Expand Up @@ -358,7 +358,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})

handler := InFlightLimitWithOpts(limitPerClient, errDomain, InFlightLimitOpts{
handler := MustInFlightLimitWithOpts(limitPerClient, errDomain, InFlightLimitOpts{
GetKey: makeInFlightLimitGetKeyByHeader(headerClientID),
ResponseStatusCode: http.StatusTooManyRequests,
})(next)
Expand Down Expand Up @@ -417,7 +417,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
}
rw.WriteHeader(http.StatusOK)
})
handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{
GetKey: makeInFlightLimitGetKeyByHeader(headerClientID),
MaxKeys: 1,
ResponseStatusCode: http.StatusTooManyRequests,
Expand Down Expand Up @@ -460,7 +460,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})

handler := InFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{
handler := MustInFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{
GetKey: makeInFlightLimitGetKeyByHeader(headerClientID),
MaxKeys: 1,
ResponseStatusCode: http.StatusTooManyRequests,
Expand Down Expand Up @@ -499,7 +499,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
<-reqContinued
rw.WriteHeader(http.StatusOK)
})
handler := InFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{GetRetryAfter: func(r *http.Request) time.Duration {
handler := MustInFlightLimitWithOpts(1, errDomain, InFlightLimitOpts{GetRetryAfter: func(r *http.Request) time.Duration {
return retryAfter
}})(next)

Expand Down Expand Up @@ -530,7 +530,7 @@ func TestInFlightLimitHandler_ServeHTTP(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})

handler := InFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{
handler := MustInFlightLimitWithOpts(limit, errDomain, InFlightLimitOpts{
DryRun: true,
BacklogLimit: backlogLimit,
BacklogTimeout: time.Millisecond * 10,
Expand Down
28 changes: 23 additions & 5 deletions httpserver/middleware/rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,24 @@ type Rate struct {
}

// RateLimit is a middleware that limits the rate of HTTP requests.
func RateLimit(maxRate Rate, errDomain string) func(next http.Handler) http.Handler {
func RateLimit(maxRate Rate, errDomain string) (func(next http.Handler) http.Handler, error) {
return RateLimitWithOpts(maxRate, errDomain, RateLimitOpts{GetRetryAfter: GetRetryAfterEstimatedTime})
}

// MustRateLimit is a version of RateLimit that panics if an error occurs.
func MustRateLimit(maxRate Rate, errDomain string) func(next http.Handler) http.Handler {
mw, err := RateLimit(maxRate, errDomain)
if err != nil {
panic(err)
}
return mw
}

// RateLimitWithOpts is a configurable version of a middleware to limit the rate of HTTP requests.
func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func(next http.Handler) http.Handler {
func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) (func(next http.Handler) http.Handler, error) {
backlogLimit := opts.BacklogLimit
if backlogLimit < 0 {
panic(fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit))
return nil, fmt.Errorf("backlog limit should not be negative, got %d", backlogLimit)
}
if opts.DryRun {
backlogLimit = 0
Expand Down Expand Up @@ -148,12 +157,12 @@ func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func(
}
limiter, err := makeLimiter()
if err != nil {
panic(err)
return nil, err
}

getBacklogSlots, err := makeRateLimitBacklogSlotsProvider(backlogLimit, maxKeys)
if err != nil {
panic(err)
return nil, fmt.Errorf("make rate limit backlog slots provider: %w", err)
}

backlogTimeout := opts.BacklogTimeout
Expand All @@ -174,7 +183,16 @@ func RateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func(
onReject: makeRateLimitOnRejectFunc(opts),
onError: makeRateLimitOnErrorFunc(opts),
}
}, nil
}

// MustRateLimitWithOpts is a version of RateLimitWithOpts that panics if an error occurs.
func MustRateLimitWithOpts(maxRate Rate, errDomain string, opts RateLimitOpts) func(next http.Handler) http.Handler {
mw, err := RateLimitWithOpts(maxRate, errDomain, opts)
if err != nil {
panic(err)
}
return mw
}

//nolint:funlen,gocyclo
Expand Down
22 changes: 11 additions & 11 deletions httpserver/middleware/rate_limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {

t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, no key", func(t *testing.T) {
next, nextServedCount := makeNext()
handler := RateLimit(Rate{1, time.Second}, errDomain)(next)
handler := MustRateLimit(Rate{1, time.Second}, errDomain)(next)
_ = sendReqAndCheckCode(t, handler, http.StatusOK, nil)
retryAfter := sendReqAndCheckCode(t, handler, http.StatusServiceUnavailable, nil)
time.Sleep(retryAfter)
Expand All @@ -84,7 +84,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {
wantRetryAfter := time.Second * time.Duration(math.Ceil(emissionInterval.Seconds()))

next, nextServedCount := makeNext()
handler := RateLimitWithOpts(rate, errDomain, RateLimitOpts{MaxBurst: maxBurst, GetRetryAfter: GetRetryAfterEstimatedTime})(next)
handler := MustRateLimitWithOpts(rate, errDomain, RateLimitOpts{MaxBurst: maxBurst, GetRetryAfter: GetRetryAfterEstimatedTime})(next)

sendNReqsConcurrentlyAndCheck := func(n int) {
var okCount, tooManyReqsCount, unexpectedCodeReqsCount, wrongRetryAfterReqsCount, getRetryAfterErrsCount atomic.Int32
Expand Down Expand Up @@ -140,7 +140,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {
t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, by key", func(t *testing.T) {
const headerClientID = "X-Client-ID"
next, nextServedCount := makeNext()
handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
GetKey: makeRateLimitGetKeyByHeader(headerClientID),
GetRetryAfter: GetRetryAfterEstimatedTime,
ResponseStatusCode: http.StatusTooManyRequests,
Expand All @@ -167,7 +167,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {
t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, by key, no bypass empty key", func(t *testing.T) {
const headerClientID = "X-Client-ID"
next, nextServedCount := makeNext()
handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
GetKey: func(r *http.Request) (key string, bypass bool, err error) {
return r.Header.Get(headerClientID), false, nil
},
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {
wantRetryAfter := time.Second * time.Duration(math.Ceil(emissionInterval.Seconds()))

next, nextServedCount := makeNext()
handler := RateLimitWithOpts(rate, errDomain, RateLimitOpts{
handler := MustRateLimitWithOpts(rate, errDomain, RateLimitOpts{
MaxBurst: maxBurst,
GetKey: makeRateLimitGetKeyByHeader(headerClientID),
GetRetryAfter: GetRetryAfterEstimatedTime,
Expand Down Expand Up @@ -311,7 +311,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {

t.Run("sliding window, maxRate=1r/s, no key", func(t *testing.T) {
next, nextServedCount := makeNext()
handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
Alg: RateLimitAlgSlidingWindow,
GetRetryAfter: GetRetryAfterEstimatedTime,
})(next)
Expand All @@ -330,7 +330,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {
rate := Rate{2, time.Second}

next, nextServedCount := makeNext()
handler := RateLimitWithOpts(rate, errDomain, RateLimitOpts{
handler := MustRateLimitWithOpts(rate, errDomain, RateLimitOpts{
Alg: RateLimitAlgSlidingWindow,
GetKey: makeRateLimitGetKeyByHeader(headerClientID),
GetRetryAfter: GetRetryAfterEstimatedTime,
Expand Down Expand Up @@ -387,7 +387,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {

t.Run("RetryAfter custom", func(t *testing.T) {
next, _ := makeNext()
handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
GetRetryAfter: func(r *http.Request, estimatedTime time.Duration) time.Duration {
return estimatedTime * 3
},
Expand All @@ -399,7 +399,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {

t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, backlogLimit=1, no key", func(t *testing.T) {
next, nextServedCount := makeNext()
handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{BacklogLimit: 1})(next)
handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{BacklogLimit: 1})(next)
sendReqAndCheckCode(t, handler, http.StatusOK, nil)
startTime := time.Now()
sendReqAndCheckCode(t, handler, http.StatusOK, nil)
Expand All @@ -425,7 +425,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {
t.Run("leaky bucket, maxRate=1r/m, maxBurst=0, backlogLimit=1, backlogTimeout=1s, no key", func(t *testing.T) {
next, nextServedCount := makeNext()
rateLimitOpts := RateLimitOpts{BacklogLimit: 1, BacklogTimeout: time.Second, GetRetryAfter: GetRetryAfterEstimatedTime}
handler := RateLimitWithOpts(Rate{1, time.Minute}, errDomain, rateLimitOpts)(next)
handler := MustRateLimitWithOpts(Rate{1, time.Minute}, errDomain, rateLimitOpts)(next)
sendReqAndCheckCode(t, handler, http.StatusOK, nil)
startTime := time.Now()
sendReqAndCheckCode(t, handler, http.StatusServiceUnavailable, nil)
Expand All @@ -436,7 +436,7 @@ func TestRateLimitHandler_ServeHTTP(t *testing.T) {
t.Run("leaky bucket, maxRate=1r/s, maxBurst=0, backlogLimit=1, by key", func(t *testing.T) {
const headerClientID = "X-Client-ID"
next, nextServedCount := makeNext()
handler := RateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
handler := MustRateLimitWithOpts(Rate{1, time.Second}, errDomain, RateLimitOpts{
GetKey: makeRateLimitGetKeyByHeader(headerClientID),
BacklogLimit: 1,
GetRetryAfter: GetRetryAfterEstimatedTime,
Expand Down
Loading

0 comments on commit 6d0330c

Please sign in to comment.