From 497c31ea3d70ebbe90db22958aa14d1d1c1736ca Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Wed, 10 Jul 2024 20:17:16 +0200 Subject: [PATCH] RateLimit credits only for Quota --- middleware/common.go | 2 -- middleware/middleware_ratelimit.go | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/middleware/common.go b/middleware/common.go index 4528842..f435cda 100644 --- a/middleware/common.go +++ b/middleware/common.go @@ -7,7 +7,6 @@ import ( "time" "github.com/0xsequence/quotacontrol/proto" - "github.com/go-chi/httprate" ) const ( @@ -205,7 +204,6 @@ func getProjectID(ctx context.Context) (uint64, bool) { // WithComputeUnits sets the compute units and rate limit increment to the context. func WithComputeUnits(ctx context.Context, cu int64) context.Context { - ctx = httprate.WithIncrement(ctx, int(cu)) return context.WithValue(ctx, ctxKeyComputeUnits, cu) } diff --git a/middleware/middleware_ratelimit.go b/middleware/middleware_ratelimit.go index 9dac3f5..99466a1 100644 --- a/middleware/middleware_ratelimit.go +++ b/middleware/middleware_ratelimit.go @@ -71,6 +71,9 @@ func RateLimit(rlCfg RLConfig, redisCfg redis.Config) func(next http.Handler) ht if _, ok := GetService(ctx); ok { ctx = httprate.WithRequestLimit(ctx, serviceRPM) } else if q, ok := GetAccessQuota(ctx); ok { + if cu, ok := getComputeUnits(ctx); ok { + ctx = httprate.WithIncrement(ctx, int(cu)) + } ctx = httprate.WithRequestLimit(ctx, int(q.Limit.RateLimit)) } else if _, ok := GetAccount(ctx); ok { ctx = httprate.WithRequestLimit(ctx, accountRPM)