From f56f7310e81ac0266807192803e4a451a8983766 Mon Sep 17 00:00:00 2001 From: Dhruv Bhanushali Date: Wed, 20 Dec 2023 10:32:45 +0400 Subject: [PATCH] Make usages of Redis resilient to absence of Redis (#3505) Co-authored-by: sarayourfriend <24264157+sarayourfriend@users.noreply.github.com> --- api/api/controllers/search_controller.py | 55 ++++++-- api/api/utils/check_dead_links/__init__.py | 18 ++- api/api/utils/dead_link_mask.py | 20 ++- api/api/utils/image_proxy/__init__.py | 51 ++++++-- api/api/utils/image_proxy/extension.py | 25 +++- api/api/utils/tallies.py | 12 +- api/api/utils/throttle.py | 12 +- api/api/views/oauth2_views.py | 26 +++- api/conf/settings/__init__.py | 1 + api/conf/settings/caches.py | 29 +++++ api/conf/settings/databases.py | 32 ----- api/test/conftest.py | 71 +++------- api/test/factory/models/content_provider.py | 8 ++ api/test/fixtures/__init__.py | 0 api/test/fixtures/asynchronous.py | 53 ++++++++ api/test/fixtures/cache.py | 68 ++++++++++ api/test/test_auth.py | 60 ++++++++- api/test/test_dead_link_filter.py | 14 -- api/test/unit/conftest.py | 14 -- .../controllers/elasticsearch/test_related.py | 7 +- .../controllers/test_search_controller.py | 113 +++++++++++++--- .../test_search_controller_search_query.py | 7 +- api/test/unit/utils/conftest.py | 14 -- api/test/unit/utils/test_check_dead_links.py | 41 +++++- api/test/unit/utils/test_image_proxy.py | 121 ++++++++++++++---- api/test/unit/utils/test_tallies.py | 15 +++ api/test/unit/utils/test_throttle.py | 94 +++++++++++--- 27 files changed, 738 insertions(+), 243 deletions(-) create mode 100644 api/conf/settings/caches.py create mode 100644 api/test/factory/models/content_provider.py create mode 100644 api/test/fixtures/__init__.py create mode 100644 api/test/fixtures/asynchronous.py create mode 100644 api/test/fixtures/cache.py diff --git a/api/api/controllers/search_controller.py b/api/api/controllers/search_controller.py index f84cfc27254..351712dfc38 100644 --- a/api/api/controllers/search_controller.py +++ b/api/api/controllers/search_controller.py @@ -14,6 +14,7 @@ from elasticsearch_dsl import Q, Search from elasticsearch_dsl.query import EMPTY_QUERY from elasticsearch_dsl.response import Hit, Response +from redis.exceptions import ConnectionError import api.models as models from api.constants.media_types import OriginIndex, SearchIndex @@ -184,21 +185,33 @@ def get_excluded_providers_query() -> Q | None: `:FILTERED_PROVIDERS_CACHE_VERSION:FILTERED_PROVIDERS_CACHE_KEY` key. """ - filtered_providers = cache.get( - key=FILTERED_PROVIDERS_CACHE_KEY, version=FILTERED_PROVIDERS_CACHE_VERSION - ) + logger = module_logger.getChild("get_excluded_providers_query") + + try: + filtered_providers = cache.get( + key=FILTERED_PROVIDERS_CACHE_KEY, version=FILTERED_PROVIDERS_CACHE_VERSION + ) + except ConnectionError: + logger.warning("Redis connect failed, cannot get cached filtered providers.") + filtered_providers = None + if not filtered_providers: filtered_providers = list( models.ContentProvider.objects.filter(filter_content=True).values_list( "provider_identifier", flat=True ) ) - cache.set( - key=FILTERED_PROVIDERS_CACHE_KEY, - version=FILTERED_PROVIDERS_CACHE_VERSION, - timeout=FILTER_CACHE_TIMEOUT, - value=filtered_providers, - ) + + try: + cache.set( + key=FILTERED_PROVIDERS_CACHE_KEY, + version=FILTERED_PROVIDERS_CACHE_VERSION, + timeout=FILTER_CACHE_TIMEOUT, + value=filtered_providers, + ) + except ConnectionError: + logger.warning("Redis connect failed, cannot cache filtered providers.") + if filtered_providers: return Q("terms", provider=filtered_providers) return None @@ -567,9 +580,19 @@ def get_sources(index): cache_fetch_failed = True sources = None log.warning("Source cache fetch failed due to corruption") + except ConnectionError: + cache_fetch_failed = True + sources = None + log.warning("Redis connect failed, cannot get cached sources.") + if isinstance(sources, list) or cache_fetch_failed: - # Invalidate old provider format. - cache.delete(key=source_cache_name) + sources = None + try: + # Invalidate old provider format. + cache.delete(key=source_cache_name) + except ConnectionError: + log.warning("Redis connect failed, cannot invalidate cached sources.") + if not sources: # Don't increase `size` without reading this issue first: # https://github.com/elastic/elasticsearch/issues/18838 @@ -597,7 +620,15 @@ def get_sources(index): except NotFoundError: buckets = [{"key": "none_found", "doc_count": 0}] sources = {result["key"]: result["doc_count"] for result in buckets} - cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources) + + try: + cache.set( + key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources + ) + except ConnectionError: + log.warning("Redis connect failed, cannot cache sources.") + + sources = {source: int(doc_count) for source, doc_count in sources.items()} return sources diff --git a/api/api/utils/check_dead_links/__init__.py b/api/api/utils/check_dead_links/__init__.py index d7603915e81..219e60480a9 100644 --- a/api/api/utils/check_dead_links/__init__.py +++ b/api/api/utils/check_dead_links/__init__.py @@ -9,6 +9,7 @@ from asgiref.sync import async_to_sync from decouple import config from elasticsearch_dsl.response import Hit +from redis.exceptions import ConnectionError from api.utils.aiohttp import get_aiohttp_session from api.utils.check_dead_links.provider_status_mappings import provider_status_mappings @@ -25,8 +26,15 @@ def _get_cached_statuses(redis, image_urls): - cached_statuses = redis.mget([CACHE_PREFIX + url for url in image_urls]) - return [int(b.decode("utf-8")) if b is not None else None for b in cached_statuses] + try: + cached_statuses = redis.mget([CACHE_PREFIX + url for url in image_urls]) + return [ + int(b.decode("utf-8")) if b is not None else None for b in cached_statuses + ] + except ConnectionError: + logger = parent_logger.getChild("_get_cached_statuses") + logger.warning("Redis connect failed, validating all URLs without cache.") + return [None] * len(image_urls) def _get_expiry(status, default): @@ -51,7 +59,6 @@ async def _head(url: str) -> tuple[str, int]: # https://stackoverflow.com/q/55259755 @async_to_sync async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]: - tasks = [] tasks = [asyncio.ensure_future(_head(url)) for url in urls] responses = asyncio.gather(*tasks) await responses @@ -111,7 +118,10 @@ def check_dead_links( logger.debug(f"caching status={status} expiry={expiry}") pipe.expire(key, expiry) - pipe.execute() + try: + pipe.execute() + except ConnectionError: + logger.warning("Redis connect failed, cannot cache link liveness.") # Merge newly verified results with cached statuses for idx, url in enumerate(to_verify): diff --git a/api/api/utils/dead_link_mask.py b/api/api/utils/dead_link_mask.py index fdd1034b0fc..1791daf1673 100644 --- a/api/api/utils/dead_link_mask.py +++ b/api/api/utils/dead_link_mask.py @@ -1,6 +1,12 @@ +import logging + import django_redis from deepdiff import DeepHash from elasticsearch_dsl import Search +from redis.exceptions import ConnectionError + + +parent_logger = logging.getLogger(__name__) # 3 hours minutes (in seconds) @@ -34,7 +40,12 @@ def get_query_mask(query_hash: str) -> list[int]: """ redis = django_redis.get_redis_connection("default") key = f"{query_hash}:dead_link_mask" - return list(map(int, redis.lrange(key, 0, -1))) + try: + return list(map(int, redis.lrange(key, 0, -1))) + except ConnectionError: + logger = parent_logger.getChild("get_query_mask") + logger.warning("Redis connect failed, cannot get cached query mask.") + return [] def save_query_mask(query_hash: str, mask: list): @@ -50,4 +61,9 @@ def save_query_mask(query_hash: str, mask: list): redis_pipe.delete(key) redis_pipe.rpush(key, *mask) redis_pipe.expire(key, DEAD_LINK_MASK_TTL) - redis_pipe.execute() + + try: + redis_pipe.execute() + except ConnectionError: + logger = parent_logger.getChild("save_query_mask") + logger.warning("Redis connect failed, cannot cache query mask.") diff --git a/api/api/utils/image_proxy/__init__.py b/api/api/utils/image_proxy/__init__.py index 54a796b2179..03da60f1888 100644 --- a/api/api/utils/image_proxy/__init__.py +++ b/api/api/utils/image_proxy/__init__.py @@ -1,3 +1,4 @@ +import itertools import logging from dataclasses import dataclass from typing import Literal @@ -12,6 +13,7 @@ import sentry_sdk from aiohttp.client_exceptions import ClientResponseError from asgiref.sync import sync_to_async +from redis.exceptions import ConnectionError from sentry_sdk import push_scope, set_context from api.utils.aiohttp import get_aiohttp_session @@ -24,6 +26,8 @@ parent_logger = logging.getLogger(__name__) +exception_iterator = itertools.count() + HEADERS = { "User-Agent": settings.OUTBOUND_USER_AGENT_TEMPLATE.format( purpose="ThumbnailGeneration" @@ -81,7 +85,7 @@ def get_request_params_for_extension( @sync_to_async def _tally_response( - tallies, + tallies_conn, media_info: MediaInfo, month: str, domain: str, @@ -93,14 +97,33 @@ def _tally_response( Pulled into a separate function to help reduce overload when skimming the `get` function, which is complex enough as is. """ - tallies.incr(f"thumbnail_response_code:{month}:{response.status}") - tallies.incr( - f"thumbnail_response_code_by_domain:{domain}:" f"{month}:{response.status}" - ) - tallies.incr( - f"thumbnail_response_code_by_provider:{media_info.media_provider}:" - f"{month}:{response.status}" - ) + + logger = parent_logger.getChild("_tally_response") + + with tallies_conn.pipeline() as tallies: + tallies.incr(f"thumbnail_response_code:{month}:{response.status}") + tallies.incr( + f"thumbnail_response_code_by_domain:{domain}:" f"{month}:{response.status}" + ) + tallies.incr( + f"thumbnail_response_code_by_provider:{media_info.media_provider}:" + f"{month}:{response.status}" + ) + try: + tallies.execute() + except ConnectionError: + logger.warning( + "Redis connect failed, thumbnail response codes not tallied." + ) + + +@sync_to_async +def _tally_client_response_errors(tallies, month: str, domain: str, status: int): + logger = parent_logger.getChild("_tally_client_response_errors") + try: + tallies.incr(f"thumbnail_http_error:{domain}:{month}:{status}") + except ConnectionError: + logger.warning("Redis connect failed, thumbnail HTTP errors not tallied.") _UPSTREAM_TIMEOUT = aiohttp.ClientTimeout(15) @@ -168,7 +191,13 @@ async def get( except Exception as exc: exception_name = f"{exc.__class__.__module__}.{exc.__class__.__name__}" key = f"thumbnail_error:{exception_name}:{domain}:{month}" - count = await tallies_incr(key) + try: + count = await tallies_incr(key) + except ConnectionError: + logger.warning("Redis connect failed, thumbnail errors not tallied.") + # We will use a counter to space out Sentry logs. + count = next(exception_iterator) + if count <= settings.THUMBNAIL_ERROR_INITIAL_ALERT_THRESHOLD or ( count % settings.THUMBNAIL_ERROR_REPEATED_ALERT_FREQUENCY == 0 ): @@ -188,7 +217,7 @@ async def get( if isinstance(exc, ClientResponseError): status = exc.status do_not_wait_for( - tallies_incr(f"thumbnail_http_error:{domain}:{month}:{status}") + _tally_client_response_errors(tallies, month, domain, status) ) logger.warning( f"Failed to render thumbnail " diff --git a/api/api/utils/image_proxy/extension.py b/api/api/utils/image_proxy/extension.py index 2f2612022c3..cb9fad261f2 100644 --- a/api/api/utils/image_proxy/extension.py +++ b/api/api/utils/image_proxy/extension.py @@ -1,3 +1,4 @@ +import logging from os.path import splitext from urllib.parse import urlparse @@ -5,16 +6,22 @@ import django_redis import sentry_sdk from asgiref.sync import sync_to_async +from redis.exceptions import ConnectionError from api.utils.aiohttp import get_aiohttp_session from api.utils.asyncio import do_not_wait_for from api.utils.image_proxy.exception import UpstreamThumbnailException +parent_logger = logging.getLogger(__name__) + + _HEAD_TIMEOUT = aiohttp.ClientTimeout(10) async def get_image_extension(image_url: str, media_identifier) -> str | None: + logger = parent_logger.getChild("get_image_extension") + cache = django_redis.get_redis_connection("default") key = f"media:{media_identifier}:thumb_type" @@ -22,8 +29,11 @@ async def get_image_extension(image_url: str, media_identifier) -> str | None: if not ext: # If the extension is not present in the URL, try to get it from the redis cache - ext = await sync_to_async(cache.get)(key) - ext = ext.decode("utf-8") if ext else None + try: + ext = await sync_to_async(cache.get)(key) + ext = ext.decode("utf-8") if ext else None + except ConnectionError: + logger.warning("Redis connect failed, cannot get cached image extension.") if not ext: # If the extension is still not present, try getting it from the content type @@ -37,7 +47,7 @@ async def get_image_extension(image_url: str, media_identifier) -> str | None: else: ext = None - do_not_wait_for(sync_to_async(cache.set)(key, ext if ext else "unknown")) + do_not_wait_for(_cache_extension(cache, key, ext)) except Exception as exc: sentry_sdk.capture_exception(exc) raise UpstreamThumbnailException( @@ -48,6 +58,15 @@ async def get_image_extension(image_url: str, media_identifier) -> str | None: return ext +@sync_to_async +def _cache_extension(cache, key, ext): + logger = parent_logger.getChild("cache_extension") + try: + cache.set(key, ext if ext else "unknown") + except ConnectionError: + logger.warning("Redis connect failed, cannot cache image extension.") + + def _get_file_extension_from_url(image_url: str) -> str: """Return the image extension if present in the URL.""" parsed = urlparse(image_url) diff --git a/api/api/utils/tallies.py b/api/api/utils/tallies.py index 6dd352a43a2..85710a38a70 100644 --- a/api/api/utils/tallies.py +++ b/api/api/utils/tallies.py @@ -1,8 +1,13 @@ +import logging from collections import defaultdict from datetime import datetime, timedelta import django_redis from django_redis.client.default import Redis +from redis.exceptions import ConnectionError + + +parent_logger = logging.getLogger(__name__) def get_weekly_timestamp() -> str: @@ -37,5 +42,8 @@ def count_provider_occurrences(results: list[dict], index: str) -> None: for provider, occurrences in provider_occurrences.items(): pipe.incr(f"provider_occurrences:{index}:{week}:{provider}", occurrences) pipe.incr(f"provider_appeared_in_searches:{index}:{week}:{provider}", 1) - - pipe.execute() + try: + pipe.execute() + except ConnectionError: + logger = parent_logger.getChild("count_provider_occurrences") + logger.warning("Redis connect failed, cannot increment provider tallies.") diff --git a/api/api/utils/throttle.py b/api/api/utils/throttle.py index 05213ed83da..4e3effaab34 100644 --- a/api/api/utils/throttle.py +++ b/api/api/utils/throttle.py @@ -3,6 +3,8 @@ from rest_framework.throttling import SimpleRateThrottle as BaseSimpleRateThrottle +from redis.exceptions import ConnectionError + from api.utils.oauth2_helper import get_token_info @@ -16,7 +18,12 @@ class SimpleRateThrottle(BaseSimpleRateThrottle, metaclass=abc.ABCMeta): """ def allow_request(self, request, view): - is_allowed = super().allow_request(request, view) + try: + is_allowed = super().allow_request(request, view) + except ConnectionError: + logger = parent_logger.getChild("allow_request") + logger.warning("Redis connect failed, allowing request.") + is_allowed = True view.headers |= self.headers() return is_allowed @@ -44,10 +51,9 @@ def has_valid_token(self, request): return token_info and token_info.valid def get_cache_key(self, request, view): - ident = self.get_ident(request) return self.cache_format % { "scope": self.scope, - "ident": ident, + "ident": self.get_ident(request), } diff --git a/api/api/views/oauth2_views.py b/api/api/views/oauth2_views.py index dc72cbfd76b..db4d780af72 100644 --- a/api/api/views/oauth2_views.py +++ b/api/api/views/oauth2_views.py @@ -14,6 +14,7 @@ from drf_spectacular.utils import extend_schema from oauth2_provider.generators import generate_client_secret from oauth2_provider.views import TokenView as BaseTokenView +from redis.exceptions import ConnectionError from api.docs.oauth2_docs import key_info, register, token from api.models import OAuth2Verification, ThrottledApplication @@ -25,6 +26,9 @@ from api.utils.throttle import OnePerSecond, TenPerDay +module_logger = log.getLogger(__name__) + + @extend_schema(tags=["auth"]) class Register(APIView): throttle_classes = (TenPerDay,) @@ -218,12 +222,20 @@ def get(self, request, format=None): # TODO: Replace 500 response with exception. return Response(status=500, data="Unknown API key rate limit type") - sustained_requests_list = cache.get(sustained_throttle_key) - sustained_requests = ( - len(sustained_requests_list) if sustained_requests_list else None - ) - burst_requests_list = cache.get(burst_throttle_key) - burst_requests = len(burst_requests_list) if burst_requests_list else None + try: + sustained_requests_list = cache.get(sustained_throttle_key) + sustained_requests = ( + len(sustained_requests_list) if sustained_requests_list else None + ) + burst_requests_list = cache.get(burst_throttle_key) + burst_requests = len(burst_requests_list) if burst_requests_list else None + status = 200 + except ConnectionError: + logger = module_logger.getChild("CheckRates.get") + logger.warning("Redis connect failed, cannot get key usage.") + burst_requests = None + sustained_requests = None + status = 424 response_data = OAuth2KeyInfoSerializer( { @@ -233,4 +245,4 @@ def get(self, request, format=None): "verified": token_info.verified, } ) - return Response(status=200, data=response_data.data) + return Response(status=status, data=response_data.data) diff --git a/api/conf/settings/__init__.py b/api/conf/settings/__init__.py index d3bf19d2d1f..2d5eb7bb70a 100644 --- a/api/conf/settings/__init__.py +++ b/api/conf/settings/__init__.py @@ -31,6 +31,7 @@ "static.py", # services "databases.py", + "caches.py", "elasticsearch.py", "email.py", "aws.py", diff --git a/api/conf/settings/caches.py b/api/conf/settings/caches.py new file mode 100644 index 00000000000..52cbdc005ce --- /dev/null +++ b/api/conf/settings/caches.py @@ -0,0 +1,29 @@ +from decouple import config + + +# Caches + +REDIS_HOST = config("REDIS_HOST", default="localhost") +REDIS_PORT = config("REDIS_PORT", default=6379, cast=int) +REDIS_PASSWORD = config("REDIS_PASSWORD", default="") + + +def _make_cache_config(dbnum: int, **overrides) -> dict: + return { + "BACKEND": "django_redis.cache.RedisCache", + "LOCATION": f"redis://{REDIS_HOST}:{REDIS_PORT}/{dbnum}", + "OPTIONS": { + "CLIENT_CLASS": "django_redis.client.DefaultClient", + } + | overrides.pop("OPTIONS", {}), + } | overrides + + +CACHES = { + # Site cache writes to 'default' + "default": _make_cache_config(0), + # Used for tracking tallied figures that shouldn't expire and are indexed + # with a timestamp range (for example, the key could a timestamp valid + # for a given week), allowing historical data analysis. + "tallies": _make_cache_config(3, TIMEOUT=None), +} diff --git a/api/conf/settings/databases.py b/api/conf/settings/databases.py index fb28e330a76..5473bcc66b1 100644 --- a/api/conf/settings/databases.py +++ b/api/conf/settings/databases.py @@ -25,35 +25,3 @@ }, } } - -# Caches - -REDIS_HOST = config("REDIS_HOST", default="localhost") -REDIS_PORT = config("REDIS_PORT", default=6379, cast=int) -REDIS_PASSWORD = config("REDIS_PASSWORD", default="") - - -def _make_cache_config(dbnum: int, **overrides) -> dict: - return { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": f"redis://{REDIS_HOST}:{REDIS_PORT}/{dbnum}", - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - } - | overrides.pop("OPTIONS", {}), - } | overrides - - -CACHES = { - # Site cache writes to 'default' - "default": _make_cache_config(0), - # For rapidly changing stats that we don't want to hammer the database with - "traffic_stats": _make_cache_config(1), - # For ensuring consistency among multiple Django workers and servers. - # Used by Redlock. - "locks": _make_cache_config(2), - # Used for tracking tallied figures that shouldn't expire and are indexed - # with a timestamp range (for example, the key could a timestamp valid - # for a given week), allowing historical data analysis. - "tallies": _make_cache_config(3, TIMEOUT=None), -} diff --git a/api/test/conftest.py b/api/test/conftest.py index 5277a134a29..8ffefc50e7a 100644 --- a/api/test/conftest.py +++ b/api/test/conftest.py @@ -1,53 +1,18 @@ -import asyncio - -import pytest - -from conf.asgi import application - - -@pytest.fixture -def get_new_loop(): - loops: list[asyncio.AbstractEventLoop] = [] - - def _get_new_loop() -> asyncio.AbstractEventLoop: - loop = asyncio.new_event_loop() - loops.append(loop) - return loop - - yield _get_new_loop - - for loop in loops: - loop.close() - - -@pytest.fixture(scope="session") -def session_loop() -> asyncio.AbstractEventLoop: - loop = asyncio.new_event_loop() - yield loop - loop.close() - - -@pytest.fixture(scope="session", autouse=True) -def ensure_asgi_lifecycle(session_loop: asyncio.AbstractEventLoop): - """ - Call application shutdown lifecycle event. - - This cannot be an async fixture because the scope is session - and pytest-asynio's `event_loop` fixture, which is auto-used - for async tests and fixtures, is function scoped, which is - incomatible with session scoped fixtures. `async_to_sync` works - fine here, so it's not a problem. - - This cannot yet call the startup signal due to: - https://github.com/illagrenan/django-asgi-lifespan/pull/80 - """ - scope = {"type": "lifespan"} - - async def noop(*args, **kwargs): - ... - - async def shutdown(): - return {"type": "lifespan.shutdown"} - - yield - session_loop.run_until_complete(application(scope, shutdown, noop)) +from test.fixtures.asynchronous import ensure_asgi_lifecycle, get_new_loop, session_loop +from test.fixtures.cache import ( + django_cache, + redis, + unreachable_django_cache, + unreachable_redis, +) + + +__all__ = [ + "ensure_asgi_lifecycle", + "get_new_loop", + "session_loop", + "django_cache", + "redis", + "unreachable_django_cache", + "unreachable_redis", +] diff --git a/api/test/factory/models/content_provider.py b/api/test/factory/models/content_provider.py new file mode 100644 index 00000000000..04871aee5ed --- /dev/null +++ b/api/test/factory/models/content_provider.py @@ -0,0 +1,8 @@ +from factory.django import DjangoModelFactory + +from api.models import ContentProvider + + +class ContentProviderFactory(DjangoModelFactory): + class Meta: + model = ContentProvider diff --git a/api/test/fixtures/__init__.py b/api/test/fixtures/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/api/test/fixtures/asynchronous.py b/api/test/fixtures/asynchronous.py new file mode 100644 index 00000000000..5277a134a29 --- /dev/null +++ b/api/test/fixtures/asynchronous.py @@ -0,0 +1,53 @@ +import asyncio + +import pytest + +from conf.asgi import application + + +@pytest.fixture +def get_new_loop(): + loops: list[asyncio.AbstractEventLoop] = [] + + def _get_new_loop() -> asyncio.AbstractEventLoop: + loop = asyncio.new_event_loop() + loops.append(loop) + return loop + + yield _get_new_loop + + for loop in loops: + loop.close() + + +@pytest.fixture(scope="session") +def session_loop() -> asyncio.AbstractEventLoop: + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session", autouse=True) +def ensure_asgi_lifecycle(session_loop: asyncio.AbstractEventLoop): + """ + Call application shutdown lifecycle event. + + This cannot be an async fixture because the scope is session + and pytest-asynio's `event_loop` fixture, which is auto-used + for async tests and fixtures, is function scoped, which is + incomatible with session scoped fixtures. `async_to_sync` works + fine here, so it's not a problem. + + This cannot yet call the startup signal due to: + https://github.com/illagrenan/django-asgi-lifespan/pull/80 + """ + scope = {"type": "lifespan"} + + async def noop(*args, **kwargs): + ... + + async def shutdown(): + return {"type": "lifespan.shutdown"} + + yield + session_loop.run_until_complete(application(scope, shutdown, noop)) diff --git a/api/test/fixtures/cache.py b/api/test/fixtures/cache.py new file mode 100644 index 00000000000..e4d8971bdba --- /dev/null +++ b/api/test/fixtures/cache.py @@ -0,0 +1,68 @@ +import pytest +from django_redis.cache import RedisCache +from fakeredis import FakeRedis, FakeServer + + +@pytest.fixture(autouse=True) +def redis(monkeypatch) -> FakeRedis: + """Emulate a Redis connection that does not affect the real cache.""" + + fake_redis = FakeRedis() + + def get_redis_connection(*args, **kwargs): + return fake_redis + + monkeypatch.setattr("django_redis.get_redis_connection", get_redis_connection) + yield fake_redis + fake_redis.client().close() + + +@pytest.fixture +def unreachable_redis(monkeypatch) -> FakeRedis: + """ + Emulate a broken Redis connection that does not affect the real cache. + + This fixture is useful for testing the resiliency of the API to withstand a + Redis outage. Attempts to read/write to this Redis instance will raise + ``ConnectionError``. + """ + + fake_server = FakeServer() + fake_server.connected = False + fake_redis = FakeRedis(server=fake_server) + + def get_redis_connection(*args, **kwargs): + return fake_redis + + monkeypatch.setattr("django_redis.get_redis_connection", get_redis_connection) + yield fake_redis + fake_server.connected = True + fake_redis.client().close() + + +@pytest.fixture(autouse=True) +def django_cache(redis, monkeypatch) -> RedisCache: + """Use the fake Redis fixture ``redis`` as Django's default cache.""" + + cache = RedisCache(" ", {}) + client = cache.client + client._clients = [redis] + monkeypatch.setattr("django.core.cache.cache", cache) + yield cache + + +@pytest.fixture +def unreachable_django_cache(unreachable_redis, monkeypatch) -> RedisCache: + """ + Use the fake Redis fixture ``unreachable_redis`` as Django's default cache. + + This fixture is useful for testing the resiliency of the API to withstand a + Redis outage. Attempts to read/write to this Redis instance will raise + ``ConnectionError``. + """ + + cache = RedisCache(" ", {}) + client = cache.client + client._clients = [unreachable_redis] + monkeypatch.setattr("django.core.cache.cache", cache) + yield cache diff --git a/api/test/test_auth.py b/api/test/test_auth.py index 80441b30c0f..6387b1818ed 100644 --- a/api/test/test_auth.py +++ b/api/test/test_auth.py @@ -11,6 +11,30 @@ from api.models import OAuth2Verification, ThrottledApplication +cache_availability_params = pytest.mark.parametrize( + "is_cache_reachable, cache_name", + [(True, "oauth_cache"), (False, "unreachable_oauth_cache")], +) +# This parametrize decorator runs the test function with two scenarios: +# - one where the API can connect to Redis +# - one where it cannot and raises ``ConnectionError`` +# The fixtures referenced here are defined below. + + +@pytest.fixture(autouse=True) +def oauth_cache(django_cache, monkeypatch): + cache = django_cache + monkeypatch.setattr("rest_framework.throttling.SimpleRateThrottle.cache", cache) + yield cache + + +@pytest.fixture +def unreachable_oauth_cache(unreachable_django_cache, monkeypatch): + cache = unreachable_django_cache + monkeypatch.setattr("api.views.oauth2_views.cache", cache) + yield cache + + @pytest.mark.django_db @pytest.fixture def test_auth_tokens_registration(client): @@ -74,6 +98,7 @@ def _integration_verify_most_recent_token(client): "rate_limit_model", [x[0] for x in ThrottledApplication.RATE_LIMIT_MODELS], ) +@cache_availability_params @pytest.mark.skipif( API_URL != "http://localhost:8000", reason=( @@ -82,11 +107,24 @@ def _integration_verify_most_recent_token(client): " that isn't possible." ), ) -def test_auth_email_verification(client, rate_limit_model, test_auth_token_exchange): +def test_auth_email_verification( + request, + client, + is_cache_reachable, + cache_name, + rate_limit_model, + test_auth_token_exchange, +): res = _integration_verify_most_recent_token(client) assert res.status_code == 200 test_auth_rate_limit_reporting( - client, rate_limit_model, test_auth_token_exchange, verified=True + request, + client, + is_cache_reachable, + cache_name, + rate_limit_model, + test_auth_token_exchange, + verified=True, ) @@ -95,9 +133,18 @@ def test_auth_email_verification(client, rate_limit_model, test_auth_token_excha "rate_limit_model", [x[0] for x in ThrottledApplication.RATE_LIMIT_MODELS], ) +@cache_availability_params def test_auth_rate_limit_reporting( - client, rate_limit_model, test_auth_token_exchange, verified=False + request, + client, + is_cache_reachable, + cache_name, + rate_limit_model, + test_auth_token_exchange, + verified=False, ): + request.getfixturevalue(cache_name) + # We're anonymous still, so we need to wait a second before exchanging # the token. time.sleep(1) @@ -107,6 +154,13 @@ def test_auth_rate_limit_reporting( application.save() res = client.get("/v1/rate_limit/", HTTP_AUTHORIZATION=f"Bearer {token}") res_data = res.json() + if is_cache_reachable: + assert res.status_code == 200 + else: + assert res.status_code == 424 + assert res_data["requests_this_minute"] is None + assert res_data["requests_today"] is None + if verified: assert res_data["rate_limit_model"] == rate_limit_model assert res_data["verified"] is True diff --git a/api/test/test_dead_link_filter.py b/api/test/test_dead_link_filter.py index 676f804faac..bb37d8ed95d 100644 --- a/api/test/test_dead_link_filter.py +++ b/api/test/test_dead_link_filter.py @@ -6,24 +6,10 @@ import pytest import requests -from fakeredis import FakeRedis from api.controllers.elasticsearch.helpers import DEAD_LINK_RATIO -@pytest.fixture(autouse=True) -def redis(monkeypatch) -> FakeRedis: - fake_redis = FakeRedis() - - def get_redis_connection(*args, **kwargs): - return fake_redis - - monkeypatch.setattr("django_redis.get_redis_connection", get_redis_connection) - - yield fake_redis - fake_redis.client().close() - - @pytest.fixture(autouse=True) def turn_off_db_read(monkeypatch): """ diff --git a/api/test/unit/conftest.py b/api/test/unit/conftest.py index 155daedf17f..213dafd0ca9 100644 --- a/api/test/unit/conftest.py +++ b/api/test/unit/conftest.py @@ -12,7 +12,6 @@ import pook import pytest from elasticsearch import Elasticsearch -from fakeredis import FakeRedis from api.models import ( Audio, @@ -40,19 +39,6 @@ ) -@pytest.fixture() -def redis(monkeypatch) -> FakeRedis: - fake_redis = FakeRedis() - - def get_redis_connection(*args, **kwargs): - return fake_redis - - monkeypatch.setattr("django_redis.get_redis_connection", get_redis_connection) - - yield fake_redis - fake_redis.client().close() - - @pytest.fixture def api_client(): return APIClient() diff --git a/api/test/unit/controllers/elasticsearch/test_related.py b/api/test/unit/controllers/elasticsearch/test_related.py index b05066c8017..8da59714ecf 100644 --- a/api/test/unit/controllers/elasticsearch/test_related.py +++ b/api/test/unit/controllers/elasticsearch/test_related.py @@ -6,8 +6,6 @@ from test.factory.models import ImageFactory from unittest import mock -from django.core.cache import cache - import pook import pytest @@ -22,7 +20,10 @@ @pytest.fixture -def excluded_providers_cache(): +def excluded_providers_cache(django_cache, monkeypatch): + cache = django_cache + monkeypatch.setattr("api.controllers.search_controller.cache", cache) + excluded_provider = "excluded_provider" cache_value = [excluded_provider] cache.set( diff --git a/api/test/unit/controllers/test_search_controller.py b/api/test/unit/controllers/test_search_controller.py index fc879317a84..652c43ea0f6 100644 --- a/api/test/unit/controllers/test_search_controller.py +++ b/api/test/unit/controllers/test_search_controller.py @@ -1,3 +1,4 @@ +import datetime import logging import random import re @@ -8,11 +9,11 @@ MOCK_LIVE_RESULT_URL_PREFIX, create_mock_es_http_image_search_response, ) +from test.factory.models.content_provider import ContentProviderFactory from unittest import mock +from unittest.mock import patch from uuid import uuid4 -from django.core.cache import cache - import pook import pytest from django_redis import get_redis_connection @@ -30,17 +31,28 @@ pytestmark = pytest.mark.django_db -@pytest.fixture() -def cache_setter(): - keys = [] +cache_availability_params = pytest.mark.parametrize( + "is_cache_reachable, cache_name", + [(True, "search_con_cache"), (False, "unreachable_search_con_cache")], +) +# This parametrize decorator runs the test function with two scenarios: +# - one where the API can connect to Redis +# - one where it cannot and raises ``ConnectionError`` +# The fixtures referenced here are defined below. + - def _cache_setter(key, value, version=1): - keys.append((key, version)) - cache.set(key, value=value, version=version, timeout=1) +@pytest.fixture(autouse=True) +def search_con_cache(django_cache, monkeypatch): + cache = django_cache + monkeypatch.setattr("api.controllers.search_controller.cache", cache) + yield cache - yield _cache_setter - for key, version in keys: - cache.delete(key, version=version) + +@pytest.fixture +def unreachable_search_con_cache(unreachable_django_cache, monkeypatch): + cache = unreachable_django_cache + monkeypatch.setattr("api.controllers.search_controller.cache", cache) + yield cache @pytest.mark.parametrize( @@ -824,12 +836,77 @@ def _delete_all_results_but_first(_, __, results, ___): assert "Nesting threshold breached" in caplog.text -def test_get_excluded_providers_query_returns_None_when_no_provider_is_excluded(): - assert search_controller.get_excluded_providers_query() is None +@pytest.mark.django_db +@cache_availability_params +@pytest.mark.parametrize( + "excluded_count, result", + [(2, Terms(provider=["provider1", "provider2"])), (0, None)], +) +def test_get_excluded_providers_query_returns_excluded( + excluded_count, result, is_cache_reachable, cache_name, request, caplog +): + cache = request.getfixturevalue(cache_name) + + if is_cache_reachable: + cache.set( + key=FILTERED_PROVIDERS_CACHE_KEY, + version=2, + timeout=30, + value=[f"provider{i + 1}" for i in range(excluded_count)], + ) + else: + for i in range(excluded_count): + ContentProviderFactory.create( + created_on=datetime.datetime.now(), + provider_identifier=f"provider{i + 1}", + provider_name=f"Provider {i + 1}", + filter_content=True, + ) + + assert search_controller.get_excluded_providers_query() == result + + if not is_cache_reachable: + assert all( + message in caplog.text + for message in [ + "Redis connect failed, cannot get cached filtered providers.", + "Redis connect failed, cannot cache filtered providers.", + ] + ) -def test_get_excluded_providers_query_returns_when_cache_is_set(cache_setter): - cache_setter(FILTERED_PROVIDERS_CACHE_KEY, ["provider1", "provider2"], version=2) - assert search_controller.get_excluded_providers_query() == Terms( - provider=["provider1", "provider2"] - ) +@cache_availability_params +def test_get_sources_returns_stats(is_cache_reachable, cache_name, request, caplog): + cache = request.getfixturevalue(cache_name) + + if is_cache_reachable: + cache.set( + "sources-multimedia", value={"provider_1": "1000", "provider_2": "1000"} + ) + + with patch( + "api.controllers.search_controller.get_raw_es_response", + return_value={ + "aggregations": { + "unique_sources": { + "buckets": [ + {"key": "provider_1", "doc_count": 1000}, + {"key": "provider_2", "doc_count": 1000}, + ] + } + } + }, + ): + assert search_controller.get_sources("multimedia") == { + "provider_1": 1000, + "provider_2": 1000, + } + + if not is_cache_reachable: + assert all( + message in caplog.text + for message in [ + "Redis connect failed, cannot get cached sources.", + "Redis connect failed, cannot cache sources.", + ] + ) diff --git a/api/test/unit/controllers/test_search_controller_search_query.py b/api/test/unit/controllers/test_search_controller_search_query.py index 9b574cd680f..3ccd56daded 100644 --- a/api/test/unit/controllers/test_search_controller_search_query.py +++ b/api/test/unit/controllers/test_search_controller_search_query.py @@ -1,5 +1,3 @@ -from django.core.cache import cache - import pytest from elasticsearch_dsl import Q @@ -15,7 +13,10 @@ @pytest.fixture -def excluded_providers_cache(): +def excluded_providers_cache(django_cache, monkeypatch): + cache = django_cache + monkeypatch.setattr("api.controllers.search_controller.cache", cache) + excluded_provider = "excluded_provider" cache_value = [excluded_provider] cache.set( diff --git a/api/test/unit/utils/conftest.py b/api/test/unit/utils/conftest.py index 4c726c29f3d..8ef48d758d7 100644 --- a/api/test/unit/utils/conftest.py +++ b/api/test/unit/utils/conftest.py @@ -2,20 +2,6 @@ from pathlib import Path import pytest -from fakeredis import FakeRedis - - -@pytest.fixture(autouse=True) -def redis(monkeypatch) -> FakeRedis: - fake_redis = FakeRedis() - - def get_redis_connection(*args, **kwargs): - return fake_redis - - monkeypatch.setattr("django_redis.get_redis_connection", get_redis_connection) - - yield fake_redis - fake_redis.client().close() @pytest.fixture(scope="session") diff --git a/api/test/unit/utils/test_check_dead_links.py b/api/test/unit/utils/test_check_dead_links.py index d36a2ef20c9..21eceb46781 100644 --- a/api/test/unit/utils/test_check_dead_links.py +++ b/api/test/unit/utils/test_check_dead_links.py @@ -54,6 +54,7 @@ async def raise_timeout_error(*args, **kwargs): assert len(results) == 0 +@pook.on @pytest.mark.parametrize("provider", ("thingiverse", "flickr")) def test_403_considered_dead(provider): query_hash = f"test_{provider}_403_considered_dead" @@ -62,6 +63,7 @@ def test_403_considered_dead(provider): {"identifier": i, "provider": provider if i % 2 else other_provider} for i in range(4) ] + len_results = len(results) image_urls = [f"https://example.org/{i}" for i in range(len(results))] start_slice = 0 @@ -74,7 +76,44 @@ def test_403_considered_dead(provider): check_dead_links(query_hash, start_slice, results, image_urls) - assert head_mock.calls == len(results) + assert head_mock.calls == len_results # All the provider's results should be filtered out, leaving only the "other" provider assert all([r["provider"] == other_provider for r in results]) + + +@pook.on +@pytest.mark.parametrize( + "is_cache_reachable, cache_name", + [(True, "redis"), (False, "unreachable_redis")], +) +def test_mset_and_expire_for_responses(is_cache_reachable, cache_name, request, caplog): + cache = request.getfixturevalue(cache_name) + + query_hash = "test_mset_and_expiry_for_responses" + results = [{"identifier": i, "provider": "best_provider_ever"} for i in range(40)] + image_urls = [f"https://example.org/{i}" for i in range(len(results))] + start_slice = 0 + + ( + pook.head(pook.regex(r"https://example.org/\d")) + .headers(HEADERS) + .times(len(results)) + .reply(200) + ) + + check_dead_links(query_hash, start_slice, results, image_urls) + + if is_cache_reachable: + for i in range(len(results)): + assert cache.get(f"valid:https://example.org/{i}") == b"200" + # TTL is 30 days for 2xx responses + assert cache.ttl(f"valid:https://example.org/{i}") == 2592000 + else: + assert all( + message in caplog.text + for message in [ + "Redis connect failed, validating all URLs without cache.", + "Redis connect failed, cannot cache link liveness.", + ] + ) diff --git a/api/test/unit/utils/test_image_proxy.py b/api/test/unit/utils/test_image_proxy.py index 7c81f32beb6..5cf0e2833a9 100644 --- a/api/test/unit/utils/test_image_proxy.py +++ b/api/test/unit/utils/test_image_proxy.py @@ -1,6 +1,8 @@ import asyncio +import itertools from dataclasses import replace from test.factory.models.image import ImageFactory +from unittest.mock import patch from urllib.parse import urlencode from django.conf import settings @@ -47,6 +49,16 @@ """ +cache_availability_params = pytest.mark.parametrize( + "is_cache_reachable, cache_name", + [(True, "redis"), (False, "unreachable_redis")], +) +# This parametrize decorator runs the test function with two scenarios: +# - one where the API can connect to Redis +# - one where it cannot and raises ``ConnectionError`` +# The fixtures referenced here are defined below. + + @pytest.fixture def auth_key(): test_key = "this is a test Photon Key boop boop, let me in" @@ -283,7 +295,11 @@ async def raise_exc(*args, **kwargs): @pook.on -def test_get_successful_records_response_code(photon_get, mock_image_data, redis): +@cache_availability_params +def test_get_successful_records_response_code( + photon_get, mock_image_data, is_cache_reachable, cache_name, request, caplog +): + cache = request.getfixturevalue(cache_name) ( pook.get(PHOTON_URL_FOR_TEST_IMAGE) .params( @@ -296,16 +312,22 @@ def test_get_successful_records_response_code(photon_get, mock_image_data, redis .header("Accept", "image/*") .reply(200) .body(MOCK_BODY) - .mock ) photon_get(TEST_MEDIA_INFO) month = get_monthly_timestamp() - assert redis.get(f"thumbnail_response_code:{month}:200") == b"1" - assert ( - redis.get(f"thumbnail_response_code_by_domain:{TEST_IMAGE_DOMAIN}:{month}:200") - == b"1" - ) + + keys = [ + f"thumbnail_response_code:{month}:200", + f"thumbnail_response_code_by_domain:{TEST_IMAGE_DOMAIN}:{month}:200", + ] + if is_cache_reachable: + for key in keys: + assert cache.get(key) == b"1" + else: + assert ( + "Redis connect failed, thumbnail response codes not tallied." in caplog.text + ) alert_count_params = pytest.mark.parametrize( @@ -356,6 +378,7 @@ def test_get_successful_records_response_code(photon_get, mock_image_data, redis ), ], ) +@cache_availability_params @alert_count_params def test_get_exception_handles_error( photon_get, @@ -365,14 +388,25 @@ def test_get_exception_handles_error( should_alert, sentry_capture_exception, setup_request_exception, - redis, + is_cache_reachable, + cache_name, + request, + caplog, ): + cache = request.getfixturevalue(cache_name) + setup_request_exception(exc) month = get_monthly_timestamp() key = f"thumbnail_error:{exc_name}:{TEST_IMAGE_DOMAIN}:{month}" - redis.set(key, count_start) + if is_cache_reachable: + cache.set(key, count_start) - with pytest.raises(UpstreamThumbnailException): + with ( + pytest.raises(UpstreamThumbnailException), + patch( + "api.utils.image_proxy.exception_iterator", itertools.count(count_start + 1) + ), + ): photon_get(TEST_MEDIA_INFO) assert_func = ( @@ -381,9 +415,14 @@ def test_get_exception_handles_error( else sentry_capture_exception.assert_not_called ) assert_func() - assert redis.get(key) == str(count_start + 1).encode() + + if is_cache_reachable: + assert cache.get(key) == str(count_start + 1).encode() + else: + assert "Redis connect failed, thumbnail errors not tallied." in caplog.text +@cache_availability_params @alert_count_params @pytest.mark.parametrize( "status_code, text", @@ -401,15 +440,23 @@ def test_get_http_exception_handles_error( count_start, should_alert, sentry_capture_exception, - redis, + is_cache_reachable, + cache_name, + request, + caplog, ): + cache = request.getfixturevalue(cache_name) + month = get_monthly_timestamp() key = f"thumbnail_error:aiohttp.client_exceptions.ClientResponseError:{TEST_IMAGE_DOMAIN}:{month}" - redis.set(key, count_start) + if is_cache_reachable: + cache.set(key, count_start) - with pytest.raises(UpstreamThumbnailException): + with pytest.raises(UpstreamThumbnailException), patch( + "api.utils.image_proxy.exception_iterator", itertools.count(count_start + 1) + ): with pook.use(): - pook.get(PHOTON_URL_FOR_TEST_IMAGE).reply(status_code, text).mock + pook.get(PHOTON_URL_FOR_TEST_IMAGE).reply(status_code, text) photon_get(TEST_MEDIA_INFO) assert_func = ( @@ -418,13 +465,21 @@ def test_get_http_exception_handles_error( else sentry_capture_exception.assert_not_called ) assert_func() - assert redis.get(key) == str(count_start + 1).encode() - # Assertions about the HTTP error specific message - assert ( - redis.get(f"thumbnail_http_error:{TEST_IMAGE_DOMAIN}:{month}:{status_code}") - == b"1" - ) + if is_cache_reachable: + assert cache.get(key) == str(count_start + 1).encode() + assert ( + cache.get(f"thumbnail_http_error:{TEST_IMAGE_DOMAIN}:{month}:{status_code}") + == b"1" + ) + else: + assert all( + message in caplog.text + for message in [ + "Redis connect failed, thumbnail HTTP errors not tallied.", + "Redis connect failed, thumbnail errors not tallied.", + ] + ) @pook.on @@ -511,9 +566,18 @@ def test_photon_get_raises_by_not_allowed_types(photon_get, image_type): ({"Content-Type": "unknown"}, b"unknown"), ], ) +@cache_availability_params def test_photon_get_saves_image_type_to_cache( - photon_get, redis, headers, expected_cache_val + photon_get, + headers, + expected_cache_val, + is_cache_reachable, + cache_name, + request, + caplog, ): + cache = request.getfixturevalue(cache_name) + image_url = TEST_IMAGE_URL.replace(".jpg", "") image = ImageFactory.create(url=image_url) media_info = MediaInfo( @@ -526,5 +590,14 @@ def test_photon_get_saves_image_type_to_cache( with pytest.raises(UnsupportedMediaType): photon_get(media_info) - key = f"media:{image.identifier}:thumb_type" - assert redis.get(key) == expected_cache_val + key = f"media:{image.identifier}:thumb_type" + if is_cache_reachable: + assert cache.get(key) == expected_cache_val + else: + assert all( + message in caplog.text + for message in [ + "Redis connect failed, cannot get cached image extension.", + "Redis connect failed, cannot cache image extension.", + ] + ) diff --git a/api/test/unit/utils/test_tallies.py b/api/test/unit/utils/test_tallies.py index 5d82a68bcd0..41fad62eb24 100644 --- a/api/test/unit/utils/test_tallies.py +++ b/api/test/unit/utils/test_tallies.py @@ -111,3 +111,18 @@ def test_count_provider_occurrences_increments_existing_tallies(redis): ) == b"1" ) + + +def test_writes_error_logs_for_redis_connection_errors(unreachable_redis, caplog): + provider_counts = {"flickr": 4, "stocksnap": 6} + + results = [ + {"provider": provider} + for provider, count in provider_counts.items() + for _ in range(count) + ] + now = datetime(2023, 1, 19) # 16th is start of week + with freeze_time(now): + tallies.count_provider_occurrences(results, FAKE_MEDIA_TYPE) + + assert "Redis connect failed, cannot increment provider tallies." in caplog.text diff --git a/api/test/unit/utils/test_throttle.py b/api/test/unit/utils/test_throttle.py index b24cd3d3f53..531aecf441e 100644 --- a/api/test/unit/utils/test_throttle.py +++ b/api/test/unit/utils/test_throttle.py @@ -1,7 +1,7 @@ from test.factory.models.image import ImageFactory from test.factory.models.oauth2 import AccessTokenFactory -from django.core.cache import cache +from django.http import HttpResponse from rest_framework.settings import api_settings from rest_framework.test import force_authenticate from rest_framework.views import APIView @@ -12,6 +12,30 @@ from api.views.media_views import MediaViewSet +cache_availability_params = pytest.mark.parametrize( + "is_cache_reachable, cache_name", + [(True, "throttle_cache"), (False, "unreachable_throttle_cache")], +) +# This parametrize decorator runs the test function with two scenarios: +# - one where the API can connect to Redis +# - one where it cannot and raises ``ConnectionError`` +# The fixtures referenced here are defined below. + + +@pytest.fixture(autouse=True) +def throttle_cache(django_cache, monkeypatch): + cache = django_cache + monkeypatch.setattr("rest_framework.throttling.SimpleRateThrottle.cache", cache) + yield cache + + +@pytest.fixture +def unreachable_throttle_cache(unreachable_django_cache, monkeypatch): + cache = unreachable_django_cache + monkeypatch.setattr("rest_framework.throttling.SimpleRateThrottle.cache", cache) + yield cache + + @pytest.fixture(autouse=True) def enable_throttles(settings): # Stash current settings so we can revert them after the test @@ -216,9 +240,11 @@ def test_oauth_rate_limit_used_thumbnail( @pytest.mark.django_db -def test_rate_limit_headers(request_factory): - cache.delete_pattern("throttle_*") - limit = 2 +@cache_availability_params +def test_rate_limit_headers(request_factory, is_cache_reachable, cache_name, request): + request.getfixturevalue(cache_name) + + limit = 2 # number of allowed requests, we will go 1 above this limit class DummyThrottle(throttle.BurstRateThrottle): THROTTLE_RATES = {"anon_burst": f"{limit}/hour"} @@ -226,6 +252,9 @@ class DummyThrottle(throttle.BurstRateThrottle): class ThrottledView(APIView): throttle_classes = [DummyThrottle] + def get(self, request): + return HttpResponse("ok") + view = ThrottledView().as_view() request = request_factory.get("/") @@ -234,29 +263,54 @@ class ThrottledView(APIView): response = view(request) headers = [h for h in response.headers.items() if "X-RateLimit" in h[0]] - # Assert that request returns 429 response if limit has been exceeded. - assert response.status_code == 429 if idx == limit + 1 else 200 - - # Assert that the 'Available' header constantly decrements, but not below zero. - assert [ - ("X-RateLimit-Limit-anon_burst", f"{limit}/hour"), - ("X-RateLimit-Available-anon_burst", str(max(0, limit - idx))), - ] == headers + if is_cache_reachable: + # Assert that request returns 429 response if limit has been exceeded. + assert response.status_code == 429 if idx == limit + 1 else 200 + # Assert that the 'Available' header constantly decrements, but not below zero. + assert [ + ("X-RateLimit-Limit-anon_burst", f"{limit}/hour"), + ("X-RateLimit-Available-anon_burst", str(max(0, limit - idx))), + ] == headers + else: + # Throttling gets disabled if Redis cannot cache request history. + assert response.status_code == 200 + # Headers are not set if Redis cannot cache request history. + assert not headers @pytest.mark.django_db -def test_rate_limit_headers_when_no_scope(request_factory): - cache.delete_pattern("throttle_*") +@cache_availability_params +def test_rate_limit_headers_when_no_scope( + request_factory, is_cache_reachable, cache_name, request +): + request.getfixturevalue(cache_name) + + limit = 10 # number of allowed requests, we will go 1 above this limit class ThrottledView(APIView): throttle_classes = [throttle.TenPerDay] + def get(self, request): + return HttpResponse("ok") + view = ThrottledView().as_view() request = request_factory.get("/") - response = view(request) - headers = [h for h in response.headers.items() if "X-RateLimit" in h[0]] - assert [ - ("X-RateLimit-Limit-tenperday", "10/day"), - ("X-RateLimit-Available-tenperday", "9"), - ] == headers + # Send limit + 1 requests. The last one should be throttled. + for idx in range(1, limit + 2): + response = view(request) + headers = [h for h in response.headers.items() if "X-RateLimit" in h[0]] + + if is_cache_reachable: + # Assert that request returns 429 response if limit has been exceeded. + assert response.status_code == 429 if idx == limit + 1 else 200 + # Assert that headers match the throttle class. + assert [ + ("X-RateLimit-Limit-tenperday", "10/day"), + ("X-RateLimit-Available-tenperday", str(max(0, limit - idx))), + ] == headers + else: + # Throttling gets disabled if Redis cannot cache request history. + assert response.status_code == 200 + # Headers are not set if Redis cannot cache request history. + assert not headers