Skip to content

Commit

Permalink
feat: solver rate limiter exemption list (#507)
Browse files Browse the repository at this point in the history
Adds a way to exempt ips from the solver rate limitter
  • Loading branch information
kelindi committed Feb 10, 2025
1 parent f70545b commit 81a446d
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 77 deletions.
1 change: 1 addition & 0 deletions .local.dev
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ WEB3_TOKEN_ADDRESS_=0xa513E6E4b8f2a923D98304ec87F64353C4D5C853
WEB3_USERS_ADDRESS=0x0DCd1Bf9A1b36cE34237eEaFef220932846BCD82
BACALHAU_API_HOST=localhost
BACALHAU_API_PORT=1234
SERVER_RATE_EXEMPTED_IPS=127.0.0.1,::1
1 change: 1 addition & 0 deletions pkg/http/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ValidationToken struct {
type RateLimiterOptions struct {
RequestLimit int
WindowLength int
ExemptedIPs []string
}

type ClientOptions struct {
Expand Down
28 changes: 28 additions & 0 deletions pkg/http/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
stdlog "log"
"net"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -443,3 +444,30 @@ func newRetryClient() *retryablehttp.Client {
}
return retryClient
}

func CanonicalizeIP(ip string) string {
isIPv6 := false
// This is how net.ParseIP decides if an address is IPv6
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704
for i := 0; !isIPv6 && i < len(ip); i++ {
switch ip[i] {
case '.':
// IPv4
return ip
case ':':
// IPv6
isIPv6 = true
break
}
}
if !isIPv6 {
// Not an IP address at all
return ip
}

ipv6 := net.ParseIP(ip)
if ipv6 == nil {
return ip
}
return ipv6.Mask(net.CIDRMask(64, 128)).String()
}
14 changes: 13 additions & 1 deletion pkg/options/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package options

import (
"fmt"

"net"
"github.com/lilypad-tech/lilypad/pkg/http"
"github.com/spf13/cobra"
)
Expand All @@ -29,6 +29,7 @@ func GetDefaultRateLimiterOptions() http.RateLimiterOptions {
return http.RateLimiterOptions{
RequestLimit: GetDefaultServeOptionInt("SERVER_RATE_REQUEST_LIMIT", 5),
WindowLength: GetDefaultServeOptionInt("SERVER_RATE_WINDOW_LENGTH", 10),
ExemptedIPs: GetDefaultServeOptionStringArray("SERVER_RATE_EXEMPTED_IPS", []string{}),
}
}

Expand Down Expand Up @@ -68,6 +69,10 @@ func AddServerCliFlags(cmd *cobra.Command, serverOptions *http.ServerOptions) {
&serverOptions.RateLimiter.WindowLength, "server-rate-window-length", serverOptions.RateLimiter.WindowLength,
`The time window over which to limit in seconds (SERVER_RATE_WINDOW_LENGTH).`,
)
cmd.PersistentFlags().StringArrayVar(
&serverOptions.RateLimiter.ExemptedIPs, "server-rate-exempted-ips", serverOptions.RateLimiter.ExemptedIPs,
`The IPs to exempt from rate limiting (SERVER_RATE_EXEMPTED_IPS).`,
)
}

func CheckServerOptions(options http.ServerOptions) error {
Expand All @@ -80,5 +85,12 @@ func CheckServerOptions(options http.ServerOptions) error {
if options.AccessControl.ValidationTokenKid == "" {
return fmt.Errorf("SERVER_VALIDATION_TOKEN_KID is required")
}
if len(options.RateLimiter.ExemptedIPs) > 0 {
for _, ip := range options.RateLimiter.ExemptedIPs {
if net.ParseIP(ip) == nil {
return fmt.Errorf("invalid IP address: %s", ip)
}
}
}
return nil
}
103 changes: 81 additions & 22 deletions pkg/solver/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package solver_test

import (
"fmt"
"math/rand"
"net/http"
"os"
"sync"
"testing"
"time"
Expand All @@ -17,55 +17,115 @@ type rateResult struct {
limitedCount int
}

// This test suite sends 100 requests over approximately half a second.
type rateTestCase struct {
name string
headers map[string]string
expectedOK int
expectedLimit int
}

// This test suite sends 200 requests to three different paths. We send the
// requests in rate limited and exempt test groups. The rate limited group
// should allow 5/100 requests through and the exempt group should allow 100/100.
//
// We assume the solver uses the default rate limiting settings with
// a request limit of 5 and window length of 10 seconds.
// a request limit of 5 and window length of 10 seconds. In addition, the solver
// should be configured to exempt localhost.
func TestRateLimiter(t *testing.T) {
paths := []string{
"/api/v1/resource_offers",
"/api/v1/job_offers",
"/api/v1/deals",
}

// The solver should rate limit when forwarded
// headers are set to 1.2.3.4.
nonExemptHeaders := []map[string]string{
{"True-Client-IP": "1.2.3.4"},
{"X-Real-IP": "1.2.3.4"},
{"X-Forwarded-For": "1.2.3.4"},
}

// The running solver is configured to exempt localhost.
// When no headers are set, test using the IP address from
// the underlying connection (also localhost)
exemptHeaders := []map[string]string{
{"True-Client-IP": "127.0.0.1"},
{"X-Real-IP": "127.0.0.1"},
{"X-Forwarded-For": "127.0.0.1"},
{}, // No headers case - uses RemoteAddr
}

t.Run("non-exempt IP is rate limited", func(t *testing.T) {
// Select a random header on each test run. Over time we test them all.
headers := nonExemptHeaders[rand.Intn(len(nonExemptHeaders))]
tc := rateTestCase{
name: fmt.Sprintf("rate limited with headers %v", headers),
headers: headers,
expectedOK: 5,
expectedLimit: 95,
}
runRateLimitTest(t, paths, tc)
})

t.Run("exempt IP is not rate limited", func(t *testing.T) {
// Select a random header on each test run. Over time we test them all.
headers := exemptHeaders[rand.Intn(len(exemptHeaders))]
tc := rateTestCase{
name: fmt.Sprintf("exempt with headers %v", headers),
headers: headers,
expectedOK: 100,
expectedLimit: 0,
}
runRateLimitTest(t, paths, tc)
})
}

func runRateLimitTest(t *testing.T, paths []string, tc rateTestCase) {
var wg sync.WaitGroup
ch := make(chan rateResult, len(paths))

// Send off callers to run concurrently
// Run the calls against paths in parallel
for _, path := range paths {
wg.Add(1)

go func() {
go func(path string) {
defer wg.Done()
makeCalls(t, path, ch)
}()
makeCalls(t, path, ch, tc)
}(path)
}

wg.Wait()
close(ch)

expectedOkCount := 5
for result := range ch {
if result.okCount > expectedOkCount {
t.Errorf(
"%s allowed %d requests and limited %d requests, but expected limiting after %d requests\n",
result.path, result.okCount, result.limitedCount, expectedOkCount,
)
if result.okCount != tc.expectedOK {
t.Errorf("%s: Expected %d successful requests, got %d",
result.path, tc.expectedOK, result.okCount)
}
if result.limitedCount != tc.expectedLimit {
t.Errorf("%s: Expected %d rate limited requests, got %d",
result.path, tc.expectedLimit, result.limitedCount)
}
}
}

func makeCalls(t *testing.T, path string, ch chan rateResult) {
func makeCalls(t *testing.T, path string, ch chan rateResult, tc rateTestCase) {
var okCount int
var limitedCount int
client := &http.Client{}

for i := 0; i < 100; i++ {
req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d%s", 8081, path), nil)

// Make 100 requests
for range 100 {
requestURL := fmt.Sprintf("http://localhost:%d%s", 8081, path)
res, err := http.Get(requestURL)
// Set test case headers
for key, value := range tc.headers {
req.Header.Set(key, value)
}

res, err := client.Do(req)
if err != nil {
t.Errorf("Get request failed on %s: %s\n", path, err)
os.Exit(1)
t.Errorf("Request failed on %s: %s\n", path, err)
return
}

if res.StatusCode == 200 {
Expand All @@ -76,7 +136,6 @@ func makeCalls(t *testing.T, path string, ch chan rateResult) {
t.Errorf("Expected a 200 or 429 status code, but received a %d\n", res.StatusCode)
}

// Wait before making next call
time.Sleep(5 * time.Millisecond)
}

Expand Down
Loading

0 comments on commit 81a446d

Please sign in to comment.