diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index f8d3c52..bda9bd1 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -15,83 +15,100 @@ import ( "time" "github.com/dicedb/dicedb-go" + "github.com/gin-gonic/gin" ) -// RateLimiter middleware to limit requests based on a specified limit and duration -func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { - configValue := config.LoadConfig() - cronFrequencyInterval := configValue.Server.CronCleanupFrequency - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if handleCors(w, r) { - return - } +type ( + RateLimiterMiddleware struct { + client *db.DiceDB + limit int64 + window float64 + conf *config.Config + cronFrequencyInterval time.Duration + } +) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() +func NewRateLimiterMiddleware(client *db.DiceDB, limit int64, window float64) (rl *RateLimiterMiddleware) { + rl = &RateLimiterMiddleware{ + client: client, + limit: limit, + window: window, + cronFrequencyInterval: config.LoadConfig().Server.CronCleanupFrequency, + } + return +} - // Only apply rate limiting for specific paths (e.g., "/cli/") - if !strings.Contains(r.URL.Path, "/shell/exec/") { - next.ServeHTTP(w, r) - return - } +// RateLimiter middleware to limit requests based on a specified limit and duration +func (rl *RateLimiterMiddleware) Exec(c *gin.Context) { + if handleCors(c.Writer, c.Request) { + return + } - // Generate the rate limiting key based on the current window - currentWindow := time.Now().Unix() / int64(window) - key := fmt.Sprintf("request_count:%d", currentWindow) - slog.Debug("Created rate limiter key", slog.Any("key", key)) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - // Get the current request count for this window - val, err := client.Client.Get(ctx, key).Result() - if err != nil && !errors.Is(err, dicedb.Nil) { - slog.Error("Error fetching request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } + // Only apply rate limiting for specific paths (e.g., "/cli/") + if !strings.Contains(c.Request.URL.Path, "/shell/exec/") { + c.Next() + return + } - // Parse the current request count or initialize to 0 - var requestCount int64 = 0 - if val != "" { - requestCount, err = strconv.ParseInt(val, 10, 64) - if err != nil { - slog.Error("Error converting request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - } + // Generate the rate limiting key based on the current window + currentWindow := time.Now().Unix() / int64(rl.window) + key := fmt.Sprintf("request_count:%d", currentWindow) + slog.Debug("Created rate limiter key", slog.Any("key", key)) + + // Get the current request count for this window + val, err := rl.client.Client.Get(ctx, key).Result() + if err != nil && !errors.Is(err, dicedb.Nil) { + slog.Error("Error fetching request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } - // Check if the request count exceeds the limit - if requestCount >= limit { - slog.Warn("Request limit exceeded", "count", requestCount) - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0) - http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) + // Parse the current request count or initialize to 0 + var requestCount int64 = 0 + if val != "" { + requestCount, err = strconv.ParseInt(val, 10, 64) + if err != nil { + slog.Error("Error converting request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) return } + } - // Increment the request count - if requestCount, err = client.Client.Incr(ctx, key).Result(); err != nil { - slog.Error("Error incrementing request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } + // Check if the request count exceeds the limit + if requestCount >= rl.limit { + slog.Warn("Request limit exceeded", "count", requestCount) + addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0) + http.Error(c.Writer, "429 - Too Many Requests", http.StatusTooManyRequests) + return + } - // Set the key expiry if it's newly created - if requestCount == 1 { - if err := client.Client.Expire(ctx, key, time.Duration(window)*time.Second).Err(); err != nil { - slog.Error("Error setting expiry for request count", "error", err) - } - } + // Increment the request count + if requestCount, err = rl.client.Client.Incr(ctx, key).Result(); err != nil { + slog.Error("Error incrementing request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } - secondsDifference, err := calculateNextCleanupTime(ctx, client, cronFrequencyInterval) - if err != nil { - slog.Error("Error calculating next cleanup time", "error", err) + // Set the key expiry if it's newly created + if requestCount == 1 { + if err := rl.client.Client.Expire(ctx, key, time.Duration(rl.window)*time.Second).Err(); err != nil { + slog.Error("Error setting expiry for request count", "error", err) } + } - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), - secondsDifference) + secondsDifference, err := calculateNextCleanupTime(ctx, rl.client, rl.cronFrequencyInterval) + if err != nil { + slog.Error("Error calculating next cleanup time", "error", err) + } - slog.Info("Request processed", "count", requestCount+1) - next.ServeHTTP(w, r) - }) + addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), + secondsDifference) + + slog.Info("Request processed", "count", requestCount+1) + c.Next() } func calculateNextCleanupTime(ctx context.Context, client *db.DiceDB, cronFrequencyInterval time.Duration) (int64, error) { @@ -186,6 +203,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) { + w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10)) w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10)) w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10)) diff --git a/internal/server/http.go b/internal/server/http.go index b5209c2..2553a97 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -12,6 +12,8 @@ import ( "server/internal/db" "server/internal/middleware" util "server/util" + + "github.com/gin-gonic/gin" ) type HTTPServer struct { @@ -50,19 +52,13 @@ func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { })).ServeHTTP(w, r) } -func NewHTTPServer(addr string, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, +func NewHTTPServer(router *gin.Engine, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, limit int64, window float64) *HTTPServer { - handlerMux := &HandlerMux{ - mux: mux, - rateLimiter: func(w http.ResponseWriter, r *http.Request, next http.Handler) { - middleware.RateLimiter(diceDBAdminClient, next, limit, window).ServeHTTP(w, r) - }, - } return &HTTPServer{ httpServer: &http.Server{ - Addr: addr, - Handler: handlerMux, + Addr: ":8080", + Handler: router, ReadHeaderTimeout: 5 * time.Second, }, DiceClient: diceClient, diff --git a/main.go b/main.go index 9d60133..f36bde0 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "os" "server/config" "server/internal/db" + "server/internal/middleware" "server/internal/server" "sync" @@ -56,7 +57,12 @@ func main() { c.Next() }) - httpServer := server.NewHTTPServer(":8080", nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin, + router.Use((middleware.NewRateLimiterMiddleware(diceDBAdminClient, + configValue.Server.RequestLimitPerMin, + configValue.Server.RequestWindowSec, + ).Exec)) + + httpServer := server.NewHTTPServer(router, nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin, configValue.Server.RequestWindowSec) // Register routes @@ -68,7 +74,7 @@ func main() { go func() { defer wg.Done() // Run the HTTP Server - if err := router.Run(":8080"); err != nil { + if err := httpServer.Run(context.Background()); err != nil { slog.Error("server failed: %v\n", slog.Any("err", err)) diceDBAdminClient.CloseDiceDB() cancel()