Skip to content

Commit

Permalink
Added test coverage for decoded responses
Browse files Browse the repository at this point in the history
  • Loading branch information
vladvildanov committed Sep 13, 2024
1 parent 9852b78 commit 97abacd
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 51 deletions.
3 changes: 2 additions & 1 deletion redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
keys_to_delete = []

for redis_key in redis_keys:
redis_key = redis_key.decode()
if isinstance(redis_key, bytes):
redis_key = redis_key.decode()
for cache_key in self._cache:
if redis_key in cache_key.get_redis_keys():
keys_to_delete.append(cache_key)
Expand Down
24 changes: 10 additions & 14 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
compare_versions,
format_error_message,
get_lib_version,
str_if_bytes,
str_if_bytes, ensure_string,
)

if HIREDIS_AVAILABLE:
Expand Down Expand Up @@ -735,19 +735,10 @@ def _host_error(self):
return f"{self.host}:{self.port}"


def ensure_string(key):
if isinstance(key, bytes):
return key.decode("utf-8")
elif isinstance(key, str):
return key
else:
raise TypeError("Key must be either a string or bytes")


class CacheProxyConnection(ConnectionInterface):
DUMMY_CACHE_VALUE = b"foo"
MIN_ALLOWED_VERSION = "7.4.0"
DEFAULT_SERVER_NAME = b"redis"
DEFAULT_SERVER_NAME = "redis"

def __init__(self, conn: ConnectionInterface, cache: CacheInterface):
self.pid = os.getpid()
Expand Down Expand Up @@ -776,12 +767,17 @@ def set_parser(self, parser_class):
def connect(self):
self._conn.connect()

server_name = self._conn.handshake_metadata.get(b"server")
server_ver = self._conn.handshake_metadata.get(b"version")
server_name = self._conn.handshake_metadata.get(b"server", None)
if server_name is None:
server_name = self._conn.handshake_metadata.get("server", None)
server_ver = self._conn.handshake_metadata.get(b"version", None)
if server_ver is None:
server_ver = self._conn.handshake_metadata.get("version", None)
if server_ver is None or server_ver is None:
raise ConnectionError("Cannot retrieve information about server version")

server_ver = server_ver.decode("utf-8")
server_ver = ensure_string(server_ver)
server_name = ensure_string(server_name)

if (
server_name != self.DEFAULT_SERVER_NAME
Expand Down
9 changes: 9 additions & 0 deletions redis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,12 @@ def compare_versions(version1: str, version2: str) -> int:
return 1

return 0


def ensure_string(key):
if isinstance(key, bytes):
return key.decode("utf-8")
elif isinstance(key, str):
return key
else:
raise TypeError("Key must be either a string or bytes")
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,15 @@ def sentinel_setup(request):
cache = request.param.get("cache", None)
cache_config = request.param.get("cache_config", None)
force_master_ip = request.param.get("force_master_ip", None)
decode_responses = request.param.get("decode_responses", False)
sentinel = Sentinel(
sentinel_endpoints,
force_master_ip=force_master_ip,
socket_timeout=0.1,
cache=cache,
cache_config=cache_config,
protocol=3,
decode_responses=decode_responses,
**kwargs,
)
yield sentinel
Expand Down
Loading

0 comments on commit 97abacd

Please sign in to comment.