diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index cfc7b3b..f8d3c52 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/http" + "server/config" "server/internal/db" "server/internal/server/utils" mock "server/internal/tests/dbmocks" @@ -18,6 +19,8 @@ import ( // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { + configValue := config.LoadConfig() + cronFrequencyInterval := configValue.Server.CronCleanupFrequency return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if handleCors(w, r) { return @@ -78,28 +81,40 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float } } - // Get the cron last cleanup run time - var lastCronCleanupTime int64 - resp := client.Client.Get(ctx, utils.LastCronCleanupTimeUnixMs) - if resp.Err() != nil && !errors.Is(resp.Err(), dicedb.Nil) { - slog.Error("Failed to get last cron cleanup time for headers", slog.Any("err", resp.Err().Error())) - } - - if resp.Val() != "" { - lastCronCleanupTime, err = strconv.ParseInt(resp.Val(), 10, 64) - if err != nil { - slog.Error("Error converting last cron cleanup time", "error", err) - } + secondsDifference, err := calculateNextCleanupTime(ctx, client, cronFrequencyInterval) + if err != nil { + slog.Error("Error calculating next cleanup time", "error", err) } addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), - lastCronCleanupTime) + secondsDifference) slog.Info("Request processed", "count", requestCount+1) next.ServeHTTP(w, r) }) } +func calculateNextCleanupTime(ctx context.Context, client *db.DiceDB, cronFrequencyInterval time.Duration) (int64, error) { + var lastCronCleanupTime int64 + resp := client.Client.Get(ctx, utils.LastCronCleanupTimeUnixMs) + if resp.Err() != nil && !errors.Is(resp.Err(), dicedb.Nil) { + return -1, resp.Err() + } + + if resp.Val() != "" { + var err error + lastCronCleanupTime, err = strconv.ParseInt(resp.Val(), 10, 64) // directly assign here + if err != nil { + return -1, err + } + } + + lastCleanupTime := time.UnixMilli(lastCronCleanupTime) + nextCleanupTime := lastCleanupTime.Add(cronFrequencyInterval) + timeDifference := nextCleanupTime.Sub(time.Now()) + return int64(timeDifference.Seconds()), nil +} + func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if handleCors(w, r) { @@ -170,14 +185,14 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi }) } -func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, cronLastCleanupTime int64) { +func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) { w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10)) w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10)) w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10)) w.Header().Set("x-ratelimit-reset", strconv.FormatInt(resetTime, 10)) - w.Header().Set("x-last-cleanup-time", strconv.FormatInt(cronLastCleanupTime, 10)) + w.Header().Set("x-next-cleanup-time", strconv.FormatInt(secondsLeftForCleanup, 10)) // Expose the rate limit headers to the client w.Header().Set("Access-Control-Expose-Headers", "x-ratelimit-limit, x-ratelimit-remaining,"+ - "x-ratelimit-used, x-ratelimit-reset, x-last-cleanup-time") + "x-ratelimit-used, x-ratelimit-reset, x-next-cleanup-time") }