Skip to content

Commit

Permalink
Refactor ip based rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
lmoe committed Oct 27, 2023
1 parent 9efd991 commit 913bdb6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 21 deletions.
3 changes: 1 addition & 2 deletions packages/webapi/services/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/ethereum/go-ethereum/rpc"
"github.com/labstack/echo/v4"
"golang.org/x/time/rate"

hivedb "github.com/iotaledger/hive.go/kvstore/database"
"github.com/iotaledger/hive.go/logger"
Expand Down Expand Up @@ -125,7 +124,7 @@ func (e *EVMService) getWebsocketContext(chainID isc.ChainID) *websocketContext
syncPool: new(sync.Pool),
jsonRPCParams: e.jsonrpcParams,
rateLimiterMutex: sync.Mutex{},
rateLimiter: map[string]*rate.Limiter{},
rateLimiter: map[string]*activityRateLimiter{},
}

return e.websocketContexts[chainID]
Expand Down
91 changes: 72 additions & 19 deletions packages/webapi/services/evm_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,88 @@ This way we can inject our rate limiting logic inside.
The downside is that it's not possible to limit per JSON message, but it should work regardless.
*/

type activityRateLimiter struct {
rateLimiter *rate.Limiter
lastActivity time.Time
}

func newActivityRateLimiter(rateLimiter *rate.Limiter) *activityRateLimiter {
return &activityRateLimiter{
rateLimiter: rateLimiter,
lastActivity: time.Now(),
}
}

func (a *activityRateLimiter) Allow() bool {
a.lastActivity = time.Now()
allowed := a.rateLimiter.Allow()
fmt.Printf("Rate token deducted. Remaining tokens:[%v]\n", a.Tokens())
return allowed
}

func (a *activityRateLimiter) LastActivity() time.Time {
return a.lastActivity
}

func (a *activityRateLimiter) UpdateLastActivity() {
a.lastActivity = time.Now()
}

func (a *activityRateLimiter) Tokens() float64 {
return a.rateLimiter.TokensAt(time.Now())
}

type websocketContext struct {
rateLimiterMutex sync.Mutex
rateLimiter map[string]*rate.Limiter
syncPool *sync.Pool
jsonRPCParams *jsonrpc.Parameters
rateLimiter map[string]*activityRateLimiter

syncPool *sync.Pool
jsonRPCParams *jsonrpc.Parameters
}

