Skip to content

Commit

Permalink
Make usages of Redis resilient to absence of Redis (#3505)
Browse files Browse the repository at this point in the history
Co-authored-by: sarayourfriend <[email protected]>
  • Loading branch information
dhruvkb and sarayourfriend authored Dec 20, 2023
1 parent ba942f2 commit f56f731
Show file tree
Hide file tree
Showing 27 changed files with 738 additions and 243 deletions.
55 changes: 43 additions & 12 deletions api/api/controllers/search_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from elasticsearch_dsl import Q, Search
from elasticsearch_dsl.query import EMPTY_QUERY
from elasticsearch_dsl.response import Hit, Response
from redis.exceptions import ConnectionError

import api.models as models
from api.constants.media_types import OriginIndex, SearchIndex
Expand Down Expand Up @@ -184,21 +185,33 @@ def get_excluded_providers_query() -> Q | None:
`:FILTERED_PROVIDERS_CACHE_VERSION:FILTERED_PROVIDERS_CACHE_KEY` key.
"""

filtered_providers = cache.get(
key=FILTERED_PROVIDERS_CACHE_KEY, version=FILTERED_PROVIDERS_CACHE_VERSION
)
logger = module_logger.getChild("get_excluded_providers_query")

try:
filtered_providers = cache.get(
key=FILTERED_PROVIDERS_CACHE_KEY, version=FILTERED_PROVIDERS_CACHE_VERSION
)
except ConnectionError:
logger.warning("Redis connect failed, cannot get cached filtered providers.")
filtered_providers = None

if not filtered_providers:
filtered_providers = list(
models.ContentProvider.objects.filter(filter_content=True).values_list(
"provider_identifier", flat=True
)
)
cache.set(
key=FILTERED_PROVIDERS_CACHE_KEY,
version=FILTERED_PROVIDERS_CACHE_VERSION,
timeout=FILTER_CACHE_TIMEOUT,
value=filtered_providers,
)

try:
cache.set(
key=FILTERED_PROVIDERS_CACHE_KEY,
version=FILTERED_PROVIDERS_CACHE_VERSION,
timeout=FILTER_CACHE_TIMEOUT,
value=filtered_providers,
)
except ConnectionError:
logger.warning("Redis connect failed, cannot cache filtered providers.")

if filtered_providers:
return Q("terms", provider=filtered_providers)
return None
Expand Down Expand Up @@ -567,9 +580,19 @@ def get_sources(index):
cache_fetch_failed = True
sources = None
log.warning("Source cache fetch failed due to corruption")
except ConnectionError:
cache_fetch_failed = True
sources = None
log.warning("Redis connect failed, cannot get cached sources.")

if isinstance(sources, list) or cache_fetch_failed:
# Invalidate old provider format.
cache.delete(key=source_cache_name)
sources = None
try:
# Invalidate old provider format.
cache.delete(key=source_cache_name)
except ConnectionError:
log.warning("Redis connect failed, cannot invalidate cached sources.")

if not sources:
# Don't increase `size` without reading this issue first:
# https://github.com/elastic/elasticsearch/issues/18838
Expand Down Expand Up @@ -597,7 +620,15 @@ def get_sources(index):
except NotFoundError:
buckets = [{"key": "none_found", "doc_count": 0}]
sources = {result["key"]: result["doc_count"] for result in buckets}
cache.set(key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources)

try:
cache.set(
key=source_cache_name, timeout=SOURCE_CACHE_TIMEOUT, value=sources
)
except ConnectionError:
log.warning("Redis connect failed, cannot cache sources.")

sources = {source: int(doc_count) for source, doc_count in sources.items()}
return sources


Expand Down
18 changes: 14 additions & 4 deletions api/api/utils/check_dead_links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from asgiref.sync import async_to_sync
from decouple import config
from elasticsearch_dsl.response import Hit
from redis.exceptions import ConnectionError

from api.utils.aiohttp import get_aiohttp_session
from api.utils.check_dead_links.provider_status_mappings import provider_status_mappings
Expand All @@ -25,8 +26,15 @@


def _get_cached_statuses(redis, image_urls):
cached_statuses = redis.mget([CACHE_PREFIX + url for url in image_urls])
return [int(b.decode("utf-8")) if b is not None else None for b in cached_statuses]
try:
cached_statuses = redis.mget([CACHE_PREFIX + url for url in image_urls])
return [
int(b.decode("utf-8")) if b is not None else None for b in cached_statuses
]
except ConnectionError:
logger = parent_logger.getChild("_get_cached_statuses")
logger.warning("Redis connect failed, validating all URLs without cache.")
return [None] * len(image_urls)


def _get_expiry(status, default):
Expand All @@ -51,7 +59,6 @@ async def _head(url: str) -> tuple[str, int]:
# https://stackoverflow.com/q/55259755
@async_to_sync
async def _make_head_requests(urls: list[str]) -> list[tuple[str, int]]:
tasks = []
tasks = [asyncio.ensure_future(_head(url)) for url in urls]
responses = asyncio.gather(*tasks)
await responses
Expand Down Expand Up @@ -111,7 +118,10 @@ def check_dead_links(
logger.debug(f"caching status={status} expiry={expiry}")
pipe.expire(key, expiry)

pipe.execute()
try:
pipe.execute()
except ConnectionError:
logger.warning("Redis connect failed, cannot cache link liveness.")

# Merge newly verified results with cached statuses
for idx, url in enumerate(to_verify):
Expand Down
20 changes: 18 additions & 2 deletions api/api/utils/dead_link_mask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import logging

import django_redis
from deepdiff import DeepHash
from elasticsearch_dsl import Search
from redis.exceptions import ConnectionError


parent_logger = logging.getLogger(__name__)


# 3 hours minutes (in seconds)
Expand Down Expand Up @@ -34,7 +40,12 @@ def get_query_mask(query_hash: str) -> list[int]:
"""
redis = django_redis.get_redis_connection("default")
key = f"{query_hash}:dead_link_mask"
return list(map(int, redis.lrange(key, 0, -1)))
try:
return list(map(int, redis.lrange(key, 0, -1)))
except ConnectionError:
logger = parent_logger.getChild("get_query_mask")
logger.warning("Redis connect failed, cannot get cached query mask.")
return []


def save_query_mask(query_hash: str, mask: list):
Expand All @@ -50,4 +61,9 @@ def save_query_mask(query_hash: str, mask: list):
redis_pipe.delete(key)
redis_pipe.rpush(key, *mask)
redis_pipe.expire(key, DEAD_LINK_MASK_TTL)
redis_pipe.execute()

try:
redis_pipe.execute()
except ConnectionError:
logger = parent_logger.getChild("save_query_mask")
logger.warning("Redis connect failed, cannot cache query mask.")
51 changes: 40 additions & 11 deletions api/api/utils/image_proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
from dataclasses import dataclass
from typing import Literal
Expand All @@ -12,6 +13,7 @@
import sentry_sdk
from aiohttp.client_exceptions import ClientResponseError
from asgiref.sync import sync_to_async
from redis.exceptions import ConnectionError
from sentry_sdk import push_scope, set_context

from api.utils.aiohttp import get_aiohttp_session
Expand All @@ -24,6 +26,8 @@

parent_logger = logging.getLogger(__name__)

exception_iterator = itertools.count()

HEADERS = {
"User-Agent": settings.OUTBOUND_USER_AGENT_TEMPLATE.format(
purpose="ThumbnailGeneration"
Expand Down Expand Up @@ -81,7 +85,7 @@ def get_request_params_for_extension(

@sync_to_async
def _tally_response(
tallies,
tallies_conn,
media_info: MediaInfo,
month: str,
domain: str,
Expand All @@ -93,14 +97,33 @@ def _tally_response(
Pulled into a separate function to help reduce overload when skimming
the `get` function, which is complex enough as is.
"""
tallies.incr(f"thumbnail_response_code:{month}:{response.status}")
tallies.incr(
f"thumbnail_response_code_by_domain:{domain}:" f"{month}:{response.status}"
)
tallies.incr(
f"thumbnail_response_code_by_provider:{media_info.media_provider}:"
f"{month}:{response.status}"
)

logger = parent_logger.getChild("_tally_response")

with tallies_conn.pipeline() as tallies:
tallies.incr(f"thumbnail_response_code:{month}:{response.status}")
tallies.incr(
f"thumbnail_response_code_by_domain:{domain}:" f"{month}:{response.status}"
)
tallies.incr(
f"thumbnail_response_code_by_provider:{media_info.media_provider}:"
f"{month}:{response.status}"
)
try:
tallies.execute()
except ConnectionError:
logger.warning(
"Redis connect failed, thumbnail response codes not tallied."
)


@sync_to_async
def _tally_client_response_errors(tallies, month: str, domain: str, status: int):
logger = parent_logger.getChild("_tally_client_response_errors")
try:
tallies.incr(f"thumbnail_http_error:{domain}:{month}:{status}")
except ConnectionError:
logger.warning("Redis connect failed, thumbnail HTTP errors not tallied.")


_UPSTREAM_TIMEOUT = aiohttp.ClientTimeout(15)
Expand Down Expand Up @@ -168,7 +191,13 @@ async def get(
except Exception as exc:
exception_name = f"{exc.__class__.__module__}.{exc.__class__.__name__}"
key = f"thumbnail_error:{exception_name}:{domain}:{month}"
count = await tallies_incr(key)
try:
count = await tallies_incr(key)
except ConnectionError:
logger.warning("Redis connect failed, thumbnail errors not tallied.")
# We will use a counter to space out Sentry logs.
count = next(exception_iterator)

if count <= settings.THUMBNAIL_ERROR_INITIAL_ALERT_THRESHOLD or (
count % settings.THUMBNAIL_ERROR_REPEATED_ALERT_FREQUENCY == 0
):
Expand All @@ -188,7 +217,7 @@ async def get(
if isinstance(exc, ClientResponseError):
status = exc.status
do_not_wait_for(
tallies_incr(f"thumbnail_http_error:{domain}:{month}:{status}")
_tally_client_response_errors(tallies, month, domain, status)
)
logger.warning(
f"Failed to render thumbnail "
Expand Down
25 changes: 22 additions & 3 deletions api/api/utils/image_proxy/extension.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
import logging
from os.path import splitext
from urllib.parse import urlparse

import aiohttp
import django_redis
import sentry_sdk
from asgiref.sync import sync_to_async
from redis.exceptions import ConnectionError

from api.utils.aiohttp import get_aiohttp_session
from api.utils.asyncio import do_not_wait_for
from api.utils.image_proxy.exception import UpstreamThumbnailException


parent_logger = logging.getLogger(__name__)


_HEAD_TIMEOUT = aiohttp.ClientTimeout(10)


async def get_image_extension(image_url: str, media_identifier) -> str | None:
logger = parent_logger.getChild("get_image_extension")

cache = django_redis.get_redis_connection("default")
key = f"media:{media_identifier}:thumb_type"

ext = _get_file_extension_from_url(image_url)

if not ext:
# If the extension is not present in the URL, try to get it from the redis cache
ext = await sync_to_async(cache.get)(key)
ext = ext.decode("utf-8") if ext else None
try:
ext = await sync_to_async(cache.get)(key)
ext = ext.decode("utf-8") if ext else None
except ConnectionError:
logger.warning("Redis connect failed, cannot get cached image extension.")

if not ext:
# If the extension is still not present, try getting it from the content type
Expand All @@ -37,7 +47,7 @@ async def get_image_extension(image_url: str, media_identifier) -> str | None:
else:
ext = None

do_not_wait_for(sync_to_async(cache.set)(key, ext if ext else "unknown"))
do_not_wait_for(_cache_extension(cache, key, ext))
except Exception as exc:
sentry_sdk.capture_exception(exc)
raise UpstreamThumbnailException(
Expand All @@ -48,6 +58,15 @@ async def get_image_extension(image_url: str, media_identifier) -> str | None:
return ext


@sync_to_async
def _cache_extension(cache, key, ext):
logger = parent_logger.getChild("cache_extension")
try:
cache.set(key, ext if ext else "unknown")
except ConnectionError:
logger.warning("Redis connect failed, cannot cache image extension.")


def _get_file_extension_from_url(image_url: str) -> str:
"""Return the image extension if present in the URL."""
parsed = urlparse(image_url)
Expand Down
12 changes: 10 additions & 2 deletions api/api/utils/tallies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
from collections import defaultdict
from datetime import datetime, timedelta

import django_redis
from django_redis.client.default import Redis
from redis.exceptions import ConnectionError


parent_logger = logging.getLogger(__name__)


def get_weekly_timestamp() -> str:
Expand Down Expand Up @@ -37,5 +42,8 @@ def count_provider_occurrences(results: list[dict], index: str) -> None:
for provider, occurrences in provider_occurrences.items():
pipe.incr(f"provider_occurrences:{index}:{week}:{provider}", occurrences)
pipe.incr(f"provider_appeared_in_searches:{index}:{week}:{provider}", 1)

pipe.execute()
try:
pipe.execute()
except ConnectionError:
logger = parent_logger.getChild("count_provider_occurrences")
logger.warning("Redis connect failed, cannot increment provider tallies.")
12 changes: 9 additions & 3 deletions api/api/utils/throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from rest_framework.throttling import SimpleRateThrottle as BaseSimpleRateThrottle

from redis.exceptions import ConnectionError

from api.utils.oauth2_helper import get_token_info


Expand All @@ -16,7 +18,12 @@ class SimpleRateThrottle(BaseSimpleRateThrottle, metaclass=abc.ABCMeta):
"""

def allow_request(self, request, view):
is_allowed = super().allow_request(request, view)
try:
is_allowed = super().allow_request(request, view)
except ConnectionError:
logger = parent_logger.getChild("allow_request")
logger.warning("Redis connect failed, allowing request.")
is_allowed = True
view.headers |= self.headers()
return is_allowed

Expand Down Expand Up @@ -44,10 +51,9 @@ def has_valid_token(self, request):
return token_info and token_info.valid

def get_cache_key(self, request, view):
ident = self.get_ident(request)
return self.cache_format % {
"scope": self.scope,
"ident": ident,
"ident": self.get_ident(request),
}


Expand Down
Loading

0 comments on commit f56f731

Please sign in to comment.