diff --git a/.local.dev b/.local.dev index 46b96a84..d0fab552 100644 --- a/.local.dev +++ b/.local.dev @@ -30,3 +30,4 @@ WEB3_TOKEN_ADDRESS_=0xa513E6E4b8f2a923D98304ec87F64353C4D5C853 WEB3_USERS_ADDRESS=0x0DCd1Bf9A1b36cE34237eEaFef220932846BCD82 BACALHAU_API_HOST=localhost BACALHAU_API_PORT=1234 +SERVER_RATE_EXEMPTED_IPS=127.0.0.1,::1 diff --git a/pkg/http/types.go b/pkg/http/types.go index c50d143f..cc571a88 100644 --- a/pkg/http/types.go +++ b/pkg/http/types.go @@ -21,6 +21,7 @@ type ValidationToken struct { type RateLimiterOptions struct { RequestLimit int WindowLength int + ExemptedIPs []string } type ClientOptions struct { diff --git a/pkg/http/utils.go b/pkg/http/utils.go index 17aa2e73..9696ee86 100644 --- a/pkg/http/utils.go +++ b/pkg/http/utils.go @@ -9,6 +9,7 @@ import ( "fmt" "io" stdlog "log" + "net" "net/http" "net/url" "strings" @@ -443,3 +444,30 @@ func newRetryClient() *retryablehttp.Client { } return retryClient } + +func CanonicalizeIP(ip string) string { + isIPv6 := false + // This is how net.ParseIP decides if an address is IPv6 + // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704 + for i := 0; !isIPv6 && i < len(ip); i++ { + switch ip[i] { + case '.': + // IPv4 + return ip + case ':': + // IPv6 + isIPv6 = true + break + } + } + if !isIPv6 { + // Not an IP address at all + return ip + } + + ipv6 := net.ParseIP(ip) + if ipv6 == nil { + return ip + } + return ipv6.Mask(net.CIDRMask(64, 128)).String() +} diff --git a/pkg/options/server.go b/pkg/options/server.go index 29b1776d..ac8fa0bd 100644 --- a/pkg/options/server.go +++ b/pkg/options/server.go @@ -2,7 +2,7 @@ package options import ( "fmt" - + "net" "github.com/lilypad-tech/lilypad/pkg/http" "github.com/spf13/cobra" ) @@ -29,6 +29,7 @@ func GetDefaultRateLimiterOptions() http.RateLimiterOptions { return http.RateLimiterOptions{ RequestLimit: GetDefaultServeOptionInt("SERVER_RATE_REQUEST_LIMIT", 5), WindowLength: GetDefaultServeOptionInt("SERVER_RATE_WINDOW_LENGTH", 10), + ExemptedIPs: GetDefaultServeOptionStringArray("SERVER_RATE_EXEMPTED_IPS", []string{}), } } @@ -68,6 +69,10 @@ func AddServerCliFlags(cmd *cobra.Command, serverOptions *http.ServerOptions) { &serverOptions.RateLimiter.WindowLength, "server-rate-window-length", serverOptions.RateLimiter.WindowLength, `The time window over which to limit in seconds (SERVER_RATE_WINDOW_LENGTH).`, ) + cmd.PersistentFlags().StringArrayVar( + &serverOptions.RateLimiter.ExemptedIPs, "server-rate-exempted-ips", serverOptions.RateLimiter.ExemptedIPs, + `The IPs to exempt from rate limiting (SERVER_RATE_EXEMPTED_IPS).`, + ) } func CheckServerOptions(options http.ServerOptions) error { @@ -80,5 +85,12 @@ func CheckServerOptions(options http.ServerOptions) error { if options.AccessControl.ValidationTokenKid == "" { return fmt.Errorf("SERVER_VALIDATION_TOKEN_KID is required") } + if len(options.RateLimiter.ExemptedIPs) > 0 { + for _, ip := range options.RateLimiter.ExemptedIPs { + if net.ParseIP(ip) == nil { + return fmt.Errorf("invalid IP address: %s", ip) + } + } + } return nil } diff --git a/pkg/solver/ratelimit_test.go b/pkg/solver/ratelimit_test.go index f2976646..3e4a8c3c 100644 --- a/pkg/solver/ratelimit_test.go +++ b/pkg/solver/ratelimit_test.go @@ -4,8 +4,8 @@ package solver_test import ( "fmt" + "math/rand" "net/http" - "os" "sync" "testing" "time" @@ -17,9 +17,20 @@ type rateResult struct { limitedCount int } -// This test suite sends 100 requests over approximately half a second. +type rateTestCase struct { + name string + headers map[string]string + expectedOK int + expectedLimit int +} + +// This test suite sends 200 requests to three different paths. We send the +// requests in rate limited and exempt test groups. The rate limited group +// should allow 5/100 requests through and the exempt group should allow 100/100. +// // We assume the solver uses the default rate limiting settings with -// a request limit of 5 and window length of 10 seconds. +// a request limit of 5 and window length of 10 seconds. In addition, the solver +// should be configured to exempt localhost. func TestRateLimiter(t *testing.T) { paths := []string{ "/api/v1/resource_offers", @@ -27,45 +38,94 @@ func TestRateLimiter(t *testing.T) { "/api/v1/deals", } + // The solver should rate limit when forwarded + // headers are set to 1.2.3.4. + nonExemptHeaders := []map[string]string{ + {"True-Client-IP": "1.2.3.4"}, + {"X-Real-IP": "1.2.3.4"}, + {"X-Forwarded-For": "1.2.3.4"}, + } + + // The running solver is configured to exempt localhost. + // When no headers are set, test using the IP address from + // the underlying connection (also localhost) + exemptHeaders := []map[string]string{ + {"True-Client-IP": "127.0.0.1"}, + {"X-Real-IP": "127.0.0.1"}, + {"X-Forwarded-For": "127.0.0.1"}, + {}, // No headers case - uses RemoteAddr + } + + t.Run("non-exempt IP is rate limited", func(t *testing.T) { + // Select a random header on each test run. Over time we test them all. + headers := nonExemptHeaders[rand.Intn(len(nonExemptHeaders))] + tc := rateTestCase{ + name: fmt.Sprintf("rate limited with headers %v", headers), + headers: headers, + expectedOK: 5, + expectedLimit: 95, + } + runRateLimitTest(t, paths, tc) + }) + + t.Run("exempt IP is not rate limited", func(t *testing.T) { + // Select a random header on each test run. Over time we test them all. + headers := exemptHeaders[rand.Intn(len(exemptHeaders))] + tc := rateTestCase{ + name: fmt.Sprintf("exempt with headers %v", headers), + headers: headers, + expectedOK: 100, + expectedLimit: 0, + } + runRateLimitTest(t, paths, tc) + }) +} + +func runRateLimitTest(t *testing.T, paths []string, tc rateTestCase) { var wg sync.WaitGroup ch := make(chan rateResult, len(paths)) - // Send off callers to run concurrently + // Run the calls against paths in parallel for _, path := range paths { wg.Add(1) - - go func() { + go func(path string) { defer wg.Done() - makeCalls(t, path, ch) - }() + makeCalls(t, path, ch, tc) + }(path) } wg.Wait() close(ch) - expectedOkCount := 5 for result := range ch { - if result.okCount > expectedOkCount { - t.Errorf( - "%s allowed %d requests and limited %d requests, but expected limiting after %d requests\n", - result.path, result.okCount, result.limitedCount, expectedOkCount, - ) + if result.okCount != tc.expectedOK { + t.Errorf("%s: Expected %d successful requests, got %d", + result.path, tc.expectedOK, result.okCount) + } + if result.limitedCount != tc.expectedLimit { + t.Errorf("%s: Expected %d rate limited requests, got %d", + result.path, tc.expectedLimit, result.limitedCount) } } } -func makeCalls(t *testing.T, path string, ch chan rateResult) { +func makeCalls(t *testing.T, path string, ch chan rateResult, tc rateTestCase) { var okCount int var limitedCount int + client := &http.Client{} + + for i := 0; i < 100; i++ { + req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d%s", 8081, path), nil) - // Make 100 requests - for range 100 { - requestURL := fmt.Sprintf("http://localhost:%d%s", 8081, path) - res, err := http.Get(requestURL) + // Set test case headers + for key, value := range tc.headers { + req.Header.Set(key, value) + } + res, err := client.Do(req) if err != nil { - t.Errorf("Get request failed on %s: %s\n", path, err) - os.Exit(1) + t.Errorf("Request failed on %s: %s\n", path, err) + return } if res.StatusCode == 200 { @@ -76,7 +136,6 @@ func makeCalls(t *testing.T, path string, ch chan rateResult) { t.Errorf("Expected a 200 or 429 status code, but received a %d\n", res.StatusCode) } - // Wait before making next call time.Sleep(5 * time.Millisecond) } diff --git a/pkg/solver/server.go b/pkg/solver/server.go index 3fd5ef61..995c1f11 100644 --- a/pkg/solver/server.go +++ b/pkg/solver/server.go @@ -51,17 +51,16 @@ func NewSolverServer( } /* - * - * - * +* +* +* - Routes +# Routes - * - * - * +* +* +* */ - func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system.CleanupManager, tracerProvider *trace.TracerProvider) error { router := mux.NewRouter() @@ -69,10 +68,24 @@ func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system subrouter.Use(http.CorsMiddleware) subrouter.Use(otelmux.Middleware("solver", otelmux.WithTracerProvider(tracerProvider))) + + exemptIPs := solverServer.options.RateLimiter.ExemptedIPs + + log.Debug().Strs("exemptIPs", exemptIPs).Msg("Loaded rate limit exemptions") + subrouter.Use(httprate.Limit( solverServer.options.RateLimiter.RequestLimit, time.Duration(solverServer.options.RateLimiter.WindowLength)*time.Second, - httprate.WithKeyFuncs(httprate.KeyByRealIP, httprate.KeyByEndpoint), + httprate.WithKeyFuncs( + exemptIPKeyFunc(exemptIPs), + httprate.KeyByEndpoint, + ), + httprate.WithErrorHandler(func(w corehttp.ResponseWriter, r *corehttp.Request, err error) { + if err.Error() == "RATE_LIMIT_EXEMPT" { + return + } + corehttp.Error(w, err.Error(), corehttp.StatusTooManyRequests) + }), )) subrouter.HandleFunc("/job_offers", http.GetHandler(solverServer.getJobOffers)).Methods("GET") @@ -180,6 +193,25 @@ func (solverServer *solverServer) disconnectCB(connParams http.WSConnectionParam } } +func exemptIPKeyFunc(exemptIPs []string) func(r *corehttp.Request) (string, error) { + return func(r *corehttp.Request) (string, error) { + ip, err := httprate.KeyByRealIP(r) + if err != nil { + log.Error().Err(err).Msg("error getting real ip") + return ip, err + } + + // Check if the IP is in the exempt list + for _, exemptIP := range exemptIPs { + if http.CanonicalizeIP(exemptIP) == ip { + return "", fmt.Errorf("RATE_LIMIT_EXEMPT") + } + } + + return ip, nil + } +} + /* * * @@ -639,51 +671,51 @@ func (solverServer *solverServer) jobOfferDownloadFiles(res corehttp.ResponseWri } func (solverServer *solverServer) handleFileDownload(dirPath string, res corehttp.ResponseWriter) *http.HTTPError { - // Read directory contents - files, err := os.ReadDir(dirPath) - if err != nil { - return &http.HTTPError{ - Message: fmt.Sprintf("error reading directory: %s", err.Error()), - StatusCode: corehttp.StatusNotFound, - } - } - - // Find the first regular file - var targetFile os.DirEntry - for _, file := range files { - info, err := file.Info() - if err != nil { - continue - } - if info.Mode().IsRegular() { - targetFile = file - break - } - } - - if targetFile == nil { - return &http.HTTPError{ - Message: "no regular files found in directory", - StatusCode: corehttp.StatusNotFound, - } - } - - // Get the actual filename and path - filename := targetFile.Name() - filePath := filepath.Join(dirPath, filename) - - // Open and serve the file - file, err := os.Open(filePath) - if err != nil { - return &http.HTTPError{ - Message: err.Error(), - StatusCode: corehttp.StatusInternalServerError, - } - } - defer file.Close() - - res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) - res.Header().Set("Content-Type", "application/x-tar") + // Read directory contents + files, err := os.ReadDir(dirPath) + if err != nil { + return &http.HTTPError{ + Message: fmt.Sprintf("error reading directory: %s", err.Error()), + StatusCode: corehttp.StatusNotFound, + } + } + + // Find the first regular file + var targetFile os.DirEntry + for _, file := range files { + info, err := file.Info() + if err != nil { + continue + } + if info.Mode().IsRegular() { + targetFile = file + break + } + } + + if targetFile == nil { + return &http.HTTPError{ + Message: "no regular files found in directory", + StatusCode: corehttp.StatusNotFound, + } + } + + // Get the actual filename and path + filename := targetFile.Name() + filePath := filepath.Join(dirPath, filename) + + // Open and serve the file + file, err := os.Open(filePath) + if err != nil { + return &http.HTTPError{ + Message: err.Error(), + StatusCode: corehttp.StatusInternalServerError, + } + } + defer file.Close() + + res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) + res.Header().Set("Content-Type", "application/x-tar") _, err = io.Copy(res, file) if err != nil {