Skip to content

Commit

Permalink
Merge pull request #2689 from onaio/fair-throttling-for-enketo-users
Browse files Browse the repository at this point in the history
Add CustomScopedRateThrolle throttling class
  • Loading branch information
FrankApiyo authored Aug 30, 2024
2 parents f42ddb6 + c9312f8 commit a69d9f2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
37 changes: 30 additions & 7 deletions onadata/libs/tests/test_throttle.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,43 @@
from django.core.cache import cache
from django.contrib.auth.models import AnonymousUser, User
from django.test import TestCase, override_settings

from rest_framework.test import APIRequestFactory

from onadata.libs.throttle import RequestHeaderThrottle
from onadata.libs.throttle import RequestHeaderThrottle, CustomScopedRateThrottle

class CustomScopedRateThrottleTest(TestCase):
def setUp(self):
# Reset the cache so that no throttles will be active
cache.clear()
self.factory = APIRequestFactory()
self.throttle = CustomScopedRateThrottle()

def test_anonymous_users(self):
"""Anonymous users get throttled base on URI path"""
request = self.factory.get("/enketo/1234/submission")
request.user = AnonymousUser()
self.throttle.scope = "submission"
cache_key = self.throttle.get_cache_key(request, None)
self.assertEqual(
cache_key,
"throttle_submission_/enketo/1234/submission_127.0.0.1"
)

def test_authenticated_users(self):
"""Authenticated users get throttled base on user id"""
request = self.factory.get("/enketo/1234/submission")
user, _created = User.objects.get_or_create(username='throttleduser')
request.user = user
self.throttle.scope = "submission"
cache_key = self.throttle.get_cache_key(request, None)
self.assertEqual(cache_key, f"throttle_submission_{user.id}")


class ThrottlingTests(TestCase):
"""
Test Renderer class.
"""

def setUp(self):
"""
Reset the cache so that no throttles will be active
"""
# Reset the cache so that no throttles will be active
cache.clear()
self.factory = APIRequestFactory()
self.throttle = RequestHeaderThrottle()
Expand Down
14 changes: 13 additions & 1 deletion onadata/libs/throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@

from django.conf import settings

from rest_framework.throttling import SimpleRateThrottle
from rest_framework.throttling import SimpleRateThrottle, ScopedRateThrottle


class CustomScopedRateThrottle(ScopedRateThrottle):
"""
Custom throttling for fair throttling for anonymous users sharing IP
"""

def get_cache_key(self, request, view):
if request.user and request.user.is_authenticated:
return super().get_cache_key(request, view)

return f'throttle_{self.scope}_{request.path}_{self.get_ident(request)}'


class RequestHeaderThrottle(SimpleRateThrottle):
Expand Down

0 comments on commit a69d9f2

Please sign in to comment.