func (w *websocketContext) getRateLimiter(remoteIP string) *rate.Limiter {
func (w *websocketContext) getRateLimiter(remoteIP string) *activityRateLimiter {
w.rateLimiterMutex.Lock()
defer w.rateLimiterMutex.Unlock()

if w.rateLimiter[remoteIP] != nil {
return w.rateLimiter[remoteIP]
}

w.rateLimiter[remoteIP] = rate.NewLimiter(rate.Every(time.Minute), w.jsonRPCParams.WebsocketRateLimitMessagesPerMinute)
limiter := rate.NewLimiter(rate.Limit(w.jsonRPCParams.WebsocketRateLimitMessagesPerMinute), 1)
w.rateLimiter[remoteIP] = newActivityRateLimiter(limiter)

return w.rateLimiter[remoteIP]
}

//nolint:unused
func (w *websocketContext) deleteRateLimiter(remoteIP string) {
func (w *websocketContext) cleanupRateLimiters() {
w.rateLimiterMutex.Lock()
defer w.rateLimiterMutex.Unlock()

if w.rateLimiter[remoteIP] != nil {
delete(w.rateLimiter, remoteIP)
for ip, rateLimiter := range w.rateLimiter {
fmt.Printf("Found rate limiter for ip:[%v], lastActivity:[%v]\n", ip, rateLimiter.LastActivity().Format(time.RFC822))

if time.Since(rateLimiter.LastActivity()) > 30*time.Minute {
fmt.Printf("Removing rate limiter for ip:[%v], lastActivity:[%v]\n", ip, rateLimiter.LastActivity().Format(time.RFC822))
delete(w.rateLimiter, ip)
} else {
fmt.Printf("Keeping rate limiter for ip:[%v], lastActivity:[%v]\n", ip, rateLimiter.LastActivity().Format(time.RFC822))
}
}
}

type rateLimitedConn struct {
net.Conn
logger *logger.Logger
limiter *rate.Limiter
limiter *activityRateLimiter
realIP string
}

func newRateLimitedConn(conn net.Conn, logger *logger.Logger, r *rate.Limiter) *rateLimitedConn {
func newRateLimitedConn(conn net.Conn, logger *logger.Logger, r *activityRateLimiter, realIP string) *rateLimitedConn {
return &rateLimitedConn{
Conn: conn,
logger: logger,
limiter: r,
realIP: realIP,
}
}

Expand All @@ -77,8 +118,10 @@ func (rlc *rateLimitedConn) Read(b []byte) (int, error) {
return n, err
}

fmt.Printf("READ MSG: %v\n", string(b))

if !rlc.limiter.Allow() {
rlc.logger.Info("[EVM WS Conn] rate limit exceeded")
rlc.logger.Infof("[EVM WS Conn] rate limit exceeded for ip:[%v]", rlc.realIP)
return 0, rlc.Conn.Close()
}

Expand All @@ -92,11 +135,12 @@ func (rlc *rateLimitedConn) Write(b []byte) (int, error) {
type rateLimitedEchoResponse struct {
*echo.Response
logger *logger.Logger
limiter *rate.Limiter
limiter *activityRateLimiter
realIP string
}

func newRateLimitedEchoResponse(r *echo.Response, logger *logger.Logger, limiter *rate.Limiter) *rateLimitedEchoResponse {
return &rateLimitedEchoResponse{r, logger, limiter}
func newRateLimitedEchoResponse(r *echo.Response, logger *logger.Logger, limiter *activityRateLimiter, realIP string) *rateLimitedEchoResponse {
return &rateLimitedEchoResponse{r, logger, limiter, realIP}
}

func (r rateLimitedEchoResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) {
Expand All @@ -108,12 +152,12 @@ func (r rateLimitedEchoResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) {
buffer.Reader = bufio.NewReader(conn)
buffer.Writer = bufio.NewWriter(conn)

return newRateLimitedConn(conn, r.logger, r.limiter), buffer, err
return newRateLimitedConn(conn, r.logger, r.limiter, r.realIP), buffer, err
}

const (
readBufferSize = 1024
writeBufferSize = 1024
readBufferSize = 4096
writeBufferSize = 4096
)

func websocketHandler(logger *logger.Logger, server *chainServer, wsContext *websocketContext, realIP string) http.Handler {
Expand All @@ -127,8 +171,11 @@ func websocketHandler(logger *logger.Logger, server *chainServer, wsContext *web
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsContext.cleanupRateLimiters()

rateLimiter := wsContext.getRateLimiter(realIP)
if !rateLimiter.Allow() {
logger.Info("[EVM WS Conn] Connection from ip:[%v] dropped (previous rate limit exceeded) current tokens:[%v]\n", realIP, rateLimiter.Tokens())
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
Expand All @@ -139,14 +186,20 @@ func websocketHandler(logger *logger.Logger, server *chainServer, wsContext *web
return
}

rateLimitedResponseWriter := newRateLimitedEchoResponse(echoResponse, logger, rateLimiter)
rateLimitedResponseWriter := newRateLimitedEchoResponse(echoResponse, logger, rateLimiter, realIP)
conn, err := upgrader.Upgrade(rateLimitedResponseWriter, r, nil)
if err != nil {
logger.Info(fmt.Sprintf("[EVM WS] %s", err))
return
}

codec := rpc.NewWebSocketCodec(conn, r.Host, r.Header)
conn.SetPongHandler(func(appData string) error {
_ = conn.SetReadDeadline(time.Time{})
rateLimiter.UpdateLastActivity()
return nil
})

server.rpc.ServeCodec(codec, 0)
})
}

0 comments on commit 913bdb6

Please sign in to comment.