From 3c61a9521b5718754e34600f86a7397b4f8d1856 Mon Sep 17 00:00:00 2001 From: Robert Brennan Date: Tue, 19 Nov 2024 13:46:14 -0500 Subject: [PATCH] Simple initial rate limiting implementation (#4976) --- openhands/server/listen.py | 11 ++++++- openhands/server/middleware.py | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 929a26ec987d..433b13bde208 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -64,7 +64,12 @@ from openhands.llm import bedrock from openhands.runtime.base import Runtime from openhands.server.auth.auth import get_sid_from_token, sign_token -from openhands.server.middleware import LocalhostCORSMiddleware, NoCacheMiddleware +from openhands.server.middleware import ( + InMemoryRateLimiter, + LocalhostCORSMiddleware, + NoCacheMiddleware, + RateLimitMiddleware, +) from openhands.server.session import SessionManager load_dotenv() @@ -84,6 +89,10 @@ app.add_middleware(NoCacheMiddleware) +app.add_middleware( + RateLimitMiddleware, rate_limiter=InMemoryRateLimiter(requests=2, seconds=1) +) + security_scheme = HTTPBearer() diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 218a949fca58..872241fc865f 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -1,6 +1,11 @@ +import asyncio +from collections import defaultdict +from datetime import datetime, timedelta from urllib.parse import urlparse +from fastapi import Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp @@ -41,3 +46,55 @@ async def dispatch(self, request, call_next): response.headers['Pragma'] = 'no-cache' response.headers['Expires'] = '0' return response + + +class InMemoryRateLimiter: + history: dict + requests: int + seconds: int + sleep_seconds: int + + def __init__(self, requests: int = 2, seconds: int = 1, sleep_seconds: int = 1): + self.requests = requests + self.seconds = seconds + self.history = defaultdict(list) + + def _clean_old_requests(self, key: str) -> None: + now = datetime.now() + cutoff = now - timedelta(seconds=self.seconds) + self.history[key] = [ts for ts in self.history[key] if ts > cutoff] + + async def __call__(self, request: Request) -> bool: + key = request.client.host + now = datetime.now() + + self._clean_old_requests(key) + + self.history[key].append(now) + + if len(self.history[key]) > self.requests * 2: + return False + elif len(self.history[key]) > self.requests: + if self.sleep_seconds > 0: + await asyncio.sleep(self.sleep_seconds) + return True + else: + return False + + return True + + +class RateLimitMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, rate_limiter: InMemoryRateLimiter): + super().__init__(app) + self.rate_limiter = rate_limiter + + async def dispatch(self, request, call_next): + ok = await self.rate_limiter(request) + if not ok: + return JSONResponse( + status_code=429, + content={'message': 'Too many requests'}, + headers={'Retry-After': '1'}, + ) + return await call_next(request)