From 68a459db122f317b7a82c4d9c7dbdf99b32061e1 Mon Sep 17 00:00:00 2001 From: Ayush Kumar Date: Mon, 3 Feb 2025 13:55:59 -0500 Subject: [PATCH] ip parsing, cli flags, cloudflare proxy support --- pkg/options/server.go | 13 +++- pkg/solver/server.go | 171 ++++++++++++++++++++++++------------------ 2 files changed, 110 insertions(+), 74 deletions(-) diff --git a/pkg/options/server.go b/pkg/options/server.go index aee61354..ac8fa0bd 100644 --- a/pkg/options/server.go +++ b/pkg/options/server.go @@ -2,7 +2,7 @@ package options import ( "fmt" - + "net" "github.com/lilypad-tech/lilypad/pkg/http" "github.com/spf13/cobra" ) @@ -69,6 +69,10 @@ func AddServerCliFlags(cmd *cobra.Command, serverOptions *http.ServerOptions) { &serverOptions.RateLimiter.WindowLength, "server-rate-window-length", serverOptions.RateLimiter.WindowLength, `The time window over which to limit in seconds (SERVER_RATE_WINDOW_LENGTH).`, ) + cmd.PersistentFlags().StringArrayVar( + &serverOptions.RateLimiter.ExemptedIPs, "server-rate-exempted-ips", serverOptions.RateLimiter.ExemptedIPs, + `The IPs to exempt from rate limiting (SERVER_RATE_EXEMPTED_IPS).`, + ) } func CheckServerOptions(options http.ServerOptions) error { @@ -81,5 +85,12 @@ func CheckServerOptions(options http.ServerOptions) error { if options.AccessControl.ValidationTokenKid == "" { return fmt.Errorf("SERVER_VALIDATION_TOKEN_KID is required") } + if len(options.RateLimiter.ExemptedIPs) > 0 { + for _, ip := range options.RateLimiter.ExemptedIPs { + if net.ParseIP(ip) == nil { + return fmt.Errorf("invalid IP address: %s", ip) + } + } + } return nil } diff --git a/pkg/solver/server.go b/pkg/solver/server.go index fdae4a95..8dde2a20 100644 --- a/pkg/solver/server.go +++ b/pkg/solver/server.go @@ -6,8 +6,8 @@ import ( "errors" "fmt" "io" - corehttp "net/http" "net" + corehttp "net/http" "os" "path/filepath" "strings" @@ -52,35 +52,16 @@ func NewSolverServer( } /* - * - * - * +* +* +* - Routes +# Routes - * - * - * +* +* +* */ - -func exemptIPKeyFunc(exemptIPs []string) func(r *corehttp.Request) (string, error) { - return func(r *corehttp.Request) (string, error) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - ip = r.RemoteAddr - } - - - for _, exemptIP := range exemptIPs { - if ip == exemptIP { - return "", nil - } - } - - return ip, nil - } -} - func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system.CleanupManager, tracerProvider *trace.TracerProvider) error { router := mux.NewRouter() @@ -89,7 +70,6 @@ func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system subrouter.Use(http.CorsMiddleware) subrouter.Use(otelmux.Middleware("solver", otelmux.WithTracerProvider(tracerProvider))) - exemptIPs := solverServer.options.RateLimiter.ExemptedIPs log.Debug().Strs("exemptIPs", exemptIPs).Msg("Loaded rate limit exemptions") @@ -208,6 +188,51 @@ func (solverServer *solverServer) disconnectCB(connParams http.WSConnectionParam } } +func exemptIPKeyFunc(exemptIPs []string) func(r *corehttp.Request) (string, error) { + return func(r *corehttp.Request) (string, error) { + + ip, err := getRealIPWithCloudflare(r) + + if err != nil { + log.Error().Err(err).Msgf("error getting real ip") + return httprate.KeyByEndpoint(r) + } + + for _, exemptIP := range exemptIPs { + if ip == canonicalizeIP(exemptIP) { + return "", nil + } + } + + return httprate.KeyByEndpoint(r) + } +} + +func getRealIPWithCloudflare(r *corehttp.Request) (string, error) { + if ip := r.Header.Get("CF-Connecting-IP"); ip != "" { + return canonicalizeIP(ip), nil + } + + return httprate.KeyByRealIP(r) +} + +func canonicalizeIP(ip string) string { + for _, c := range ip { + switch c { + case '.': + return ip + case ':': + parsed := net.ParseIP(ip) + if parsed == nil { + return ip + } + return parsed.Mask(net.CIDRMask(64, 128)).String() + } + } + + return ip +} + /* * * @@ -655,51 +680,51 @@ func (solverServer *solverServer) jobOfferDownloadFiles(res corehttp.ResponseWri } func (solverServer *solverServer) handleFileDownload(dirPath string, res corehttp.ResponseWriter) *http.HTTPError { - // Read directory contents - files, err := os.ReadDir(dirPath) - if err != nil { - return &http.HTTPError{ - Message: fmt.Sprintf("error reading directory: %s", err.Error()), - StatusCode: corehttp.StatusNotFound, - } - } - - // Find the first regular file - var targetFile os.DirEntry - for _, file := range files { - info, err := file.Info() - if err != nil { - continue - } - if info.Mode().IsRegular() { - targetFile = file - break - } - } - - if targetFile == nil { - return &http.HTTPError{ - Message: "no regular files found in directory", - StatusCode: corehttp.StatusNotFound, - } - } - - // Get the actual filename and path - filename := targetFile.Name() - filePath := filepath.Join(dirPath, filename) - - // Open and serve the file - file, err := os.Open(filePath) - if err != nil { - return &http.HTTPError{ - Message: err.Error(), - StatusCode: corehttp.StatusInternalServerError, - } - } - defer file.Close() - - res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) - res.Header().Set("Content-Type", "application/x-tar") + // Read directory contents + files, err := os.ReadDir(dirPath) + if err != nil { + return &http.HTTPError{ + Message: fmt.Sprintf("error reading directory: %s", err.Error()), + StatusCode: corehttp.StatusNotFound, + } + } + + // Find the first regular file + var targetFile os.DirEntry + for _, file := range files { + info, err := file.Info() + if err != nil { + continue + } + if info.Mode().IsRegular() { + targetFile = file + break + } + } + + if targetFile == nil { + return &http.HTTPError{ + Message: "no regular files found in directory", + StatusCode: corehttp.StatusNotFound, + } + } + + // Get the actual filename and path + filename := targetFile.Name() + filePath := filepath.Join(dirPath, filename) + + // Open and serve the file + file, err := os.Open(filePath) + if err != nil { + return &http.HTTPError{ + Message: err.Error(), + StatusCode: corehttp.StatusInternalServerError, + } + } + defer file.Close() + + res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) + res.Header().Set("Content-Type", "application/x-tar") _, err = io.Copy(res, file) if err != nil {