diff --git a/onadata/libs/tests/test_throttle.py b/onadata/libs/tests/test_throttle.py index c7b2c2a8e0..5f084d9fc6 100644 --- a/onadata/libs/tests/test_throttle.py +++ b/onadata/libs/tests/test_throttle.py @@ -1,20 +1,43 @@ from django.core.cache import cache +from django.contrib.auth.models import AnonymousUser, User from django.test import TestCase, override_settings from rest_framework.test import APIRequestFactory -from onadata.libs.throttle import RequestHeaderThrottle +from onadata.libs.throttle import RequestHeaderThrottle, CustomScopedRateThrottle + +class CustomScopedRateThrottleTest(TestCase): + def setUp(self): + # Reset the cache so that no throttles will be active + cache.clear() + self.factory = APIRequestFactory() + self.throttle = CustomScopedRateThrottle() + + def test_anonymous_users(self): + """Anonymous users get throttled base on URI path""" + request = self.factory.get("/enketo/1234/submission") + request.user = AnonymousUser() + self.throttle.scope = "submission" + cache_key = self.throttle.get_cache_key(request, None) + self.assertEqual( + cache_key, + "throttle_submission_/enketo/1234/submission_127.0.0.1" + ) + + def test_authenticated_users(self): + """Authenticated users get throttled base on user id""" + request = self.factory.get("/enketo/1234/submission") + user, _created = User.objects.get_or_create(username='throttleduser') + request.user = user + self.throttle.scope = "submission" + cache_key = self.throttle.get_cache_key(request, None) + self.assertEqual(cache_key, f"throttle_submission_{user.id}") class ThrottlingTests(TestCase): - """ - Test Renderer class. - """ def setUp(self): - """ - Reset the cache so that no throttles will be active - """ + # Reset the cache so that no throttles will be active cache.clear() self.factory = APIRequestFactory() self.throttle = RequestHeaderThrottle() diff --git a/onadata/libs/throttle.py b/onadata/libs/throttle.py index ff79442a0b..9af47dffdc 100644 --- a/onadata/libs/throttle.py +++ b/onadata/libs/throttle.py @@ -4,7 +4,19 @@ from django.conf import settings -from rest_framework.throttling import SimpleRateThrottle +from rest_framework.throttling import SimpleRateThrottle, ScopedRateThrottle + + +class CustomScopedRateThrottle(ScopedRateThrottle): + """ + Custom throttling for fair throttling for anonymous users sharing IP + """ + + def get_cache_key(self, request, view): + if request.user and request.user.is_authenticated: + return super().get_cache_key(request, view) + + return f'throttle_{self.scope}_{request.path}_{self.get_ident(request)}' class RequestHeaderThrottle(SimpleRateThrottle):