Skip to content

Commit

Permalink
Handle RESP3 sets as Python lists (#3324)
Browse files Browse the repository at this point in the history
* Handle RESP3 sets as Python lists

Although the RESP3 protocol defines the set data structure, sometimes
the responses from the Redis server contain sets with nested maps, which
cannot be represented in Python as sets with nested dicts, because dicts
are not hashable.

Versions of HIREDIS before 3.0.0 would cause segmentation fault when
parsing such responses. Starting with version 3.0.0 the problem was
fixed, with the compromise that RESP3 sets are represented as Python
lists.

The embedded RESP3 parser was so far trying to represent RESP3 sets as
Python sets, if possible. Only when this was not possible it would
switch to the list representation. Arguably this is not the best user
experience, not knowing when you will get back a set or a list.

Upgrade the required hiredis-py version to be at least 3.0.0, and change
the embedded parser to always represent RESP3 sets as lists. This way we
get a consistent experience in all cases.

This is a breaking change.

* Also cover RESP2 sets

* Fix failing tests

* Fix async RESP3 parser
  • Loading branch information
gerzse authored Jul 22, 2024
1 parent 2ffcac3 commit fd0b0d3
Show file tree
Hide file tree
Showing 18 changed files with 95 additions and 144 deletions.
34 changes: 17 additions & 17 deletions doctests/dt_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@
r.sadd("bikes:racing:usa", "bike:1", "bike:4")
# HIDE_END
res7 = r.sinter("bikes:racing:france", "bikes:racing:usa")
print(res7) # >>> {'bike:1'}
print(res7) # >>> ['bike:1']
# STEP_END

# REMOVE_START
assert res7 == {"bike:1"}
assert res7 == ["bike:1"]
# REMOVE_END

# STEP_START scard
Expand All @@ -83,12 +83,12 @@
print(res9) # >>> 3

res10 = r.smembers("bikes:racing:france")
print(res10) # >>> {'bike:1', 'bike:2', 'bike:3'}
print(res10) # >>> ['bike:1', 'bike:2', 'bike:3']
# STEP_END

# REMOVE_START
assert res9 == 3
assert res10 == {"bike:1", "bike:2", "bike:3"}
assert res10 == ['bike:1', 'bike:2', 'bike:3']
# REMOVE_END

# STEP_START smismember
Expand All @@ -109,11 +109,11 @@
r.sadd("bikes:racing:usa", "bike:1", "bike:4")

res13 = r.sdiff("bikes:racing:france", "bikes:racing:usa")
print(res13) # >>> {'bike:2', 'bike:3'}
print(res13) # >>> ['bike:2', 'bike:3']
# STEP_END

# REMOVE_START
assert res13 == {"bike:2", "bike:3"}
assert res13 == ['bike:2', 'bike:3']
r.delete("bikes:racing:france")
r.delete("bikes:racing:usa")
# REMOVE_END
Expand All @@ -124,27 +124,27 @@
r.sadd("bikes:racing:italy", "bike:1", "bike:2", "bike:3", "bike:4")

res13 = r.sinter("bikes:racing:france", "bikes:racing:usa", "bikes:racing:italy")
print(res13) # >>> {'bike:1'}
print(res13) # >>> ['bike:1']

res14 = r.sunion("bikes:racing:france", "bikes:racing:usa", "bikes:racing:italy")
print(res14) # >>> {'bike:1', 'bike:2', 'bike:3', 'bike:4'}
print(res14) # >>> ['bike:1', 'bike:2', 'bike:3', 'bike:4']

res15 = r.sdiff("bikes:racing:france", "bikes:racing:usa", "bikes:racing:italy")
print(res15) # >>> set()
print(res15) # >>> []

res16 = r.sdiff("bikes:racing:usa", "bikes:racing:france")
print(res16) # >>> {'bike:4'}
print(res16) # >>> ['bike:4']

res17 = r.sdiff("bikes:racing:france", "bikes:racing:usa")
print(res17) # >>> {'bike:2', 'bike:3'}
print(res17) # >>> ['bike:2', 'bike:3']
# STEP_END

# REMOVE_START
assert res13 == {"bike:1"}
assert res14 == {"bike:1", "bike:2", "bike:3", "bike:4"}
assert res15 == set()
assert res16 == {"bike:4"}
assert res17 == {"bike:2", "bike:3"}
assert res13 == ['bike:1']
assert res14 == ['bike:1', 'bike:2', 'bike:3', 'bike:4']
assert res15 == []
assert res16 == ['bike:4']
assert res17 == ['bike:2', 'bike:3']
r.delete("bikes:racing:france")
r.delete("bikes:racing:usa")
r.delete("bikes:racing:italy")
Expand All @@ -160,7 +160,7 @@
print(res19) # >>> bike:3

res20 = r.smembers("bikes:racing:france")
print(res20) # >>> {'bike:2', 'bike:4', 'bike:5'}
print(res20) # >>> ['bike:2', 'bike:4', 'bike:5']

res21 = r.srandmember("bikes:racing:france")
print(res21) # >>> bike:4
Expand Down
3 changes: 0 additions & 3 deletions redis/_parsers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,6 @@ def string_keys_to_dict(key_string, callback):


_RedisCallbacksRESP2 = {
**string_keys_to_dict(
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
),
**string_keys_to_dict(
"ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZRANK ZREVRANGE "
"ZREVRANGEBYSCORE ZREVRANK ZUNION",
Expand Down
12 changes: 2 additions & 10 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,11 @@ def _read_response(self, disable_decoding=False, push_request=False):
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we need to first convert to a list, and then try to convert it to a set
# so we return sets as list, all the time, for predictability
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
try:
response = set(response)
except TypeError:
pass
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
Expand Down Expand Up @@ -233,15 +229,11 @@ async def _read_response(
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we need to first convert to a list, and then try to convert it to a set
# so we always convert to a list, to have predictable return types
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
try:
response = set(response)
except TypeError:
pass
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
Expand Down
6 changes: 1 addition & 5 deletions redis/commands/bf/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from redis.client import NEVER_DECODE
from redis.exceptions import ModuleError
from redis.utils import HIREDIS_AVAILABLE, deprecated_function
from redis.utils import deprecated_function

BF_RESERVE = "BF.RESERVE"
BF_ADD = "BF.ADD"
Expand Down Expand Up @@ -139,9 +138,6 @@ def scandump(self, key, iter):
This command will return successive (iter, data) pairs until (0, NULL) to indicate completion.
For more information see `BF.SCANDUMP <https://redis.io/commands/bf.scandump>`_.
""" # noqa
if HIREDIS_AVAILABLE:
raise ModuleError("This command cannot be used when hiredis is available.")

params = [key, iter]
options = {}
options[NEVER_DECODE] = []
Expand Down
4 changes: 2 additions & 2 deletions redis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
try:
import hiredis # noqa

# Only support Hiredis >= 1.0:
HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.")
# Only support Hiredis >= 3.0:
HIREDIS_AVAILABLE = int(hiredis.__version__.split(".")[0]) >= 3
HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command")
except ImportError:
HIREDIS_AVAILABLE = False
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"Programming Language :: Python :: Implementation :: PyPy",
],
extras_require={
"hiredis": ["hiredis>=1.0.0"],
"hiredis": ["hiredis>=3.0.0"],
"ocsp": ["cryptography>=36.0.1", "pyopenssl==23.2.1", "requests>=2.31.0"],
},
)
7 changes: 1 addition & 6 deletions tests/test_asyncio/test_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import pytest
import pytest_asyncio
import redis.asyncio as redis
from redis.exceptions import ModuleError, RedisError
from redis.utils import HIREDIS_AVAILABLE
from redis.exceptions import RedisError
from tests.conftest import (
assert_resp_response,
is_resp2_connection,
Expand Down Expand Up @@ -105,10 +104,6 @@ async def do_verify():

await do_verify()
cmds = []
if HIREDIS_AVAILABLE:
with pytest.raises(ModuleError):
cur = await decoded_r.bf().scandump("myBloom", 0)
return

cur = await decoded_r.bf().scandump("myBloom", 0)
first = cur[0]
Expand Down
24 changes: 12 additions & 12 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,49 +1753,49 @@ async def test_cluster_rpoplpush(self, r: RedisCluster) -> None:

async def test_cluster_sdiff(self, r: RedisCluster) -> None:
await r.sadd("{foo}a", "1", "2", "3")
assert await r.sdiff("{foo}a", "{foo}b") == {b"1", b"2", b"3"}
assert set(await r.sdiff("{foo}a", "{foo}b")) == {b"1", b"2", b"3"}
await r.sadd("{foo}b", "2", "3")
assert await r.sdiff("{foo}a", "{foo}b") == {b"1"}
assert await r.sdiff("{foo}a", "{foo}b") == [b"1"]

async def test_cluster_sdiffstore(self, r: RedisCluster) -> None:
await r.sadd("{foo}a", "1", "2", "3")
assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 3
assert await r.smembers("{foo}c") == {b"1", b"2", b"3"}
assert set(await r.smembers("{foo}c")) == {b"1", b"2", b"3"}
await r.sadd("{foo}b", "2", "3")
assert await r.sdiffstore("{foo}c", "{foo}a", "{foo}b") == 1
assert await r.smembers("{foo}c") == {b"1"}
assert await r.smembers("{foo}c") == [b"1"]

async def test_cluster_sinter(self, r: RedisCluster) -> None:
await r.sadd("{foo}a", "1", "2", "3")
assert await r.sinter("{foo}a", "{foo}b") == set()
assert await r.sinter("{foo}a", "{foo}b") == []
await r.sadd("{foo}b", "2", "3")
assert await r.sinter("{foo}a", "{foo}b") == {b"2", b"3"}
assert set(await r.sinter("{foo}a", "{foo}b")) == {b"2", b"3"}

async def test_cluster_sinterstore(self, r: RedisCluster) -> None:
await r.sadd("{foo}a", "1", "2", "3")
assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 0
assert await r.smembers("{foo}c") == set()
assert await r.smembers("{foo}c") == []
await r.sadd("{foo}b", "2", "3")
assert await r.sinterstore("{foo}c", "{foo}a", "{foo}b") == 2
assert await r.smembers("{foo}c") == {b"2", b"3"}
assert set(await r.smembers("{foo}c")) == {b"2", b"3"}

async def test_cluster_smove(self, r: RedisCluster) -> None:
await r.sadd("{foo}a", "a1", "a2")
await r.sadd("{foo}b", "b1", "b2")
assert await r.smove("{foo}a", "{foo}b", "a1")
assert await r.smembers("{foo}a") == {b"a2"}
assert await r.smembers("{foo}b") == {b"b1", b"b2", b"a1"}
assert await r.smembers("{foo}a") == [b"a2"]
assert set(await r.smembers("{foo}b")) == {b"b1", b"b2", b"a1"}

async def test_cluster_sunion(self, r: RedisCluster) -> None:
await r.sadd("{foo}a", "1", "2")
await r.sadd("{foo}b", "2", "3")
assert await r.sunion("{foo}a", "{foo}b") == {b"1", b"2", b"3"}
assert set(await r.sunion("{foo}a", "{foo}b")) == {b"1", b"2", b"3"}

async def test_cluster_sunionstore(self, r: RedisCluster) -> None:
await r.sadd("{foo}a", "1", "2")
await r.sadd("{foo}b", "2", "3")
assert await r.sunionstore("{foo}c", "{foo}a", "{foo}b") == 3
assert await r.smembers("{foo}c") == {b"1", b"2", b"3"}
assert set(await r.smembers("{foo}c")) == {b"1", b"2", b"3"}

@skip_if_server_version_lt("6.2.0")
async def test_cluster_zdiff(self, r: RedisCluster) -> None:
Expand Down
36 changes: 17 additions & 19 deletions tests/test_asyncio/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ async def test_zscan_iter(self, r: redis.Redis):
async def test_sadd(self, r: redis.Redis):
members = {b"1", b"2", b"3"}
await r.sadd("a", *members)
assert await r.smembers("a") == members
assert set(await r.smembers("a")) == members

async def test_scard(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3")
Expand All @@ -1415,34 +1415,34 @@ async def test_scard(self, r: redis.Redis):
@pytest.mark.onlynoncluster
async def test_sdiff(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3")
assert await r.sdiff("a", "b") == {b"1", b"2", b"3"}
assert set(await r.sdiff("a", "b")) == {b"1", b"2", b"3"}
await r.sadd("b", "2", "3")
assert await r.sdiff("a", "b") == {b"1"}
assert await r.sdiff("a", "b") == [b"1"]

@pytest.mark.onlynoncluster
async def test_sdiffstore(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3")
assert await r.sdiffstore("c", "a", "b") == 3
assert await r.smembers("c") == {b"1", b"2", b"3"}
assert set(await r.smembers("c")) == {b"1", b"2", b"3"}
await r.sadd("b", "2", "3")
assert await r.sdiffstore("c", "a", "b") == 1
assert await r.smembers("c") == {b"1"}
assert await r.smembers("c") == [b"1"]

@pytest.mark.onlynoncluster
async def test_sinter(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3")
assert await r.sinter("a", "b") == set()
assert await r.sinter("a", "b") == []
await r.sadd("b", "2", "3")
assert await r.sinter("a", "b") == {b"2", b"3"}
assert set(await r.sinter("a", "b")) == {b"2", b"3"}

@pytest.mark.onlynoncluster
async def test_sinterstore(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3")
assert await r.sinterstore("c", "a", "b") == 0
assert await r.smembers("c") == set()
assert await r.smembers("c") == []
await r.sadd("b", "2", "3")
assert await r.sinterstore("c", "a", "b") == 2
assert await r.smembers("c") == {b"2", b"3"}
assert set(await r.smembers("c")) == {b"2", b"3"}

async def test_sismember(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3")
Expand All @@ -1453,22 +1453,22 @@ async def test_sismember(self, r: redis.Redis):

async def test_smembers(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3")
assert await r.smembers("a") == {b"1", b"2", b"3"}
assert set(await r.smembers("a")) == {b"1", b"2", b"3"}

@pytest.mark.onlynoncluster
async def test_smove(self, r: redis.Redis):
await r.sadd("a", "a1", "a2")
await r.sadd("b", "b1", "b2")
assert await r.smove("a", "b", "a1")
assert await r.smembers("a") == {b"a2"}
assert await r.smembers("b") == {b"b1", b"b2", b"a1"}
assert await r.smembers("a") == [b"a2"]
assert set(await r.smembers("b")) == {b"b1", b"b2", b"a1"}

async def test_spop(self, r: redis.Redis):
s = [b"1", b"2", b"3"]
await r.sadd("a", *s)
value = await r.spop("a")
assert value in s
assert await r.smembers("a") == set(s) - {value}
assert set(await r.smembers("a")) == set(s) - {value}

@skip_if_server_version_lt("3.2.0")
async def test_spop_multi_value(self, r: redis.Redis):
Expand All @@ -1481,9 +1481,7 @@ async def test_spop_multi_value(self, r: redis.Redis):
assert value in s

response = await r.spop("a", 1)
assert_resp_response(
r, response, list(set(s) - set(values)), set(s) - set(values)
)
assert set(response) == set(s) - set(values)

async def test_srandmember(self, r: redis.Redis):
s = [b"1", b"2", b"3"]
Expand All @@ -1502,20 +1500,20 @@ async def test_srem(self, r: redis.Redis):
await r.sadd("a", "1", "2", "3", "4")
assert await r.srem("a", "5") == 0
assert await r.srem("a", "2", "4") == 2
assert await r.smembers("a") == {b"1", b"3"}
assert set(await r.smembers("a")) == {b"1", b"3"}

@pytest.mark.onlynoncluster
async def test_sunion(self, r: redis.Redis):
await r.sadd("a", "1", "2")
await r.sadd("b", "2", "3")
assert await r.sunion("a", "b") == {b"1", b"2", b"3"}
assert set(await r.sunion("a", "b")) == {b"1", b"2", b"3"}

@pytest.mark.onlynoncluster
async def test_sunionstore(self, r: redis.Redis):
await r.sadd("a", "1", "2")
await r.sadd("b", "2", "3")
assert await r.sunionstore("c", "a", "b") == 3
assert await r.smembers("c") == {b"1", b"2", b"3"}
assert set(await r.smembers("c")) == {b"1", b"2", b"3"}

# SORTED SET COMMANDS
async def test_zadd(self, r: redis.Redis):
Expand Down
5 changes: 2 additions & 3 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from redis.commands.search.result import Result
from redis.commands.search.suggestion import Suggestion
from tests.conftest import (
assert_resp_response,
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_resp_version,
Expand Down Expand Up @@ -862,7 +861,7 @@ async def test_tags(decoded_r: redis.Redis):
assert 1 == res["total_results"]

q2 = await decoded_r.ft().tagvals("tags")
assert set(tags.split(",") + tags2.split(",")) == q2
assert set(tags.split(",") + tags2.split(",")) == set(q2)


@pytest.mark.redismod
Expand Down Expand Up @@ -986,7 +985,7 @@ async def test_dict_operations(decoded_r: redis.Redis):

# Dump dict and inspect content
res = await decoded_r.ft().dict_dump("custom_dict")
assert_resp_response(decoded_r, res, ["item1", "item3"], {"item1", "item3"})
assert res == ["item1", "item3"]

# Remove rest of the items before reload
await decoded_r.ft().dict_del("custom_dict", *res)
Expand Down
Loading

0 comments on commit fd0b0d3

Please sign in to comment.