diff --git a/examples/filters/rate_limit_filter_pipeline.py b/examples/filters/rate_limit_filter_pipeline.py index d1e88236..49b2eea3 100644 --- a/examples/filters/rate_limit_filter_pipeline.py +++ b/examples/filters/rate_limit_filter_pipeline.py @@ -89,6 +89,7 @@ def log_request(self, user_id: str): if user_id not in self.user_requests: self.user_requests[user_id] = [] self.user_requests[user_id].append(now) + self.increment_violation(user_id) # Increment violation counter if rate limited def rate_limited(self, user_id: str) -> bool: """Check if a user is rate limited.""" @@ -99,16 +100,19 @@ def rate_limited(self, user_id: str) -> bool: if self.valves.requests_per_minute is not None: requests_last_minute = sum(1 for req in user_reqs if time.time() - req < 60) if requests_last_minute >= self.valves.requests_per_minute: + self.handle_violation(user_id) return True if self.valves.requests_per_hour is not None: requests_last_hour = sum(1 for req in user_reqs if time.time() - req < 3600) if requests_last_hour >= self.valves.requests_per_hour: + self.handle_violation(user_id) return True if self.valves.sliding_window_limit is not None: requests_in_window = len(user_reqs) if requests_in_window >= self.valves.sliding_window_limit: + self.handle_violation(user_id) return True return False @@ -124,4 +128,33 @@ async def inlet(self, body: dict, user: Optional[dict] = None) -> dict: raise Exception("Rate limit exceeded. Please try again later.") self.log_request(user_id) + self.check_violations(user_id) # Check for violations and issue warnings or suspend account if necessary return body + + + def increment_violation(self, user_id: str): + """Increment the violation counter for a user.""" + if user_id not in self.user_violations: + self.user_violations[user_id] = 0 + self.user_violations[user_id] += 1 + + def handle_violation(self, user_id: str): + """Handle a violation by a user.""" + self.increment_violation(user_id) + self.check_violations(user_id) + + def check_violations(self, user_id: str): + """Check the user's violation count and issue warnings or suspend the account if necessary.""" + violation_count = self.user_violations.get(user_id, 0) + if violation_count >= 5: + self.suspend_account(user_id) + elif violation_count >= 3: + self.issue_warning(user_id) + + def issue_warning(self, user_id: str): + """Issue a warning to the user.""" + print(f"Warning issued to user {user_id}. Please adhere to the rate limits.") + + def suspend_account(self, user_id: str): + """Suspend the user's account.""" + print(f"User {user_id} has been suspended due to excessive violations.")