Skip to content

Commit

Permalink
override public rate limiting through context
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasJenicek committed Feb 26, 2024
1 parent bcc2818 commit 1911c42
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
run: go build -v ./...

- name: Test
run: go test -v ./...
run: make test
23 changes: 17 additions & 6 deletions middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ func (k *contextKey) String() string {
}

var (
ctxKeyAccessKey = &contextKey{"AccessKey"}
ctxKeyAccessQuota = &contextKey{"AccessQuota"}
ctxKeyComputeUnits = &contextKey{"ComputeUnits"}
ctxKeyRateLimitSkip = &contextKey{"RateLimitSkip"}
ctxKeyTime = &contextKey{"Time"}
ctxKeyResult = &contextKey{"Result"}
ctxKeyAccessKey = &contextKey{"AccessKey"}
ctxKeyAccessQuota = &contextKey{"AccessQuota"}
ctxKeyComputeUnits = &contextKey{"ComputeUnits"}
ctxKeyRateLimitSkip = &contextKey{"RateLimitSkip"}
ctxOverridePublicLimit = &contextKey{"OverridePublicLimit"}
ctxKeyTime = &contextKey{"Time"}
ctxKeyResult = &contextKey{"Result"}
)

// WithAccessKey adds the access key to the context.
Expand Down Expand Up @@ -86,6 +87,16 @@ func IsSkipRateLimit(ctx context.Context) bool {
return ok
}

// WithOverridePublicLimit override rate limiting configuration.
func WithOverridePublicLimit(ctx context.Context, limit int) context.Context {
return context.WithValue(ctx, ctxOverridePublicLimit, limit)
}

func GetOverridePublicLimit(ctx context.Context) (int, bool) {
limit, ok := ctx.Value(ctxOverridePublicLimit).(int)
return limit, ok
}

// WithComputeUnits sets the compute units to the context.
func WithComputeUnits(ctx context.Context, cu int64) context.Context {
return context.WithValue(ctx, ctxKeyComputeUnits, cu)
Expand Down
5 changes: 5 additions & 0 deletions ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ func NewHTTPRateLimiter(cfg Config, vary RateLimitVaryFn) func(next http.Handler
return
}

rpmPublicLimit, _ := middleware.GetOverridePublicLimit(ctx)
if rpmPublicLimit > 0 {
rlPublic = httprate.Limit(rpmPublicLimit, time.Minute, optsPublic...)(next)
}

// Rate limit
var rateLimitType RateLimitType
var rlKey string
Expand Down
32 changes: 32 additions & 0 deletions ratelimiter_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package quotacontrol

import (
"context"
"encoding/binary"
"encoding/json"
"github.com/0xsequence/quotacontrol/middleware"
"math/rand"
"net"
"net/http"
Expand Down Expand Up @@ -45,3 +47,33 @@ func TestRateLimiter(t *testing.T) {
assert.Equal(t, err.Message, _CustomErrorMessage)
}
}

func TestOverridePublicRateLimiting(t *testing.T) {
rl := NewHTTPRateLimiter(Config{
RateLimiter: RateLimiterConfig{
Enabled: true,
PublicRequestsPerMinute: 10,
ErrorMessage: "Custom error",
},
}, nil)
handler := rl(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
buf := make([]byte, 4)
for i := 11; i < 20; i++ {
ip := rand.Uint32()
binary.LittleEndian.PutUint32(buf, ip)
}
ipAddress := net.IP(buf).String()
srv := httptest.NewServer(handler)
for i := 0; i < 5; i++ {
ctx := middleware.WithOverridePublicLimit(context.Background(), 2)
req, _ := http.NewRequestWithContext(ctx, "GET", srv.URL, nil)
req.RemoteAddr = ipAddress
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if i < 2 {
assert.Equal(t, http.StatusOK, w.Code)
continue
}
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}
}

0 comments on commit 1911c42

Please sign in to comment.