diff --git a/django_redis/cache.py b/django_redis/cache.py index d26c33fa..f29acbbf 100644 --- a/django_redis/cache.py +++ b/django_redis/cache.py @@ -186,6 +186,71 @@ def touch(self, *args, **kwargs): return self.client.touch(*args, **kwargs) @omit_exception + def sadd(self, *args, **kwargs): + return self.client.sadd(*args, **kwargs) + + @omit_exception + def scard(self, *args, **kwargs): + return self.client.scard(*args, **kwargs) + + @omit_exception + def sdiff(self, *args, **kwargs): + return self.client.sdiff(*args, **kwargs) + + @omit_exception + def sdiffstore(self, *args, **kwargs): + return self.client.sdiffstore(*args, **kwargs) + + @omit_exception + def sinter(self, *args, **kwargs): + return self.client.sinter(*args, **kwargs) + + @omit_exception + def sinterstore(self, *args, **kwargs): + return self.client.sinterstore(*args, **kwargs) + + @omit_exception + def sismember(self, *args, **kwargs): + return self.client.sismember(*args, **kwargs) + + @omit_exception + def smembers(self, *args, **kwargs): + return self.client.smembers(*args, **kwargs) + + @omit_exception + def smove(self, *args, **kwargs): + return self.client.smove(*args, **kwargs) + + @omit_exception + def spop(self, *args, **kwargs): + return self.client.spop(*args, **kwargs) + + @omit_exception + def srandmember(self, *args, **kwargs): + return self.client.srandmember(*args, **kwargs) + + @omit_exception + def srem(self, *args, **kwargs): + return self.client.srem(*args, **kwargs) + + @omit_exception + def sscan(self, *args, **kwargs): + return self.client.sscan(*args, **kwargs) + @omit_exception + def sscan_iter(self, *args, **kwargs): + return self.client.sscan_iter(*args, **kwargs) + + @omit_exception + def smismember(self, *args, **kwargs): + return self.client.smismember(*args, **kwargs) + @omit_exception + def sunion(self, *args, **kwargs): + return self.client.sunion(*args, **kwargs) + + @omit_exception + def sunionstore(self, *args, **kwargs): + return self.client.sunionstore(*args, **kwargs) + @omit_exception def hset(self, *args, **kwargs): return self.client.hset(*args, **kwargs) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index b9a5c1b0..d61a0b35 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,7 +3,7 @@ import socket from collections import OrderedDict from contextlib import suppress -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union, Set from django.conf import settings from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func @@ -778,6 +778,246 @@ def make_pattern( return CacheKey(self._backend.key_func(pattern, prefix, version_str)) + def sadd( + self, + key: Any, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + values = [self.encode(value) for value in values] + return int(client.sadd(key, *values)) + + def scard( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return int(client.scard(key)) + + def sdiff( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sdiff(*keys)} + + def sdiffstore( + self, + dest: Any, + *keys, + version_dest: Optional[int] = None, + version_keys: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version_dest) + keys = [self.make_key(key, version=version_keys) for key in keys] + return int(client.sdiffstore(dest, *keys)) + + def sinter( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sinter(*keys)} + + def sinterstore( + self, + dest: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sinterstore(dest, *keys)) + + def smismember( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + encoded_members = [self.encode(member) for member in members] + + return [bool(value) for value in client.smismember(key, *encoded_members)] + + def sismember( + self, + key: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + member = self.encode(member) + return bool(client.sismember(key, member)) + + def smembers( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return {self.decode(value) for value in client.smembers(key)} + + def smove( + self, + source: Any, + destination: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=True) + + source = self.make_key(source, version=version) + destination = self.make_key(destination) + member = self.encode(member) + return bool(client.smove(source, destination, member)) + + def spop( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[set, Any]: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + result = client.spop(key, count) + if type(result) == list: + return {self.decode(value) for value in result} + return self.decode(result) + + def srandmember( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[set, Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + result = client.srandmember(key, count) + if type(result) == list: + return {self.decode(value) for value in result} + return self.decode(result) + + def srem( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + members = [self.decode(member) for member in members] + return int(client.srem(key, *members)) + + def sscan( + self, + key: Any, + match: Optional[str] = None, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set[Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + + cursor, result = client.sscan(key, match=self.encode(match), count=count) + return {self.decode(value) for value in result} + + def sscan_iter( + self, + key: Any, + match: Optional[str] = None, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Iterator[Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + for value in client.sscan_iter(key, match=match, count=count): + yield self.decode(value) + + + def sunion( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sunion(*keys)} + + def sunionstore( + self, + destination: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + destination = self.make_key(destination, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sunionstore(destination, *keys)) + def close(self) -> None: close_flag = self._options.get( "CLOSE_CONNECTION", diff --git a/django_redis/compressors/lz4.py b/django_redis/compressors/lz4.py index 32183321..940c96d5 100644 --- a/django_redis/compressors/lz4.py +++ b/django_redis/compressors/lz4.py @@ -16,5 +16,5 @@ def compress(self, value: bytes) -> bytes: def decompress(self, value: bytes) -> bytes: try: return _decompress(value) - except Exception as e: # noqa: BLE001 + except Exception as e: raise CompressorError from e diff --git a/tests/test_backend.py b/tests/test_backend.py index 4ff60983..fbc40fc5 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -856,3 +856,113 @@ def test_hexists(self, cache: RedisCache): cache.hset("foo_hash5", "foo1", "bar1") assert cache.hexists("foo_hash5", "foo1") assert not cache.hexists("foo_hash5", "foo") + + def test_sadd(self, cache: RedisCache): + assert cache.sadd("foo", "bar") == 1 + assert cache.smembers("foo") == {"bar"} + + def test_scard(self, cache: RedisCache): + cache.sadd("foo", "bar", "bar2") + assert cache.scard("foo") == 2 + + def test_sdiff(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiff("foo1", "foo2") == {"bar1"} + + def test_sdiffstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiffstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar1"} + + def test_sdiffstore_with_keys_version(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2", version=2) + cache.sadd("foo2", "bar2", "bar3", version=2) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 1 + assert cache.smembers("foo3") == {"bar1"} + def test_sdiffstore_with_different_keys_versions_without_initial_set_in_version(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2", version=1) + cache.sadd("foo2", "bar2", "bar3", version=2) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 0 + + def test_sdiffstore_with_different_keys_versions_with_initial_set_in_version(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2", version=2) + cache.sadd("foo2", "bar2", "bar3", version=1) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 2 + def test_sinter(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinter("foo1", "foo2") == {"bar2"} + + def test_interstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinterstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar2"} + + def test_sismember(self, cache: RedisCache): + cache.sadd("foo", "bar") + assert cache.sismember("foo", "bar") is True + assert cache.sismember("foo", "bar2") is False + + def test_smove(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.smove("foo1", "foo2", "bar1") is True + assert cache.smove("foo1", "foo2", "bar4") is False + assert cache.smembers("foo1") == {"bar2"} + assert cache.smembers("foo2") == {"bar1", "bar2", "bar3"} + + def test_spop_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo") in {"bar1", "bar2"} + assert cache.smembers("foo") in {{"bar1"}, {"bar2"}} + + def test_spop(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo", 1) in {{"bar1"}, {"bar2"}} + assert cache.smembers("foo") in {{"bar1"}, {"bar2"}} + + def test_srandmember_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo") in {"bar1", "bar2"} + + def test_srandmember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo", 1) in {{"bar1"}, {"bar2"}} + + def test_srem(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srem("foo", "bar1") == 1 + assert cache.srem("foo", "bar3") == 0 + + def test_sscan(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + items = cache.sscan("foo") + assert items == {"bar1", "bar2"} + + def test_sscan_with_match(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2", "zoo") + items = cache.sscan("foo", match="zoo") + assert items == {"zoo"} + + def test_sscan_iter(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + items = cache.sscan_iter("foo") + assert set(items) == {"bar1", "bar2"} + + def test_smismember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2", "bar3") + assert cache.smismember("foo", "bar1", "bar2", "xyz") == [True, True, False] + + def test_sunion(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunion("foo1", "foo2") == {"bar1", "bar2", "bar3"} + + def test_sunionstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunionstore("foo3", "foo1", "foo2") == 3 + assert cache.smembers("foo3") == {"bar1", "bar2", "bar3"}