diff --git a/redis/cache.py b/redis/cache.py index 53faf8d055..61626beef7 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -201,7 +201,8 @@ def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: keys_to_delete = [] for redis_key in redis_keys: - redis_key = redis_key.decode() + if isinstance(redis_key, bytes): + redis_key = redis_key.decode() for cache_key in self._cache: if redis_key in cache_key.get_redis_keys(): keys_to_delete.append(cache_key) diff --git a/redis/connection.py b/redis/connection.py index f1ca69fde0..a2d9f4f2c2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -44,7 +44,7 @@ compare_versions, format_error_message, get_lib_version, - str_if_bytes, + str_if_bytes, ensure_string, ) if HIREDIS_AVAILABLE: @@ -735,19 +735,10 @@ def _host_error(self): return f"{self.host}:{self.port}" -def ensure_string(key): - if isinstance(key, bytes): - return key.decode("utf-8") - elif isinstance(key, str): - return key - else: - raise TypeError("Key must be either a string or bytes") - - class CacheProxyConnection(ConnectionInterface): DUMMY_CACHE_VALUE = b"foo" MIN_ALLOWED_VERSION = "7.4.0" - DEFAULT_SERVER_NAME = b"redis" + DEFAULT_SERVER_NAME = "redis" def __init__(self, conn: ConnectionInterface, cache: CacheInterface): self.pid = os.getpid() @@ -776,12 +767,17 @@ def set_parser(self, parser_class): def connect(self): self._conn.connect() - server_name = self._conn.handshake_metadata.get(b"server") - server_ver = self._conn.handshake_metadata.get(b"version") + server_name = self._conn.handshake_metadata.get(b"server", None) + if server_name is None: + server_name = self._conn.handshake_metadata.get("server", None) + server_ver = self._conn.handshake_metadata.get(b"version", None) if server_ver is None: + server_ver = self._conn.handshake_metadata.get("version", None) + if server_ver is None or server_ver is None: raise ConnectionError("Cannot retrieve information about server version") - server_ver = server_ver.decode("utf-8") + server_ver = ensure_string(server_ver) + server_name = ensure_string(server_name) if ( server_name != self.DEFAULT_SERVER_NAME diff --git a/redis/utils.py b/redis/utils.py index 4b3a4647dc..b4e9afb054 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -183,3 +183,12 @@ def compare_versions(version1: str, version2: str) -> int: return 1 return 0 + + +def ensure_string(key): + if isinstance(key, bytes): + return key.decode("utf-8") + elif isinstance(key, str): + return key + else: + raise TypeError("Key must be either a string or bytes") diff --git a/tests/conftest.py b/tests/conftest.py index 0755fd390e..0c98eee4d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -442,6 +442,7 @@ def sentinel_setup(request): cache = request.param.get("cache", None) cache_config = request.param.get("cache_config", None) force_master_ip = request.param.get("force_master_ip", None) + decode_responses = request.param.get("decode_responses", False) sentinel = Sentinel( sentinel_endpoints, force_master_ip=force_master_ip, @@ -449,6 +450,7 @@ def sentinel_setup(request): cache=cache, cache_config=cache_config, protocol=3, + decode_responses=decode_responses, **kwargs, ) yield sentinel diff --git a/tests/test_cache.py b/tests/test_cache.py index 1a26fa668d..e106cdb156 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -24,6 +24,7 @@ def r(request): protocol = request.param.get("protocol", 3) ssl = request.param.get("ssl", False) single_connection_client = request.param.get("single_connection_client", False) + decode_responses = request.param.get("decode_responses", False) with _get_client( redis.Redis, request, @@ -32,6 +33,7 @@ def r(request): single_connection_client=single_connection_client, cache=cache, cache_config=cache_config, + decode_responses=decode_responses, **kwargs, ) as client: yield client @@ -53,8 +55,13 @@ class TestCache: "cache": DefaultCache(CacheConfig(max_size=5)), "single_connection_client": False, }, + { + "cache": DefaultCache(CacheConfig(max_size=5)), + "single_connection_client": False, + "decode_responses": True, + }, ], - ids=["single", "pool"], + ids=["single", "pool", "decoded"], indirect=True, ) @pytest.mark.onlynoncluster @@ -63,20 +70,20 @@ def test_get_from_given_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -90,8 +97,13 @@ def test_get_from_given_cache(self, r, r2): "cache_config": CacheConfig(max_size=128), "single_connection_client": False, }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + "decode_responses": True, + }, ], - ids=["single", "pool"], + ids=["single", "pool", "decoded"], indirect=True, ) @pytest.mark.onlynoncluster @@ -103,20 +115,20 @@ def test_get_from_default_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -351,7 +363,11 @@ class TestClusterCache: [ { "cache": DefaultCache(CacheConfig(max_size=128)), - } + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "decode_responses": True, + }, ], indirect=True, ) @@ -361,20 +377,20 @@ def test_get_from_cache(self, r): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) # Make sure that cache is shared between nodes. assert ( @@ -387,6 +403,10 @@ def test_get_from_cache(self, r): { "cache_config": CacheConfig(max_size=128), }, + { + "cache_config": CacheConfig(max_size=128), + "decode_responses": True, + }, ], indirect=True, ) @@ -398,20 +418,20 @@ def test_get_from_custom_cache(self, r, r2): # add key to redis assert r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -615,6 +635,11 @@ class TestSentinelCache: { "cache": DefaultCache(CacheConfig(max_size=128)), "force_master_ip": "localhost", + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "force_master_ip": "localhost", + "decode_responses": True, } ], indirect=True, @@ -624,20 +649,20 @@ def test_get_from_cache(self, master): cache = master.get_cache() master.set("foo", "bar") # get key from redis and save in local cache_data - assert master.get("foo") == b"bar" + assert master.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) master.set("foo", "barbar") # get key from redis - assert master.get("foo") == b"barbar" + assert master.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -646,6 +671,10 @@ def test_get_from_cache(self, master): { "cache_config": CacheConfig(max_size=128), }, + { + "cache_config": CacheConfig(max_size=128), + "decode_responses": True, + }, ], indirect=True, ) @@ -656,20 +685,20 @@ def test_get_from_default_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache_data - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache_data it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -731,7 +760,7 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -@skip_if_resp_version(2) +#@skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") class TestSSLCache: @pytest.mark.parametrize( @@ -740,6 +769,11 @@ class TestSSLCache: { "cache": DefaultCache(CacheConfig(max_size=128)), "ssl": True, + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "ssl": True, + "decode_responses": True, } ], indirect=True, @@ -750,11 +784,11 @@ def test_get_from_cache(self, r, r2, cache): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache_data - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) assert r2.set("foo", "barbar") @@ -762,11 +796,11 @@ def test_get_from_cache(self, r, r2, cache): # between data appears in socket buffer time.sleep(0.1) # Retrieves a new value from server and cache_data it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -776,6 +810,11 @@ def test_get_from_cache(self, r, r2, cache): "cache_config": CacheConfig(max_size=128), "ssl": True, }, + { + "cache_config": CacheConfig(max_size=128), + "ssl": True, + "decode_responses": True, + }, ], indirect=True, ) @@ -786,11 +825,11 @@ def test_get_from_custom_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache_data - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") @@ -798,11 +837,11 @@ def test_get_from_custom_cache(self, r, r2): # between data appears in socket buffer time.sleep(0.1) # Retrieves a new value from server and cache_data it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize(