From 2e45060357a0e7f302b5aa74a4e886488507216b Mon Sep 17 00:00:00 2001 From: Gaurav Sarma Date: Wed, 11 Dec 2024 19:48:29 +0800 Subject: [PATCH] Modified the TrailingSlashMiddleware --- internal/middleware/trailingslash.go | 24 ++++++++++++------------ internal/server/http.go | 16 +--------------- main.go | 10 ++++++++-- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 80871ce..a4f93a0 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -3,19 +3,19 @@ package middleware import ( "net/http" "strings" + + "github.com/gin-gonic/gin" ) -func TrailingSlashMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - newPath := strings.TrimSuffix(r.URL.Path, "/") - newURL := newPath - if r.URL.RawQuery != "" { - newURL += "?" + r.URL.RawQuery - } - http.Redirect(w, r, newURL, http.StatusMovedPermanently) - return +func TrailingSlashMiddleware(c *gin.Context) { + if c.Request.URL.Path != "/" && strings.HasSuffix(c.Request.URL.Path, "/") { + newPath := strings.TrimSuffix(c.Request.URL.Path, "/") + newURL := newPath + if c.Request.URL.RawQuery != "" { + newURL += "?" + c.Request.URL.RawQuery } - next.ServeHTTP(w, r) - }) + http.Redirect(c.Writer, c.Request, newURL, http.StatusMovedPermanently) + return + } + c.Next() } diff --git a/internal/server/http.go b/internal/server/http.go index 2553a97..f37a8f0 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -6,11 +6,9 @@ import ( "errors" "log/slog" "net/http" - "strings" "time" "server/internal/db" - "server/internal/middleware" util "server/util" "github.com/gin-gonic/gin" @@ -21,11 +19,6 @@ type HTTPServer struct { DiceClient *db.DiceDB } -type HandlerMux struct { - mux *http.ServeMux - rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) -} - type HTTPResponse struct { Data interface{} `json:"data"` } @@ -45,14 +38,7 @@ func errorResponse(response string) string { return string(jsonResponse) } -func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = strings.ToLower(r.URL.Path) - cim.rateLimiter(w, r, cim.mux) - })).ServeHTTP(w, r) -} - -func NewHTTPServer(router *gin.Engine, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, +func NewHTTPServer(router *gin.Engine, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, limit int64, window float64) *HTTPServer { return &HTTPServer{ diff --git a/main.go b/main.go index f36bde0..1106be4 100644 --- a/main.go +++ b/main.go @@ -57,13 +57,19 @@ func main() { c.Next() }) + router.Use(middleware.TrailingSlashMiddleware) router.Use((middleware.NewRateLimiterMiddleware(diceDBAdminClient, configValue.Server.RequestLimitPerMin, configValue.Server.RequestWindowSec, ).Exec)) - httpServer := server.NewHTTPServer(router, nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin, - configValue.Server.RequestWindowSec) + httpServer := server.NewHTTPServer( + router, + diceDBAdminClient, + diceDBClient, + configValue.Server.RequestLimitPerMin, + configValue.Server.RequestWindowSec, + ) // Register routes router.GET("/health", gin.WrapF(httpServer.HealthCheck))