Skip to content

Commit

Permalink
Integrate Gin router with HTTP server
Browse files Browse the repository at this point in the history
  • Loading branch information
gauravsarma1992 committed Dec 11, 2024
1 parent 795f105 commit 94d75e3
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 72 deletions.
140 changes: 79 additions & 61 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,83 +15,100 @@ import (
"time"

"github.com/dicedb/dicedb-go"
"github.com/gin-gonic/gin"
)

// RateLimiter middleware to limit requests based on a specified limit and duration
func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler {
configValue := config.LoadConfig()
cronFrequencyInterval := configValue.Server.CronCleanupFrequency
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handleCors(w, r) {
return
}
type (
RateLimiterMiddleware struct {
client *db.DiceDB
limit int64
window float64
conf *config.Config
cronFrequencyInterval time.Duration
}
)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
func NewRateLimiterMiddleware(client *db.DiceDB, limit int64, window float64) (rl *RateLimiterMiddleware) {
rl = &RateLimiterMiddleware{
client: client,
limit: limit,
window: window,
cronFrequencyInterval: config.LoadConfig().Server.CronCleanupFrequency,
}
return
}

// Only apply rate limiting for specific paths (e.g., "/cli/")
if !strings.Contains(r.URL.Path, "/shell/exec/") {
next.ServeHTTP(w, r)
return
}
// RateLimiter middleware to limit requests based on a specified limit and duration
func (rl *RateLimiterMiddleware) Exec(c *gin.Context) {
if handleCors(c.Writer, c.Request) {
return
}

// Generate the rate limiting key based on the current window
currentWindow := time.Now().Unix() / int64(window)
key := fmt.Sprintf("request_count:%d", currentWindow)
slog.Debug("Created rate limiter key", slog.Any("key", key))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Get the current request count for this window
val, err := client.Client.Get(ctx, key).Result()
if err != nil && !errors.Is(err, dicedb.Nil) {
slog.Error("Error fetching request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Only apply rate limiting for specific paths (e.g., "/cli/")
if !strings.Contains(c.Request.URL.Path, "/shell/exec/") {
c.Next()
return
}

// Parse the current request count or initialize to 0
var requestCount int64 = 0
if val != "" {
requestCount, err = strconv.ParseInt(val, 10, 64)
if err != nil {
slog.Error("Error converting request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
// Generate the rate limiting key based on the current window
currentWindow := time.Now().Unix() / int64(rl.window)
key := fmt.Sprintf("request_count:%d", currentWindow)
slog.Debug("Created rate limiter key", slog.Any("key", key))

// Get the current request count for this window
val, err := rl.client.Client.Get(ctx, key).Result()
if err != nil && !errors.Is(err, dicedb.Nil) {
slog.Error("Error fetching request count", "error", err)
http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError)
return
}

// Check if the request count exceeds the limit
if requestCount >= limit {
slog.Warn("Request limit exceeded", "count", requestCount)
addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0)
http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests)
// Parse the current request count or initialize to 0
var requestCount int64 = 0
if val != "" {
requestCount, err = strconv.ParseInt(val, 10, 64)
if err != nil {
slog.Error("Error converting request count", "error", err)
http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError)
return
}
}

// Increment the request count
if requestCount, err = client.Client.Incr(ctx, key).Result(); err != nil {
slog.Error("Error incrementing request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Check if the request count exceeds the limit
if requestCount >= rl.limit {
slog.Warn("Request limit exceeded", "count", requestCount)
addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0)
http.Error(c.Writer, "429 - Too Many Requests", http.StatusTooManyRequests)
return
}

// Set the key expiry if it's newly created
if requestCount == 1 {
if err := client.Client.Expire(ctx, key, time.Duration(window)*time.Second).Err(); err != nil {
slog.Error("Error setting expiry for request count", "error", err)
}
}
// Increment the request count
if requestCount, err = rl.client.Client.Incr(ctx, key).Result(); err != nil {
slog.Error("Error incrementing request count", "error", err)
http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError)
return
}

secondsDifference, err := calculateNextCleanupTime(ctx, client, cronFrequencyInterval)
if err != nil {
slog.Error("Error calculating next cleanup time", "error", err)
// Set the key expiry if it's newly created
if requestCount == 1 {
if err := rl.client.Client.Expire(ctx, key, time.Duration(rl.window)*time.Second).Err(); err != nil {
slog.Error("Error setting expiry for request count", "error", err)
}
}

addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window),
secondsDifference)
secondsDifference, err := calculateNextCleanupTime(ctx, rl.client, rl.cronFrequencyInterval)
if err != nil {
slog.Error("Error calculating next cleanup time", "error", err)
}

slog.Info("Request processed", "count", requestCount+1)
next.ServeHTTP(w, r)
})
addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window),
secondsDifference)

slog.Info("Request processed", "count", requestCount+1)
c.Next()
}

func calculateNextCleanupTime(ctx context.Context, client *db.DiceDB, cronFrequencyInterval time.Duration) (int64, error) {
Expand Down Expand Up @@ -186,6 +203,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
}

func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) {

w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10))
w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10))
w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10))
Expand Down
14 changes: 5 additions & 9 deletions internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"server/internal/db"
"server/internal/middleware"
util "server/util"

"github.com/gin-gonic/gin"
)

type HTTPServer struct {
Expand Down Expand Up @@ -50,19 +52,13 @@ func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
})).ServeHTTP(w, r)
}

func NewHTTPServer(addr string, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB,
func NewHTTPServer(router *gin.Engine, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB,
limit int64, window float64) *HTTPServer {
handlerMux := &HandlerMux{
mux: mux,
rateLimiter: func(w http.ResponseWriter, r *http.Request, next http.Handler) {
middleware.RateLimiter(diceDBAdminClient, next, limit, window).ServeHTTP(w, r)
},
}

return &HTTPServer{
httpServer: &http.Server{
Addr: addr,
Handler: handlerMux,
Addr: ":8080",
Handler: router,
ReadHeaderTimeout: 5 * time.Second,
},
DiceClient: diceClient,
Expand Down
10 changes: 8 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"server/config"
"server/internal/db"
"server/internal/middleware"
"server/internal/server"
"sync"

Expand Down Expand Up @@ -56,7 +57,12 @@ func main() {
c.Next()
})

httpServer := server.NewHTTPServer(":8080", nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin,
router.Use((middleware.NewRateLimiterMiddleware(diceDBAdminClient,
configValue.Server.RequestLimitPerMin,
configValue.Server.RequestWindowSec,
).Exec))

httpServer := server.NewHTTPServer(router, nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin,
configValue.Server.RequestWindowSec)

// Register routes
Expand All @@ -68,7 +74,7 @@ func main() {
go func() {
defer wg.Done()
// Run the HTTP Server
if err := router.Run(":8080"); err != nil {
if err := httpServer.Run(context.Background()); err != nil {
slog.Error("server failed: %v\n", slog.Any("err", err))
diceDBAdminClient.CloseDiceDB()
cancel()
Expand Down

0 comments on commit 94d75e3

Please sign in to comment.