Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add violation tracking and account suspension #150

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/filters/rate_limit_filter_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def log_request(self, user_id: str):
if user_id not in self.user_requests:
self.user_requests[user_id] = []
self.user_requests[user_id].append(now)
self.increment_violation(user_id) # Increment violation counter if rate limited

def rate_limited(self, user_id: str) -> bool:
"""Check if a user is rate limited."""
Expand All @@ -99,16 +100,19 @@ def rate_limited(self, user_id: str) -> bool:
if self.valves.requests_per_minute is not None:
requests_last_minute = sum(1 for req in user_reqs if time.time() - req < 60)
if requests_last_minute >= self.valves.requests_per_minute:
self.handle_violation(user_id)
return True

if self.valves.requests_per_hour is not None:
requests_last_hour = sum(1 for req in user_reqs if time.time() - req < 3600)
if requests_last_hour >= self.valves.requests_per_hour:
self.handle_violation(user_id)
return True

if self.valves.sliding_window_limit is not None:
requests_in_window = len(user_reqs)
if requests_in_window >= self.valves.sliding_window_limit:
self.handle_violation(user_id)
return True

return False
Expand All @@ -124,4 +128,33 @@ async def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
raise Exception("Rate limit exceeded. Please try again later.")

self.log_request(user_id)
self.check_violations(user_id) # Check for violations and issue warnings or suspend account if necessary
return body


def increment_violation(self, user_id: str):
"""Increment the violation counter for a user."""
if user_id not in self.user_violations:
self.user_violations[user_id] = 0
self.user_violations[user_id] += 1

def handle_violation(self, user_id: str):
"""Handle a violation by a user."""
self.increment_violation(user_id)
self.check_violations(user_id)

def check_violations(self, user_id: str):
"""Check the user's violation count and issue warnings or suspend the account if necessary."""
violation_count = self.user_violations.get(user_id, 0)
if violation_count >= 5:
self.suspend_account(user_id)
elif violation_count >= 3:
self.issue_warning(user_id)

def issue_warning(self, user_id: str):
"""Issue a warning to the user."""
print(f"Warning issued to user {user_id}. Please adhere to the rate limits.")

def suspend_account(self, user_id: str):
"""Suspend the user's account."""
print(f"User {user_id} has been suspended due to excessive violations.")