From 70e2a1321019a4df873230ac64f11f1215198624 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Sun, 21 Jan 2024 12:10:26 +0200 Subject: [PATCH 01/23] Add modules support to async RedisCluster (#3115) --- redis/commands/cluster.py | 3 +- tests/test_asyncio/test_bloom.py | 26 +--------------- tests/test_asyncio/test_graph.py | 20 ------------- tests/test_asyncio/test_json.py | 43 +-------------------------- tests/test_asyncio/test_timeseries.py | 25 +--------------- 5 files changed, 5 insertions(+), 112 deletions(-) diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 8dd463ed18..f31b88bc4e 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -44,7 +44,7 @@ ScriptCommands, ) from .helpers import list_or_args -from .redismodules import RedisModuleCommands +from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands if TYPE_CHECKING: from redis.asyncio.cluster import TargetNodesT @@ -907,6 +907,7 @@ class AsyncRedisClusterCommands( AsyncFunctionCommands, AsyncGearsCommands, AsyncModuleCommands, + AsyncRedisModuleCommands, ): """ A class for all Redis Cluster commands diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index d0a25e5625..278844416f 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -15,7 +15,6 @@ def intlist(obj): return [int(v) for v in obj] -@pytest.mark.redismod async def test_create(decoded_r: redis.Redis): """Test CREATE/RESERVE calls""" assert await decoded_r.bf().create("bloom", 0.01, 1000) @@ -30,13 +29,11 @@ async def test_create(decoded_r: redis.Redis): assert await decoded_r.topk().reserve("topk", 5, 100, 5, 0.9) -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_create(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("tDigest", 100) -@pytest.mark.redismod async def test_bf_add(decoded_r: redis.Redis): assert await decoded_r.bf().create("bloom", 0.01, 1000) assert 1 == await decoded_r.bf().add("bloom", "foo") @@ -49,7 +46,6 @@ async def test_bf_add(decoded_r: redis.Redis): assert [1, 0] == intlist(await decoded_r.bf().mexists("bloom", "foo", "noexist")) -@pytest.mark.redismod async def test_bf_insert(decoded_r: redis.Redis): assert await decoded_r.bf().create("bloom", 0.01, 1000) assert [1] == intlist(await decoded_r.bf().insert("bloom", ["foo"])) @@ -80,7 +76,6 @@ async def test_bf_insert(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_bf_scandump_and_loadchunk(decoded_r: redis.Redis): # Store a filter await decoded_r.bf().create("myBloom", "0.0001", "1000") @@ -132,7 +127,6 @@ async def do_verify(): await decoded_r.bf().create("myBloom", "0.0001", "10000000") -@pytest.mark.redismod async def test_bf_info(decoded_r: redis.Redis): expansion = 4 # Store a filter @@ -164,7 +158,6 @@ async def test_bf_info(decoded_r: redis.Redis): assert True -@pytest.mark.redismod async def test_bf_card(decoded_r: redis.Redis): # return 0 if the key does not exist assert await decoded_r.bf().card("not_exist") == 0 @@ -179,7 +172,6 @@ async def test_bf_card(decoded_r: redis.Redis): await decoded_r.bf().card("setKey") -@pytest.mark.redismod async def test_cf_add_and_insert(decoded_r: redis.Redis): assert await decoded_r.cf().create("cuckoo", 1000) assert await decoded_r.cf().add("cuckoo", "filter") @@ -205,7 +197,6 @@ async def test_cf_add_and_insert(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_cf_exists_and_del(decoded_r: redis.Redis): assert await decoded_r.cf().create("cuckoo", 1000) assert await decoded_r.cf().add("cuckoo", "filter") @@ -217,7 +208,6 @@ async def test_cf_exists_and_del(decoded_r: redis.Redis): assert 0 == await decoded_r.cf().count("cuckoo", "filter") -@pytest.mark.redismod async def test_cms(decoded_r: redis.Redis): assert await decoded_r.cms().initbydim("dim", 1000, 5) assert await decoded_r.cms().initbyprob("prob", 0.01, 0.01) @@ -233,7 +223,6 @@ async def test_cms(decoded_r: redis.Redis): assert 25 == info["count"] -@pytest.mark.redismod @pytest.mark.onlynoncluster async def test_cms_merge(decoded_r: redis.Redis): assert await decoded_r.cms().initbydim("A", 1000, 5) @@ -251,7 +240,6 @@ async def test_cms_merge(decoded_r: redis.Redis): assert [16, 15, 21] == await decoded_r.cms().query("C", "foo", "bar", "baz") -@pytest.mark.redismod async def test_topk(decoded_r: redis.Redis): # test list with empty buckets assert await decoded_r.topk().reserve("topk", 3, 50, 4, 0.9) @@ -332,7 +320,6 @@ async def test_topk(decoded_r: redis.Redis): assert 0.9 == round(float(info["decay"]), 1) -@pytest.mark.redismod async def test_topk_incrby(decoded_r: redis.Redis): await decoded_r.flushdb() assert await decoded_r.topk().reserve("topk", 3, 10, 3, 1) @@ -347,7 +334,6 @@ async def test_topk_incrby(decoded_r: redis.Redis): ) -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_reset(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("tDigest", 10) @@ -364,7 +350,6 @@ async def test_tdigest_reset(decoded_r: redis.Redis): ) -@pytest.mark.redismod @pytest.mark.onlynoncluster async def test_tdigest_merge(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("to-tDigest", 10) @@ -392,7 +377,6 @@ async def test_tdigest_merge(decoded_r: redis.Redis): assert 4.0 == await decoded_r.tdigest().max("to-tDigest") -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_min_and_max(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("tDigest", 100) @@ -403,7 +387,6 @@ async def test_tdigest_min_and_max(decoded_r: redis.Redis): assert 1 == await decoded_r.tdigest().min("tDigest") -@pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") async def test_tdigest_quantile(decoded_r: redis.Redis): @@ -432,7 +415,6 @@ async def test_tdigest_quantile(decoded_r: redis.Redis): assert [3.0, 5.0] == res -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_cdf(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("tDigest", 100) @@ -444,7 +426,6 @@ async def test_tdigest_cdf(decoded_r: redis.Redis): assert [0.1, 0.9] == [round(x, 1) for x in res] -@pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") async def test_tdigest_trimmed_mean(decoded_r: redis.Redis): @@ -455,7 +436,6 @@ async def test_tdigest_trimmed_mean(decoded_r: redis.Redis): assert 4.5 == await decoded_r.tdigest().trimmed_mean("tDigest", 0.4, 0.5) -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_rank(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("t-digest", 500) @@ -466,7 +446,6 @@ async def test_tdigest_rank(decoded_r: redis.Redis): assert [-1, 20, 9] == await decoded_r.tdigest().rank("t-digest", -20, 20, 9) -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_revrank(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("t-digest", 500) @@ -476,7 +455,6 @@ async def test_tdigest_revrank(decoded_r: redis.Redis): assert [-1, 19, 9] == await decoded_r.tdigest().revrank("t-digest", 21, 0, 10) -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_byrank(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("t-digest", 500) @@ -488,7 +466,6 @@ async def test_tdigest_byrank(decoded_r: redis.Redis): (await decoded_r.tdigest().byrank("t-digest", -1))[0] -@pytest.mark.redismod @pytest.mark.experimental async def test_tdigest_byrevrank(decoded_r: redis.Redis): assert await decoded_r.tdigest().create("t-digest", 500) @@ -500,8 +477,7 @@ async def test_tdigest_byrevrank(decoded_r: redis.Redis): (await decoded_r.tdigest().byrevrank("t-digest", -1))[0] -# @pytest.mark.redismod -# async def test_pipeline(decoded_r: redis.Redis): +# # async def test_pipeline(decoded_r: redis.Redis): # pipeline = await decoded_r.bf().pipeline() # assert not await decoded_r.bf().execute_command("get pipeline") # diff --git a/tests/test_asyncio/test_graph.py b/tests/test_asyncio/test_graph.py index 22195901e6..4caf79470e 100644 --- a/tests/test_asyncio/test_graph.py +++ b/tests/test_asyncio/test_graph.py @@ -6,14 +6,12 @@ from tests.conftest import skip_if_redis_enterprise -@pytest.mark.redismod async def test_bulk(decoded_r): with pytest.raises(NotImplementedError): await decoded_r.graph().bulk() await decoded_r.graph().bulk(foo="bar!") -@pytest.mark.redismod async def test_graph_creation(decoded_r: redis.Redis): graph = decoded_r.graph() @@ -58,7 +56,6 @@ async def test_graph_creation(decoded_r: redis.Redis): await graph.delete() -@pytest.mark.redismod async def test_array_functions(decoded_r: redis.Redis): graph = decoded_r.graph() @@ -81,7 +78,6 @@ async def test_array_functions(decoded_r: redis.Redis): assert [a] == result.result_set[0][0] -@pytest.mark.redismod async def test_path(decoded_r: redis.Redis): node0 = Node(node_id=0, label="L1") node1 = Node(node_id=1, label="L1") @@ -101,7 +97,6 @@ async def test_path(decoded_r: redis.Redis): assert expected_results == result.result_set -@pytest.mark.redismod async def test_param(decoded_r: redis.Redis): params = [1, 2.3, "str", True, False, None, [0, 1, 2]] query = "RETURN $param" @@ -111,7 +106,6 @@ async def test_param(decoded_r: redis.Redis): assert expected_results == result.result_set -@pytest.mark.redismod async def test_map(decoded_r: redis.Redis): query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" @@ -128,7 +122,6 @@ async def test_map(decoded_r: redis.Redis): assert actual == expected -@pytest.mark.redismod async def test_point(decoded_r: redis.Redis): query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" expected_lat = 32.070794860 @@ -145,7 +138,6 @@ async def test_point(decoded_r: redis.Redis): assert abs(actual["longitude"] - expected_lon) < 0.001 -@pytest.mark.redismod async def test_index_response(decoded_r: redis.Redis): result_set = await decoded_r.graph().query("CREATE INDEX ON :person(age)") assert 1 == result_set.indices_created @@ -160,7 +152,6 @@ async def test_index_response(decoded_r: redis.Redis): await decoded_r.graph().query("DROP INDEX ON :person(age)") -@pytest.mark.redismod async def test_stringify_query_result(decoded_r: redis.Redis): graph = decoded_r.graph() @@ -214,7 +205,6 @@ async def test_stringify_query_result(decoded_r: redis.Redis): await graph.delete() -@pytest.mark.redismod async def test_optional_match(decoded_r: redis.Redis): # Build a graph of form (a)-[R]->(b) node0 = Node(node_id=0, label="L1", properties={"value": "a"}) @@ -239,7 +229,6 @@ async def test_optional_match(decoded_r: redis.Redis): await graph.delete() -@pytest.mark.redismod async def test_cached_execution(decoded_r: redis.Redis): await decoded_r.graph().query("CREATE ()") @@ -259,7 +248,6 @@ async def test_cached_execution(decoded_r: redis.Redis): assert cached_result.cached_execution -@pytest.mark.redismod async def test_slowlog(decoded_r: redis.Redis): create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), @@ -272,7 +260,6 @@ async def test_slowlog(decoded_r: redis.Redis): assert results[0][2] == create_query -@pytest.mark.redismod @pytest.mark.xfail(strict=False) async def test_query_timeout(decoded_r: redis.Redis): # Build a sample graph with 1000 nodes. @@ -287,7 +274,6 @@ async def test_query_timeout(decoded_r: redis.Redis): assert False is False -@pytest.mark.redismod async def test_read_only_query(decoded_r: redis.Redis): with pytest.raises(Exception): # Issue a write query, specifying read-only true, @@ -296,7 +282,6 @@ async def test_read_only_query(decoded_r: redis.Redis): assert False is False -@pytest.mark.redismod async def test_profile(decoded_r: redis.Redis): q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" profile = (await decoded_r.graph().profile(q)).result_set @@ -311,7 +296,6 @@ async def test_profile(decoded_r: redis.Redis): assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile -@pytest.mark.redismod @skip_if_redis_enterprise() async def test_config(decoded_r: redis.Redis): config_name = "RESULTSET_SIZE" @@ -343,7 +327,6 @@ async def test_config(decoded_r: redis.Redis): await decoded_r.graph().config("RESULTSET_SIZE", -100, set=True) -@pytest.mark.redismod @pytest.mark.onlynoncluster async def test_list_keys(decoded_r: redis.Redis): result = await decoded_r.graph().list_keys() @@ -367,7 +350,6 @@ async def test_list_keys(decoded_r: redis.Redis): assert result == [] -@pytest.mark.redismod async def test_multi_label(decoded_r: redis.Redis): redis_graph = decoded_r.graph("g") @@ -393,7 +375,6 @@ async def test_multi_label(decoded_r: redis.Redis): assert True -@pytest.mark.redismod async def test_execution_plan(decoded_r: redis.Redis): redis_graph = decoded_r.graph("execution_plan") create_query = """CREATE @@ -412,7 +393,6 @@ async def test_execution_plan(decoded_r: redis.Redis): await redis_graph.delete() -@pytest.mark.redismod async def test_explain(decoded_r: redis.Redis): redis_graph = decoded_r.graph("execution_plan") # graph creation / population diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index a35bd4795f..920ec71dce 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -5,7 +5,6 @@ from tests.conftest import assert_resp_response, skip_ifmodversion_lt -@pytest.mark.redismod async def test_json_setbinarykey(decoded_r: redis.Redis): d = {"hello": "world", b"some": "value"} with pytest.raises(TypeError): @@ -13,7 +12,6 @@ async def test_json_setbinarykey(decoded_r: redis.Redis): assert await decoded_r.json().set("somekey", Path.root_path(), d, decode_keys=True) -@pytest.mark.redismod async def test_json_setgetdeleteforget(decoded_r: redis.Redis): assert await decoded_r.json().set("foo", Path.root_path(), "bar") assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) @@ -23,13 +21,11 @@ async def test_json_setgetdeleteforget(decoded_r: redis.Redis): assert await decoded_r.exists("foo") == 0 -@pytest.mark.redismod async def test_jsonget(decoded_r: redis.Redis): await decoded_r.json().set("foo", Path.root_path(), "bar") assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) -@pytest.mark.redismod async def test_json_get_jset(decoded_r: redis.Redis): assert await decoded_r.json().set("foo", Path.root_path(), "bar") assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) @@ -38,7 +34,6 @@ async def test_json_get_jset(decoded_r: redis.Redis): assert await decoded_r.exists("foo") == 0 -@pytest.mark.redismod async def test_nonascii_setgetdelete(decoded_r: redis.Redis): assert await decoded_r.json().set("notascii", Path.root_path(), "hyvää-élève") res = "hyvää-élève" @@ -49,7 +44,6 @@ async def test_nonascii_setgetdelete(decoded_r: redis.Redis): assert await decoded_r.exists("notascii") == 0 -@pytest.mark.redismod @skip_ifmodversion_lt("2.6.0", "ReJSON") async def test_json_merge(decoded_r: redis.Redis): # Test with root path $ @@ -84,7 +78,6 @@ async def test_json_merge(decoded_r: redis.Redis): } -@pytest.mark.redismod async def test_jsonsetexistentialmodifiersshouldsucceed(decoded_r: redis.Redis): obj = {"foo": "bar"} assert await decoded_r.json().set("obj", Path.root_path(), obj) @@ -102,7 +95,6 @@ async def test_jsonsetexistentialmodifiersshouldsucceed(decoded_r: redis.Redis): await decoded_r.json().set("obj", Path("foo"), "baz", nx=True, xx=True) -@pytest.mark.redismod async def test_mgetshouldsucceed(decoded_r: redis.Redis): await decoded_r.json().set("1", Path.root_path(), 1) await decoded_r.json().set("2", Path.root_path(), 2) @@ -111,7 +103,6 @@ async def test_mgetshouldsucceed(decoded_r: redis.Redis): assert await decoded_r.json().mget([1, 2], Path.root_path()) == [1, 2] -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.6.0", "ReJSON") async def test_mset(decoded_r: redis.Redis): @@ -123,7 +114,6 @@ async def test_mset(decoded_r: redis.Redis): assert await decoded_r.json().mget(["1", "2"], Path.root_path()) == [1, 2] -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release async def test_clear(decoded_r: redis.Redis): await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -131,7 +121,6 @@ async def test_clear(decoded_r: redis.Redis): assert_resp_response(decoded_r, await decoded_r.json().get("arr"), [], [[[]]]) -@pytest.mark.redismod async def test_type(decoded_r: redis.Redis): await decoded_r.json().set("1", Path.root_path(), 1) assert_resp_response( @@ -145,7 +134,6 @@ async def test_type(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_numincrby(decoded_r): await decoded_r.json().set("num", Path.root_path(), 1) assert_resp_response( @@ -157,7 +145,6 @@ async def test_numincrby(decoded_r): assert_resp_response(decoded_r, res, 1.25, [1.25]) -@pytest.mark.redismod async def test_nummultby(decoded_r: redis.Redis): await decoded_r.json().set("num", Path.root_path(), 1) @@ -170,7 +157,6 @@ async def test_nummultby(decoded_r: redis.Redis): assert_resp_response(decoded_r, res, 2.5, [2.5]) -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release async def test_toggle(decoded_r: redis.Redis): await decoded_r.json().set("bool", Path.root_path(), False) @@ -182,7 +168,6 @@ async def test_toggle(decoded_r: redis.Redis): await decoded_r.json().toggle("num", Path.root_path()) -@pytest.mark.redismod async def test_strappend(decoded_r: redis.Redis): await decoded_r.json().set("jsonkey", Path.root_path(), "foo") assert 6 == await decoded_r.json().strappend("jsonkey", "bar") @@ -190,7 +175,6 @@ async def test_strappend(decoded_r: redis.Redis): assert_resp_response(decoded_r, res, "foobar", [["foobar"]]) -@pytest.mark.redismod async def test_strlen(decoded_r: redis.Redis): await decoded_r.json().set("str", Path.root_path(), "foo") assert 3 == await decoded_r.json().strlen("str", Path.root_path()) @@ -199,7 +183,6 @@ async def test_strlen(decoded_r: redis.Redis): assert 6 == await decoded_r.json().strlen("str") -@pytest.mark.redismod async def test_arrappend(decoded_r: redis.Redis): await decoded_r.json().set("arr", Path.root_path(), [1]) assert 2 == await decoded_r.json().arrappend("arr", Path.root_path(), 2) @@ -207,7 +190,6 @@ async def test_arrappend(decoded_r: redis.Redis): assert 7 == await decoded_r.json().arrappend("arr", Path.root_path(), *[5, 6, 7]) -@pytest.mark.redismod async def test_arrindex(decoded_r: redis.Redis): r_path = Path.root_path() await decoded_r.json().set("arr", r_path, [0, 1, 2, 3, 4]) @@ -220,7 +202,6 @@ async def test_arrindex(decoded_r: redis.Redis): assert -1 == await decoded_r.json().arrindex("arr", r_path, 4, start=1, stop=3) -@pytest.mark.redismod async def test_arrinsert(decoded_r: redis.Redis): await decoded_r.json().set("arr", Path.root_path(), [0, 4]) assert 5 == await decoded_r.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) @@ -234,7 +215,6 @@ async def test_arrinsert(decoded_r: redis.Redis): assert_resp_response(decoded_r, await decoded_r.json().get("val2"), res, [[res]]) -@pytest.mark.redismod async def test_arrlen(decoded_r: redis.Redis): await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 5 == await decoded_r.json().arrlen("arr", Path.root_path()) @@ -242,7 +222,6 @@ async def test_arrlen(decoded_r: redis.Redis): assert await decoded_r.json().arrlen("fakekey") is None -@pytest.mark.redismod async def test_arrpop(decoded_r: redis.Redis): await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 4 == await decoded_r.json().arrpop("arr", Path.root_path(), 4) @@ -260,7 +239,6 @@ async def test_arrpop(decoded_r: redis.Redis): assert await decoded_r.json().arrpop("arr") is None -@pytest.mark.redismod async def test_arrtrim(decoded_r: redis.Redis): await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 3 == await decoded_r.json().arrtrim("arr", Path.root_path(), 1, 3) @@ -284,7 +262,6 @@ async def test_arrtrim(decoded_r: redis.Redis): assert 0 == await decoded_r.json().arrtrim("arr", Path.root_path(), 9, 11) -@pytest.mark.redismod async def test_resp(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": 1, "qaz": True} await decoded_r.json().set("obj", Path.root_path(), obj) @@ -294,7 +271,6 @@ async def test_resp(decoded_r: redis.Redis): assert isinstance(await decoded_r.json().resp("obj"), list) -@pytest.mark.redismod async def test_objkeys(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": "qaz"} await decoded_r.json().set("obj", Path.root_path(), obj) @@ -311,7 +287,6 @@ async def test_objkeys(decoded_r: redis.Redis): assert await decoded_r.json().objkeys("fakekey") is None -@pytest.mark.redismod async def test_objlen(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": "qaz"} await decoded_r.json().set("obj", Path.root_path(), obj) @@ -345,7 +320,6 @@ async def test_objlen(decoded_r: redis.Redis): # assert await decoded_r.get("foo") is None -@pytest.mark.redismod async def test_json_delete_with_dollar(decoded_r: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert await decoded_r.json().set("doc1", "$", doc1) @@ -399,7 +373,6 @@ async def test_json_delete_with_dollar(decoded_r: redis.Redis): await decoded_r.json().delete("not_a_document", "..a") -@pytest.mark.redismod async def test_json_forget_with_dollar(decoded_r: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert await decoded_r.json().set("doc1", "$", doc1) @@ -452,7 +425,7 @@ async def test_json_forget_with_dollar(decoded_r: redis.Redis): await decoded_r.json().forget("not_a_document", "..a") -@pytest.mark.redismod +@pytest.mark.onlynoncluster async def test_json_mget_dollar(decoded_r: redis.Redis): # Test mget with multi paths await decoded_r.json().set( @@ -488,7 +461,6 @@ async def test_json_mget_dollar(decoded_r: redis.Redis): assert res == [None, None] -@pytest.mark.redismod async def test_numby_commands_dollar(decoded_r: redis.Redis): # Test NUMINCRBY await decoded_r.json().set( @@ -543,7 +515,6 @@ async def test_numby_commands_dollar(decoded_r: redis.Redis): await decoded_r.json().nummultby("doc1", ".b[0].a", 3) == 6 -@pytest.mark.redismod async def test_strappend_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} @@ -574,7 +545,6 @@ async def test_strappend_dollar(decoded_r: redis.Redis): await decoded_r.json().strappend("doc1", "piu") -@pytest.mark.redismod async def test_strlen_dollar(decoded_r: redis.Redis): # Test multi await decoded_r.json().set( @@ -595,7 +565,6 @@ async def test_strlen_dollar(decoded_r: redis.Redis): await decoded_r.json().strlen("non_existing_doc", "$..a") -@pytest.mark.redismod async def test_arrappend_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -669,7 +638,6 @@ async def test_arrappend_dollar(decoded_r: redis.Redis): await decoded_r.json().arrappend("non_existing_doc", "$..a") -@pytest.mark.redismod async def test_arrinsert_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -708,7 +676,6 @@ async def test_arrinsert_dollar(decoded_r: redis.Redis): await decoded_r.json().arrappend("non_existing_doc", "$..a") -@pytest.mark.redismod async def test_arrlen_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -754,7 +721,6 @@ async def test_arrlen_dollar(decoded_r: redis.Redis): assert await decoded_r.json().arrlen("non_existing_doc", "..a") is None -@pytest.mark.redismod async def test_arrpop_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -796,7 +762,6 @@ async def test_arrpop_dollar(decoded_r: redis.Redis): await decoded_r.json().arrpop("non_existing_doc", "..a") -@pytest.mark.redismod async def test_arrtrim_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -848,7 +813,6 @@ async def test_arrtrim_dollar(decoded_r: redis.Redis): await decoded_r.json().arrtrim("non_existing_doc", "..a", 1, 1) -@pytest.mark.redismod async def test_objkeys_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -878,7 +842,6 @@ async def test_objkeys_dollar(decoded_r: redis.Redis): assert await decoded_r.json().objkeys("doc1", "$..nowhere") == [] -@pytest.mark.redismod async def test_objlen_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -914,7 +877,6 @@ async def test_objlen_dollar(decoded_r: redis.Redis): await decoded_r.json().objlen("doc1", ".nowhere") -@pytest.mark.redismod def load_types_data(nested_key_name): td = { "object": {}, @@ -934,7 +896,6 @@ def load_types_data(nested_key_name): return jdata, types -@pytest.mark.redismod async def test_type_dollar(decoded_r: redis.Redis): jdata, jtypes = load_types_data("a") await decoded_r.json().set("doc1", "$", jdata) @@ -953,7 +914,6 @@ async def test_type_dollar(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_clear_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", @@ -1007,7 +967,6 @@ async def test_clear_dollar(decoded_r: redis.Redis): await decoded_r.json().clear("non_existing_doc", "$..a") -@pytest.mark.redismod async def test_toggle_dollar(decoded_r: redis.Redis): await decoded_r.json().set( "doc1", diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index 91c15c3db2..b44219707e 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -10,7 +10,6 @@ ) -@pytest.mark.redismod async def test_create(decoded_r: redis.Redis): assert await decoded_r.ts().create(1) assert await decoded_r.ts().create(2, retention_msecs=5) @@ -28,7 +27,6 @@ async def test_create(decoded_r: redis.Redis): assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") async def test_create_duplicate_policy(decoded_r: redis.Redis): # Test for duplicate policy @@ -44,7 +42,6 @@ async def test_create_duplicate_policy(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_alter(decoded_r: redis.Redis): assert await decoded_r.ts().create(1) res = await decoded_r.ts().info(1) @@ -67,7 +64,6 @@ async def test_alter(decoded_r: redis.Redis): ) -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") async def test_alter_diplicate_policy(decoded_r: redis.Redis): assert await decoded_r.ts().create(1) @@ -82,7 +78,6 @@ async def test_alter_diplicate_policy(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_add(decoded_r: redis.Redis): assert 1 == await decoded_r.ts().add(1, 1, 1) assert 2 == await decoded_r.ts().add(2, 2, 3, retention_msecs=10) @@ -105,7 +100,6 @@ async def test_add(decoded_r: redis.Redis): assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") async def test_add_duplicate_policy(r: redis.Redis): # Test for duplicate policy BLOCK @@ -146,7 +140,6 @@ async def test_add_duplicate_policy(r: redis.Redis): assert 5.0 == res[1] -@pytest.mark.redismod async def test_madd(decoded_r: redis.Redis): await decoded_r.ts().create("a") assert [1, 2, 3] == await decoded_r.ts().madd( @@ -154,7 +147,6 @@ async def test_madd(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_incrby_decrby(decoded_r: redis.Redis): for _ in range(100): assert await decoded_r.ts().incrby(1, 1) @@ -183,7 +175,6 @@ async def test_incrby_decrby(decoded_r: redis.Redis): assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) -@pytest.mark.redismod async def test_create_and_delete_rule(decoded_r: redis.Redis): # test rule creation time = 100 @@ -207,7 +198,6 @@ async def test_create_and_delete_rule(decoded_r: redis.Redis): assert not info["rules"] -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") async def test_del_range(decoded_r: redis.Redis): try: @@ -224,7 +214,6 @@ async def test_del_range(decoded_r: redis.Redis): ) -@pytest.mark.redismod async def test_range(r: redis.Redis): for i in range(100): await r.ts().add(1, i, i % 7) @@ -239,7 +228,6 @@ async def test_range(r: redis.Redis): assert 10 == len(await r.ts().range(1, 0, 500, count=10)) -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") async def test_range_advanced(decoded_r: redis.Redis): for i in range(100): @@ -270,7 +258,6 @@ async def test_range_advanced(decoded_r: redis.Redis): assert_resp_response(decoded_r, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") async def test_rev_range(decoded_r: redis.Redis): for i in range(100): @@ -314,7 +301,6 @@ async def test_rev_range(decoded_r: redis.Redis): ) -@pytest.mark.redismod @pytest.mark.onlynoncluster async def test_multi_range(decoded_r: redis.Redis): await decoded_r.ts().create(1, labels={"Test": "This", "team": "ny"}) @@ -369,7 +355,6 @@ async def test_multi_range(decoded_r: redis.Redis): assert {"Test": "This", "team": "ny"} == res["1"][0] -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") async def test_multi_range_advanced(decoded_r: redis.Redis): @@ -487,7 +472,6 @@ async def test_multi_range_advanced(decoded_r: redis.Redis): assert [[0, 5.0], [5, 6.0]] == res["1"][2] -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") async def test_multi_reverse_range(decoded_r: redis.Redis): @@ -651,7 +635,6 @@ async def test_multi_reverse_range(decoded_r: redis.Redis): assert [[1, 10.0], [0, 1.0]] == res["1"][2] -@pytest.mark.redismod async def test_get(decoded_r: redis.Redis): name = "test" await decoded_r.ts().create(name) @@ -662,7 +645,6 @@ async def test_get(decoded_r: redis.Redis): assert 4 == (await decoded_r.ts().get(name))[1] -@pytest.mark.redismod @pytest.mark.onlynoncluster async def test_mget(decoded_r: redis.Redis): await decoded_r.ts().create(1, labels={"Test": "This"}) @@ -698,7 +680,6 @@ async def test_mget(decoded_r: redis.Redis): assert {"Taste": "That", "Test": "This"} == res["2"][0] -@pytest.mark.redismod async def test_info(decoded_r: redis.Redis): await decoded_r.ts().create( 1, retention_msecs=5, labels={"currentLabel": "currentData"} @@ -710,7 +691,6 @@ async def test_info(decoded_r: redis.Redis): assert info["labels"]["currentLabel"] == "currentData" -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") async def testInfoDuplicatePolicy(decoded_r: redis.Redis): await decoded_r.ts().create( @@ -728,7 +708,6 @@ async def testInfoDuplicatePolicy(decoded_r: redis.Redis): ) -@pytest.mark.redismod @pytest.mark.onlynoncluster async def test_query_index(decoded_r: redis.Redis): await decoded_r.ts().create(1, labels={"Test": "This"}) @@ -740,8 +719,7 @@ async def test_query_index(decoded_r: redis.Redis): ) -# @pytest.mark.redismod -# async def test_pipeline(r: redis.Redis): +# # async def test_pipeline(r: redis.Redis): # pipeline = await r.ts().pipeline() # pipeline.create("with_pipeline") # for i in range(100): @@ -754,7 +732,6 @@ async def test_query_index(decoded_r: redis.Redis): # assert await r.ts().get("with_pipeline")[1] == 99 * 1.1 -@pytest.mark.redismod async def test_uncompressed(decoded_r: redis.Redis): await decoded_r.ts().create("compressed") await decoded_r.ts().create("uncompressed", uncompressed=True) From 1a7d474268fe7072686369adc20aa498d63f063e Mon Sep 17 00:00:00 2001 From: Ahmed Ashraf <104530599+ahmedabdou14@users.noreply.github.com> Date: Tue, 23 Jan 2024 14:35:54 +0300 Subject: [PATCH 02/23] Fix grammer in BlockingConnectionPool class documentation (#3120) Co-authored-by: ahmedabdou14 --- redis/asyncio/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 77aa21f034..07c4262233 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1280,7 +1280,7 @@ class BlockingConnectionPool(ConnectionPool): connection from the pool when all of connections are in use, rather than raising a :py:class:`~redis.ConnectionError` (as the default :py:class:`~redis.asyncio.ConnectionPool` implementation does), it - makes blocks the current `Task` for a specified number of seconds until + blocks the current `Task` for a specified number of seconds until a connection becomes available. Use ``max_connections`` to increase / decrease the pool size:: From 2f88840383453d713859244b8206d7f942c3bcc4 Mon Sep 17 00:00:00 2001 From: Dongkeun Lee <3315213+zakaf@users.noreply.github.com> Date: Mon, 5 Feb 2024 02:04:10 +0900 Subject: [PATCH 03/23] release already acquired connections on ClusterPipeline, when get_connection raises an exception (#3133) Signed-off-by: zach.lee --- redis/cluster.py | 2 ++ tests/test_cluster.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/redis/cluster.py b/redis/cluster.py index c36665eb5c..ba25b92246 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2143,6 +2143,8 @@ def _send_cluster_commands( try: connection = get_connection(redis_node, c.args) except ConnectionError: + for n in nodes.values(): + n.connection_pool.release(n.connection) # Connection retries are being handled in the node's # Retry object. Reinitialize the node -> slot table. self.nodes_manager.initialize() diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 854b64c563..8a44d45ea3 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -10,6 +10,7 @@ from unittest.mock import DEFAULT, Mock, call, patch import pytest +import redis from redis import Redis from redis._parsers import CommandsParser from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff @@ -3250,6 +3251,31 @@ def raise_ask_error(): assert ask_node.redis_connection.connection.read_response.called assert res == ["MOCK_OK"] + def test_return_previously_acquired_connections(self, r): + # in order to ensure that a pipeline will make use of connections + # from different nodes + assert r.keyslot("a") != r.keyslot("b") + + orig_func = redis.cluster.get_connection + with patch("redis.cluster.get_connection") as get_connection: + + def raise_error(target_node, *args, **kwargs): + if get_connection.call_count == 2: + raise ConnectionError("mocked error") + else: + return orig_func(target_node, *args, **kwargs) + + get_connection.side_effect = raise_error + + r.pipeline().get("a").get("b").execute() + + # 4 = 2 get_connections per execution * 2 executions + assert get_connection.call_count == 4 + for cluster_node in r.nodes_manager.nodes_cache.values(): + connection_pool = cluster_node.redis_connection.connection_pool + num_of_conns = len(connection_pool._available_connections) + assert num_of_conns == connection_pool._created_connections + def test_empty_stack(self, r): """ If pipeline is executed with no commands it should From b1ee455ec2e384f9f4092717b02e640d06e8e1a6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 01:58:17 +0200 Subject: [PATCH 04/23] Bump actions/stale from 3 to 9 (#3132) Bumps [actions/stale](https://github.com/actions/stale) from 3 to 9. - [Release notes](https://github.com/actions/stale/releases) - [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/stale/compare/v3...v9) --- updated-dependencies: - dependency-name: actions/stale dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/stale-issues.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index 32fd9e8179..445af1c818 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/stale@v3 + - uses: actions/stale@v9 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: 'This issue is marked stale. It will be closed in 30 days if it is not updated.' From 7f632962d124d74a361a961a5ee452593098d1b8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 01:58:39 +0200 Subject: [PATCH 05/23] Bump codecov/codecov-action from 3 to 4 (#3131) Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 4. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v3...v4) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 7aaf346170..695e1b307c 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -88,7 +88,7 @@ jobs: path: '${{matrix.test-type}}*results.xml' - name: Upload codecov coverage - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: fail_ci_if_error: false From 7df57e5a32ad44446c1209ed7b8a38e996036354 Mon Sep 17 00:00:00 2001 From: poiuj <1099644+poiuj@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:12:59 +0200 Subject: [PATCH 06/23] Allow to control the minimum SSL version (#3127) * Allow to control the minimum SSL version It's useful for applications that has strict security requirements. * Add tests for minimum SSL version The commit updates test_tcp_ssl_connect for both sync and async connections. Now it sets the minimum SSL version. The test is ran with both TLSv1.2 and TLSv1.3 (if supported). A new test case is test_tcp_ssl_version_mismatch. The test added for both sync and async connections. It uses TLS 1.3 on the client side, and TLS 1.2 on the server side. It expects a connection error. The test is skipped if TLS 1.3 is not supported. * Add example of using a minimum TLS version --- CHANGES | 1 + docs/examples/ssl_connection_examples.ipynb | 36 +++++++++++++ redis/asyncio/client.py | 3 ++ redis/asyncio/cluster.py | 3 ++ redis/asyncio/connection.py | 11 ++++ redis/client.py | 2 + redis/connection.py | 5 ++ tests/test_asyncio/test_connect.py | 54 ++++++++++++++++++-- tests/test_connect.py | 56 ++++++++++++++++++--- 9 files changed, 161 insertions(+), 10 deletions(-) diff --git a/CHANGES b/CHANGES index 3d9d6292a1..e0959b0ef3 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Allow to control the minimum SSL version * Add an optional lock_name attribute to LockError. * Fix return types for `get`, `set_path` and `strappend` in JSONCommands * Connection.register_connect_callback() is made public. diff --git a/docs/examples/ssl_connection_examples.ipynb b/docs/examples/ssl_connection_examples.ipynb index ab3b4415ae..a3d015619f 100644 --- a/docs/examples/ssl_connection_examples.ipynb +++ b/docs/examples/ssl_connection_examples.ipynb @@ -76,6 +76,42 @@ "ssl_connection.ping()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Connecting to a Redis instance via SSL, while specifying a minimum TLS version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import redis\n", + "import ssl\n", + "\n", + "ssl_conn = redis.Redis(\n", + " host=\"localhost\",\n", + " port=6666,\n", + " ssl=True,\n", + " ssl_min_version=ssl.TLSVersion.TLSv1_3,\n", + ")\n", + "ssl_conn.ping()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 88de893f5b..62bdc7dd5c 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -2,6 +2,7 @@ import copy import inspect import re +import ssl import warnings from typing import ( TYPE_CHECKING, @@ -226,6 +227,7 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, + ssl_min_version: Optional[ssl.TLSVersion] = None, max_connections: Optional[int] = None, single_connection_client: bool = False, health_check_interval: int = 0, @@ -332,6 +334,7 @@ def __init__( "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, + "ssl_min_version": ssl_min_version, } ) # This arg only used if no pool is passed in diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 337c7bbdcc..4fb2fc4647 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -2,6 +2,7 @@ import collections import random import socket +import ssl import warnings from typing import ( Any, @@ -271,6 +272,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + ssl_min_version: Optional[ssl.TLSVersion] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, cache_enabled: bool = False, @@ -344,6 +346,7 @@ def __init__( "ssl_certfile": ssl_certfile, "ssl_check_hostname": ssl_check_hostname, "ssl_keyfile": ssl_keyfile, + "ssl_min_version": ssl_min_version, } ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 07c4262233..81df3b3543 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -823,6 +823,7 @@ def __init__( ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, ssl_check_hostname: bool = False, + ssl_min_version: Optional[ssl.TLSVersion] = None, **kwargs, ): self.ssl_context: RedisSSLContext = RedisSSLContext( @@ -832,6 +833,7 @@ def __init__( ca_certs=ssl_ca_certs, ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, + min_version=ssl_min_version, ) super().__init__(**kwargs) @@ -864,6 +866,10 @@ def ca_data(self): def check_hostname(self): return self.ssl_context.check_hostname + @property + def min_version(self): + return self.ssl_context.min_version + class RedisSSLContext: __slots__ = ( @@ -874,6 +880,7 @@ class RedisSSLContext: "ca_data", "context", "check_hostname", + "min_version", ) def __init__( @@ -884,6 +891,7 @@ def __init__( ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, + min_version: Optional[ssl.TLSVersion] = None, ): self.keyfile = keyfile self.certfile = certfile @@ -903,6 +911,7 @@ def __init__( self.ca_certs = ca_certs self.ca_data = ca_data self.check_hostname = check_hostname + self.min_version = min_version self.context: Optional[ssl.SSLContext] = None def get(self) -> ssl.SSLContext: @@ -914,6 +923,8 @@ def get(self) -> ssl.SSLContext: context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs or self.ca_data: context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data) + if self.min_version is not None: + context.minimum_version = self.min_version self.context = context return self.context diff --git a/redis/client.py b/redis/client.py index 2d4c512699..1209a978d2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -198,6 +198,7 @@ def __init__( ssl_validate_ocsp_stapled=False, ssl_ocsp_context=None, ssl_ocsp_expected_cert=None, + ssl_min_version=None, max_connections=None, single_connection_client=False, health_check_interval=0, @@ -311,6 +312,7 @@ def __init__( "ssl_validate_ocsp": ssl_validate_ocsp, "ssl_ocsp_context": ssl_ocsp_context, "ssl_ocsp_expected_cert": ssl_ocsp_expected_cert, + "ssl_min_version": ssl_min_version, } ) connection_pool = ConnectionPool(**kwargs) diff --git a/redis/connection.py b/redis/connection.py index 1f46267146..c9f7fc55d0 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -769,6 +769,7 @@ def __init__( ssl_validate_ocsp_stapled=False, ssl_ocsp_context=None, ssl_ocsp_expected_cert=None, + ssl_min_version=None, **kwargs, ): """Constructor @@ -787,6 +788,7 @@ def __init__( ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. + ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. Raises: RedisError @@ -819,6 +821,7 @@ def __init__( self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled self.ssl_ocsp_context = ssl_ocsp_context self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert + self.ssl_min_version = ssl_min_version super().__init__(**kwargs) def _connect(self): @@ -841,6 +844,8 @@ def _connect(self): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data ) + if self.ssl_min_version is not None: + context.minimum_version = self.ssl_min_version sslsock = context.wrap_socket(sock, server_hostname=self.host) if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: raise RedisError("cryptography is not installed.") diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 5e6b120fb3..5497501258 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -10,6 +10,7 @@ SSLConnection, UnixDomainSocketConnection, ) +from redis.exceptions import ConnectionError from ..ssl_utils import get_ssl_filename @@ -50,7 +51,17 @@ async def test_uds_connect(uds_address): @pytest.mark.ssl -async def test_tcp_ssl_connect(tcp_address): +@pytest.mark.parametrize( + "ssl_min_version", + [ + ssl.TLSVersion.TLSv1_2, + pytest.param( + ssl.TLSVersion.TLSv1_3, + marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"), + ), + ], +) +async def test_tcp_ssl_connect(tcp_address, ssl_min_version): host, port = tcp_address certfile = get_ssl_filename("server-cert.pem") keyfile = get_ssl_filename("server-key.pem") @@ -60,12 +71,44 @@ async def test_tcp_ssl_connect(tcp_address): client_name=_CLIENT_NAME, ssl_ca_certs=certfile, socket_timeout=10, + ssl_min_version=ssl_min_version, ) await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) await conn.disconnect() -async def _assert_connect(conn, server_address, certfile=None, keyfile=None): +@pytest.mark.ssl +@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3") +async def test_tcp_ssl_version_mismatch(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=1, + ssl_min_version=ssl.TLSVersion.TLSv1_3, + ) + with pytest.raises(ConnectionError): + await _assert_connect( + conn, + tcp_address, + certfile=certfile, + keyfile=keyfile, + ssl_version=ssl.TLSVersion.TLSv1_2, + ) + await conn.disconnect() + + +async def _assert_connect( + conn, + server_address, + certfile=None, + keyfile=None, + ssl_version=None, +): stop_event = asyncio.Event() finished = asyncio.Event() @@ -82,7 +125,9 @@ async def _handler(reader, writer): elif certfile: host, port = server_address context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - context.minimum_version = ssl.TLSVersion.TLSv1_2 + if ssl_version is not None: + context.minimum_version = ssl_version + context.maximum_version = ssl_version context.load_cert_chain(certfile=certfile, keyfile=keyfile) server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) else: @@ -94,6 +139,9 @@ async def _handler(reader, writer): try: await conn.connect() await conn.disconnect() + except ConnectionError: + finished.set() + raise finally: stop_event.set() aserver.close() diff --git a/tests/test_connect.py b/tests/test_connect.py index 696e69ceea..0fdbb7005f 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -7,6 +7,7 @@ import pytest from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection +from redis.exceptions import ConnectionError from .ssl_utils import get_ssl_filename @@ -45,7 +46,17 @@ def test_uds_connect(uds_address): @pytest.mark.ssl -def test_tcp_ssl_connect(tcp_address): +@pytest.mark.parametrize( + "ssl_min_version", + [ + ssl.TLSVersion.TLSv1_2, + pytest.param( + ssl.TLSVersion.TLSv1_3, + marks=pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3"), + ), + ], +) +def test_tcp_ssl_connect(tcp_address, ssl_min_version): host, port = tcp_address certfile = get_ssl_filename("server-cert.pem") keyfile = get_ssl_filename("server-key.pem") @@ -55,19 +66,42 @@ def test_tcp_ssl_connect(tcp_address): client_name=_CLIENT_NAME, ssl_ca_certs=certfile, socket_timeout=10, + ssl_min_version=ssl_min_version, ) _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) -def _assert_connect(conn, server_address, certfile=None, keyfile=None): +@pytest.mark.ssl +@pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3") +def test_tcp_ssl_version_mismatch(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ssl_min_version=ssl.TLSVersion.TLSv1_3, + ) + with pytest.raises(ConnectionError): + _assert_connect( + conn, + tcp_address, + certfile=certfile, + keyfile=keyfile, + ssl_version=ssl.PROTOCOL_TLSv1_2, + ) + + +def _assert_connect(conn, server_address, **tcp_kw): if isinstance(server_address, str): if not _RedisUDSServer: pytest.skip("Unix domain sockets are not supported on this platform") server = _RedisUDSServer(server_address, _RedisRequestHandler) else: - server = _RedisTCPServer( - server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile - ) + server = _RedisTCPServer(server_address, _RedisRequestHandler, **tcp_kw) with server as aserver: t = threading.Thread(target=aserver.serve_forever) t.start() @@ -81,11 +115,19 @@ def _assert_connect(conn, server_address, certfile=None, keyfile=None): class _RedisTCPServer(socketserver.TCPServer): - def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None: + def __init__( + self, + *args, + certfile=None, + keyfile=None, + ssl_version=ssl.PROTOCOL_TLS, + **kw, + ) -> None: self._ready_event = threading.Event() self._stop_requested = False self._certfile = certfile self._keyfile = keyfile + self._ssl_version = ssl_version super().__init__(*args, **kw) def service_actions(self): @@ -110,7 +152,7 @@ def get_request(self): server_side=True, certfile=self._certfile, keyfile=self._keyfile, - ssl_version=ssl.PROTOCOL_TLSv1_2, + ssl_version=self._ssl_version, ) return connstream, fromaddr From 6240ea1e46bf5c8d79862c5eac30d48ff1f9a62e Mon Sep 17 00:00:00 2001 From: Qiangning Hong Date: Mon, 5 Feb 2024 22:18:18 +0800 Subject: [PATCH 07/23] docs: Add timeout parameter for get_message example (#3129) The `get_message()` method in asyncio PubSub has a `timeout` argument that defaults to 0.0, causing it to immediately return. This can cause high CPU usage with the given code example and should not be recommended. By setting `timeout=None`, it works with much more efficient resource usage. --- docs/examples/asyncio_examples.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index 5eab4db1f7..5029e907da 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -201,7 +201,7 @@ "\n", "async def reader(channel: redis.client.PubSub):\n", " while True:\n", - " message = await channel.get_message(ignore_subscribe_messages=True)\n", + " message = await channel.get_message(ignore_subscribe_messages=True, timeout=None)\n", " if message is not None:\n", " print(f\"(Reader) Message Received: {message}\")\n", " if message[\"data\"].decode() == STOPWORD:\n", @@ -264,7 +264,7 @@ "\n", "async def reader(channel: redis.client.PubSub):\n", " while True:\n", - " message = await channel.get_message(ignore_subscribe_messages=True)\n", + " message = await channel.get_message(ignore_subscribe_messages=True, timeout=None)\n", " if message is not None:\n", " print(f\"(Reader) Message Received: {message}\")\n", " if message[\"data\"].decode() == STOPWORD:\n", From 6b89786a2957025b2540700baed7878943a7d401 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Mon, 12 Feb 2024 18:30:42 +0200 Subject: [PATCH 08/23] Revert stale isuue version update (#3142) --- .github/workflows/stale-issues.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index 445af1c818..32fd9e8179 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/stale@v9 + - uses: actions/stale@v3 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: 'This issue is marked stale. It will be closed in 30 days if it is not updated.' From 4099d5e0e3a15c377a10754ed6a376f3a6c20676 Mon Sep 17 00:00:00 2001 From: wKollendorf <83725977+wKollendorf@users.noreply.github.com> Date: Mon, 19 Feb 2024 10:48:42 +0100 Subject: [PATCH 09/23] Update connection.py (#3149) Exception ignored in: Traceback .... TypeError: 'NoneType' object cannot be interpreted as an integer. This happens when closing the connection within a spawned Process (multiprocess). --- redis/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index c9f7fc55d0..617d04af5c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -451,7 +451,7 @@ def disconnect(self, *args): if os.getpid() == self.pid: try: conn_sock.shutdown(socket.SHUT_RDWR) - except OSError: + except (OSError, TypeError): pass try: From 2b2a2e0bfe34584c6d8728ba5d03b70648b47f36 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 19 Feb 2024 11:49:18 +0200 Subject: [PATCH 10/23] Remove typing-extensions from deps (#3146) It's not used, and this library requires Python 3.8+. --- requirements.txt | 1 - setup.py | 1 - 2 files changed, 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 82c46c92c6..a716b84463 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ async-timeout>=4.0.2 -typing-extensions; python_version<"3.8" diff --git a/setup.py b/setup.py index aca2244218..c6a9e205f5 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,6 @@ author_email="oss@redis.com", python_requires=">=3.8", install_requires=[ - 'typing-extensions; python_version<"3.8"', 'async-timeout>=4.0.3', ], classifiers=[ From ebb6171832f284d61df4fe3afa0a23af88fbf9b6 Mon Sep 17 00:00:00 2001 From: Will Miller Date: Mon, 19 Feb 2024 10:32:01 +0000 Subject: [PATCH 11/23] Fix retry logic for pubsub and pipeline (#3134) * Fix retry logic for pubsub and pipeline Extend the fix from bea72995fd39b01e2f0a1682b16b6c7690933f36 to apply to pipeline and pubsub as well. Fixes #2973 * fix isort --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/asyncio/client.py | 36 ++++++++++++++++++----------- redis/client.py | 50 ++++++++++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 62bdc7dd5c..3e2912bfca 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -927,11 +927,15 @@ async def connect(self): async def _disconnect_raise_connect(self, conn, error): """ Close the connection and raise an exception - if retry_on_timeout is not set or the error - is not a TimeoutError. Otherwise, try to reconnect + if retry_on_error is not set or the error is not one + of the specified error types. Otherwise, try to + reconnect """ await conn.disconnect() - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): raise error await conn.connect() @@ -1344,8 +1348,8 @@ async def _disconnect_reset_raise(self, conn, error): """ Close the connection, reset watching state and raise an exception if we were watching, - retry_on_timeout is not set, - or the error is not a TimeoutError + if retry_on_error is not set or the error is not one + of the specified error types. """ await conn.disconnect() # if we were already watching a variable, the watch is no longer @@ -1356,9 +1360,12 @@ async def _disconnect_reset_raise(self, conn, error): raise WatchError( "A ConnectionError occurred on while watching one or more keys" ) - # if retry_on_timeout is not set, or the error is not - # a TimeoutError, raise it - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + # if retry_on_error is not set or the error is not one + # of the specified error types, raise it + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): await self.aclose() raise @@ -1533,8 +1540,8 @@ async def load_scripts(self): async def _disconnect_raise_reset(self, conn: Connection, error: Exception): """ Close the connection, raise an exception if we were watching, - and raise an exception if retry_on_timeout is not set, - or the error is not a TimeoutError + and raise an exception if retry_on_error is not set or the + error is not one of the specified error types. """ await conn.disconnect() # if we were watching a variable, the watch is no longer valid @@ -1544,9 +1551,12 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception): raise WatchError( "A ConnectionError occurred on while watching one or more keys" ) - # if retry_on_timeout is not set, or the error is not - # a TimeoutError, raise it - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + # if retry_on_error is not set or the error is not one + # of the specified error types, raise it + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): await self.reset() raise diff --git a/redis/client.py b/redis/client.py index 1209a978d2..85ed7380a8 100755 --- a/redis/client.py +++ b/redis/client.py @@ -25,7 +25,12 @@ SentinelCommands, list_or_args, ) -from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection +from redis.connection import ( + AbstractConnection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) from redis.credentials import CredentialProvider from redis.exceptions import ( ConnectionError, @@ -839,11 +844,15 @@ def clean_health_check_responses(self) -> None: def _disconnect_raise_connect(self, conn, error) -> None: """ Close the connection and raise an exception - if retry_on_timeout is not set or the error - is not a TimeoutError. Otherwise, try to reconnect + if retry_on_error is not set or the error is not one + of the specified error types. Otherwise, try to + reconnect """ conn.disconnect() - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): raise error conn.connect() @@ -1320,8 +1329,8 @@ def _disconnect_reset_raise(self, conn, error) -> None: """ Close the connection, reset watching state and raise an exception if we were watching, - retry_on_timeout is not set, - or the error is not a TimeoutError + if retry_on_error is not set or the error is not one + of the specified error types. """ conn.disconnect() # if we were already watching a variable, the watch is no longer @@ -1332,9 +1341,12 @@ def _disconnect_reset_raise(self, conn, error) -> None: raise WatchError( "A ConnectionError occurred on while watching one or more keys" ) - # if retry_on_timeout is not set, or the error is not - # a TimeoutError, raise it - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + # if retry_on_error is not set or the error is not one + # of the specified error types, raise it + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): self.reset() raise @@ -1492,11 +1504,15 @@ def load_scripts(self): if not exist: s.sha = immediate("SCRIPT LOAD", s.script) - def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None: + def _disconnect_raise_reset( + self, + conn: AbstractConnection, + error: Exception, + ) -> None: """ Close the connection, raise an exception if we were watching, - and raise an exception if TimeoutError is not part of retry_on_error, - or the error is not a TimeoutError + and raise an exception if retry_on_error is not set or the + error is not one of the specified error types. """ conn.disconnect() # if we were watching a variable, the watch is no longer valid @@ -1506,11 +1522,13 @@ def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None: raise WatchError( "A ConnectionError occurred on while watching one or more keys" ) - # if TimeoutError is not part of retry_on_error, or the error - # is not a TimeoutError, raise it - if not ( - TimeoutError in conn.retry_on_error and isinstance(error, TimeoutError) + # if retry_on_error is not set or the error is not one + # of the specified error types, raise it + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False ): + self.reset() raise error From d529c2ad8d2cf4dcfb41bfd93ea68cfefd81aa66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 22 Feb 2024 12:48:00 +0000 Subject: [PATCH 12/23] Fix incorrect asserts in test and ensure connections are closed (#3004) --- tests/test_ssl.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 465fdabb89..dfd8837262 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -26,13 +26,15 @@ def test_ssl_with_invalid_cert(self, request): sslclient = redis.from_url(ssl_url) with pytest.raises(ConnectionError) as e: sslclient.ping() - assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e) + assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e) + sslclient.close() def test_ssl_connection(self, request): ssl_url = request.config.option.redis_ssl_url p = urlparse(ssl_url)[1].split(":") r = redis.Redis(host=p[0], port=p[1], ssl=True, ssl_cert_reqs="none") assert r.ping() + r.close() def test_ssl_connection_without_ssl(self, request): ssl_url = request.config.option.redis_ssl_url @@ -41,7 +43,8 @@ def test_ssl_connection_without_ssl(self, request): with pytest.raises(ConnectionError) as e: r.ping() - assert "Connection closed by server" in str(e) + assert "Connection closed by server" in str(e) + r.close() def test_validating_self_signed_certificate(self, request): ssl_url = request.config.option.redis_ssl_url @@ -56,6 +59,7 @@ def test_validating_self_signed_certificate(self, request): ssl_ca_certs=self.SERVER_CERT, ) assert r.ping() + r.close() def test_validating_self_signed_string_certificate(self, request): with open(self.SERVER_CERT) as f: @@ -72,6 +76,7 @@ def test_validating_self_signed_string_certificate(self, request): ssl_ca_data=cert_data, ) assert r.ping() + r.close() def _create_oscp_conn(self, request): ssl_url = request.config.option.redis_ssl_url @@ -92,22 +97,25 @@ def _create_oscp_conn(self, request): def test_ssl_ocsp_called(self, request): r = self._create_oscp_conn(request) with pytest.raises(RedisError) as e: - assert r.ping() - assert "cryptography not installed" in str(e) + r.ping() + assert "cryptography is not installed" in str(e) + r.close() @skip_if_nocryptography() def test_ssl_ocsp_called_withcrypto(self, request): r = self._create_oscp_conn(request) with pytest.raises(ConnectionError) as e: assert r.ping() - assert "No AIA information present in ssl certificate" in str(e) + assert "No AIA information present in ssl certificate" in str(e) + r.close() # rediss://, url based ssl_url = request.config.option.redis_ssl_url sslclient = redis.from_url(ssl_url) with pytest.raises(ConnectionError) as e: sslclient.ping() - assert "No AIA information present in ssl certificate" in str(e) + assert "No AIA information present in ssl certificate" in str(e) + sslclient.close() @skip_if_nocryptography() def test_valid_ocsp_cert_http(self): @@ -132,7 +140,7 @@ def test_revoked_ocsp_certificate(self): ocsp = OCSPVerifier(wrapped, hostname, 443) with pytest.raises(ConnectionError) as e: assert ocsp.is_valid() - assert "REVOKED" in str(e) + assert "REVOKED" in str(e) @skip_if_nocryptography() def test_unauthorized_ocsp(self): @@ -157,7 +165,7 @@ def test_ocsp_not_present_in_response(self): ocsp = OCSPVerifier(wrapped, hostname, 443) with pytest.raises(ConnectionError) as e: assert ocsp.is_valid() - assert "from the" in str(e) + assert "from the" in str(e) @skip_if_nocryptography() def test_unauthorized_then_direct(self): @@ -193,6 +201,7 @@ def test_mock_ocsp_staple(self, request): with pytest.raises(RedisError): r.ping() + r.close() ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) ctx.use_certificate_file(self.SERVER_CERT) @@ -213,7 +222,8 @@ def test_mock_ocsp_staple(self, request): with pytest.raises(ConnectionError) as e: r.ping() - assert "no ocsp response present" in str(e) + assert "no ocsp response present" in str(e) + r.close() r = redis.Redis( host=p[0], @@ -228,4 +238,5 @@ def test_mock_ocsp_staple(self, request): with pytest.raises(ConnectionError) as e: r.ping() - assert "no ocsp response present" in str(e) + assert "no ocsp response present" in str(e) + r.close() From c573bc4ab61d0d57726f872fdfca31962d44b534 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:31:59 +0200 Subject: [PATCH 13/23] Fix bug: client side caching causes unexpected disconnections (#3160) * fix disconnects * skip test in cluster --------- Co-authored-by: Chayim --- redis/_parsers/resp3.py | 4 +++- redis/client.py | 14 +++++++------- redis/commands/core.py | 2 +- redis/connection.py | 17 +++++++--------- tests/test_cache.py | 43 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 19 deletions(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 13aa1ffccb..88c8d5e52b 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -117,7 +117,9 @@ def _read_response(self, disable_decoding=False, push_request=False): ) for _ in range(int(response)) ] - self.handle_push_response(response, disable_decoding, push_request) + response = self.handle_push_response( + response, disable_decoding, push_request + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") diff --git a/redis/client.py b/redis/client.py index 85ed7380a8..79f52cc989 100755 --- a/redis/client.py +++ b/redis/client.py @@ -563,10 +563,10 @@ def execute_command(self, *args, **options): pool = self.connection_pool conn = self.connection or pool.get_connection(command_name, **options) response_from_cache = conn._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - try: + try: + if response_from_cache is not None: + return response_from_cache + else: response = conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options @@ -575,9 +575,9 @@ def execute_command(self, *args, **options): ) conn._add_to_local_cache(args, response, keys) return response - finally: - if not self.connection: - pool.release(conn) + finally: + if not self.connection: + pool.release(conn) def parse_response(self, connection, command_name, **options): """Parses a response from the Redis server""" diff --git a/redis/commands/core.py b/redis/commands/core.py index 6d81d76035..464e8d8c85 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -2011,7 +2011,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: options = {} if not args: options[EMPTY_RESPONSE] = [] - options["keys"] = keys + options["keys"] = args return self.execute_command("MGET", *args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: diff --git a/redis/connection.py b/redis/connection.py index 617d04af5c..b89ce0e94b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,6 +1,5 @@ import copy import os -import select import socket import ssl import sys @@ -609,11 +608,6 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _socket_is_empty(self): - """Check if the socket is empty""" - r, _, _ = select.select([self._sock], [], [], 0) - return not bool(r) - def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] ) -> None: @@ -639,7 +633,7 @@ def _get_from_local_cache(self, command: str): or command[0] not in self.cache_whitelist ): return None - while not self._socket_is_empty(): + while self.can_read(): self.read_response(push_request=True) return self.client_cache.get(command) @@ -1187,12 +1181,15 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": try: # ensure this connection is connected to Redis connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the + # if client caching is not enabled connections that the pool + # provides should be ready to send a command. + # if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. + # (if caching enabled the connection will not always be ready + # to send a command because it may contain invalidation messages) try: - if connection.can_read(): + if connection.can_read() and connection.client_cache is None: raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() diff --git a/tests/test_cache.py b/tests/test_cache.py index 4eb5160ecc..dd33afd23e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -146,6 +146,49 @@ def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + def test_csc_not_cause_disconnects(self, r): + r, cache = r + id1 = r.client_id() + r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1}) + assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] + id2 = r.client_id() + + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] + assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [ + "1", + "1", + "1", + "1", + "1", + "1", + ] + + r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2}) + id3 = r.client_id() + # client should get value from redis server post invalidate messages + assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"] + + r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3}) + # need to check that we get correct value 3 and not 2 + assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] + + r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4}) + # need to check that we get correct value 4 and not 3 + assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] + # client should get value from client cache + assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] + id4 = r.client_id() + assert id1 == id2 == id3 == id4 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster From 26ab964ec18ec255672abaec90de439705151b5c Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:59:31 +0200 Subject: [PATCH 14/23] Fix bug: client side caching causes unexpected disconnections (async version) (#3165) * fix disconnects * skip test in cluster * add test * save return value from handle_push_response (without it 'read_response' return the push message) * insert return response from cache to the try block to prevent connection leak * enable to get connection with data avaliable to read in csc mode and change can_read_destructive to not read data * fix check if socket is empty (at_eof() can return False but this doesn't mean there's definitely more data to read) --------- Co-authored-by: Chayim --- redis/_parsers/base.py | 2 +- redis/_parsers/resp3.py | 4 ++- redis/asyncio/client.py | 40 +++++++++++++++--------------- redis/asyncio/connection.py | 14 ++++++++--- tests/test_asyncio/test_cache.py | 42 ++++++++++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 25 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 8e59249bef..0137539d66 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool: return True try: async with async_timeout(0): - return await self._stream.read(1) + return self._stream.at_eof() except TimeoutError: return False diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 88c8d5e52b..7afa43a0c2 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -261,7 +261,9 @@ async def _read_response( ) for _ in range(int(response)) ] - await self.handle_push_response(response, disable_decoding, push_request) + response = await self.handle_push_response( + response, disable_decoding, push_request + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3e2912bfca..9ff2e3917f 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -629,25 +629,27 @@ async def execute_command(self, *args, **options): pool = self.connection_pool conn = self.connection or await pool.get_connection(command_name, **options) response_from_cache = await conn._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - if self.single_connection_client: - await self._single_conn_lock.acquire() - try: - response = await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - conn._add_to_local_cache(args, response, keys) - return response - finally: - if self.single_connection_client: - self._single_conn_lock.release() - if not self.connection: - await pool.release(conn) + try: + if response_from_cache is not None: + return response_from_cache + else: + try: + if self.single_connection_client: + await self._single_conn_lock.acquire() + response = await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + conn._add_to_local_cache(args, response, keys) + return response + finally: + if self.single_connection_client: + self._single_conn_lock.release() + finally: + if not self.connection: + await pool.release(conn) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 81df3b3543..6c5c58c683 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -685,7 +685,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] def _socket_is_empty(self): """Check if the socket is empty""" - return not self._reader.at_eof() + return len(self._reader._buffer) == 0 def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] @@ -1192,12 +1192,18 @@ def make_connection(self): async def ensure_connection(self, connection: AbstractConnection): """Ensure that the connection object is connected and valid""" await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the + # if client caching is not enabled connections that the pool + # provides should be ready to send a command. + # if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. + # (if caching enabled the connection will not always be ready + # to send a command because it may contain invalidation messages) try: - if await connection.can_read_destructive(): + if ( + await connection.can_read_destructive() + and connection.client_cache is None + ): raise ConnectionError("Connection has data") from None except (ConnectionError, OSError): await connection.disconnect() diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index bf20337dfb..4762bb7c05 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_csc_not_cause_disconnects(self, r): + r, cache = r + id1 = await r.client_id() + await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}) + assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] + id2 = await r.client_id() + + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] + assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [ + "1", + "1", + "1", + "1", + "1", + ] + + await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2}) + id3 = await r.client_id() + # client should get value from redis server post invalidate messages + assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"] + + await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3}) + # need to check that we get correct value 3 and not 2 + assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] + + await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4}) + # need to check that we get correct value 4 and not 3 + assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] + id4 = await r.client_id() + assert id1 == id2 == id3 == id4 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster From 9df2225ba6309d50742959328958755210d757bd Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:07:37 +0200 Subject: [PATCH 15/23] Version 5.1.0b4 (#3166) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c6a9e205f5..68bfc25c42 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="5.1.0b3", + version="5.1.0b4", packages=find_packages( include=[ "redis", From 9ad1546c06bb7321e7e19bbb9c8dd758343c0390 Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:12:48 +0200 Subject: [PATCH 16/23] Fix lock error (#3176) --- redis/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/exceptions.py b/redis/exceptions.py index ddb4041da3..8af58cb0db 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -82,7 +82,7 @@ class LockError(RedisError, ValueError): # NOTE: For backwards compatibility, this class derives from ValueError. # This was originally chosen to behave like threading.Lock. - def __init__(self, message, lock_name=None): + def __init__(self, message=None, lock_name=None): self.message = message self.lock_name = lock_name From 22957070f72615bae090a6550143a1f091efdae1 Mon Sep 17 00:00:00 2001 From: Kamil Monicz Date: Sun, 10 Mar 2024 10:41:06 +0100 Subject: [PATCH 17/23] Remove redundant async-timeout dependency in modern Python (#3177) https://github.com/redis/redis-py/issues/3174 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 68bfc25c42..1098e115c6 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ author_email="oss@redis.com", python_requires=">=3.8", install_requires=[ - 'async-timeout>=4.0.3', + 'async-timeout>=4.0.3; python_full_version<"3.11.3"', ], classifiers=[ "Development Status :: 5 - Production/Stable", From 0a0321523078ee514801a47bd83b7fe5a5ba198e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 10 Mar 2024 11:41:34 +0200 Subject: [PATCH 18/23] Bump rojopolis/spellcheck-github-actions from 0.35.0 to 0.36.0 (#3172) Bumps [rojopolis/spellcheck-github-actions](https://github.com/rojopolis/spellcheck-github-actions) from 0.35.0 to 0.36.0. - [Release notes](https://github.com/rojopolis/spellcheck-github-actions/releases) - [Changelog](https://github.com/rojopolis/spellcheck-github-actions/blob/master/CHANGELOG.md) - [Commits](https://github.com/rojopolis/spellcheck-github-actions/compare/0.35.0...0.36.0) --- updated-dependencies: - dependency-name: rojopolis/spellcheck-github-actions dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/spellcheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml index a48781aa84..f739a54242 100644 --- a/.github/workflows/spellcheck.yml +++ b/.github/workflows/spellcheck.yml @@ -8,7 +8,7 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Check Spelling - uses: rojopolis/spellcheck-github-actions@0.35.0 + uses: rojopolis/spellcheck-github-actions@0.36.0 with: config_path: .github/spellcheck-settings.yml task_name: Markdown From 4f8dfae5aa58fc6ddc9e75f82877f5e23d69de24 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 10 Mar 2024 12:32:12 +0200 Subject: [PATCH 19/23] Bump release-drafter/release-drafter from 5 to 6 (#3171) Bumps [release-drafter/release-drafter](https://github.com/release-drafter/release-drafter) from 5 to 6. - [Release notes](https://github.com/release-drafter/release-drafter/releases) - [Commits](https://github.com/release-drafter/release-drafter/compare/v5...v6) --- updated-dependencies: - dependency-name: release-drafter/release-drafter dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release-drafter.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml index eebb3e678b..6695abfe4b 100644 --- a/.github/workflows/release-drafter.yml +++ b/.github/workflows/release-drafter.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest steps: # Drafts your next Release notes as Pull Requests are merged into "master" - - uses: release-drafter/release-drafter@v5 + - uses: release-drafter/release-drafter@v6 with: # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml config-name: release-drafter-config.yml From ddff7b54b5db4ace94203dc141151b3f08060d2a Mon Sep 17 00:00:00 2001 From: Willian Moreira Date: Tue, 12 Mar 2024 08:14:25 -0300 Subject: [PATCH 20/23] Optimizing cluster initialization changing the checks for cluster-enabled flag (#3158) * change if the cluster-mode is enabled by trying run CLUSTER SLOT insted of INFO * fix typo * fixing cluster mode is not enabled on this node tests * remove changes on asyncio * rename mock flag to be more consistent * optimizing async cluster creation using CLUSTER SLOT command instead of INFO command * fixing test. Before INFO and CLUSTER_SLOT was used for performing the connection, now only the CLUSTER_SLOT, so the total commands is minus 1 * remove dot at the end of string * remove unecessary print from test * fix lint problems --------- Co-authored-by: Willian Moreira --- redis/asyncio/cluster.py | 7 +++---- redis/cluster.py | 5 +++-- tests/test_asyncio/test_cluster.py | 20 ++++++++++++++------ tests/test_cluster.py | 16 ++++++++++++---- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4fb2fc4647..11c423b848 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1253,13 +1253,12 @@ async def initialize(self) -> None: for startup_node in self.startup_nodes.values(): try: # Make sure cluster mode is enabled on this node - if not (await startup_node.execute_command("INFO")).get( - "cluster_enabled" - ): + try: + cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") + except ResponseError: raise RedisClusterException( "Cluster mode is not enabled on this node" ) - cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True except Exception as e: # Try the next startup node. diff --git a/redis/cluster.py b/redis/cluster.py index ba25b92246..a9213f4235 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1525,11 +1525,12 @@ def initialize(self): ) self.startup_nodes[startup_node.name].redis_connection = r # Make sure cluster mode is enabled on this node - if bool(r.info().get("cluster_enabled")) is False: + try: + cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + except ResponseError: raise RedisClusterException( "Cluster mode is not enabled on this node" ) - cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) startup_nodes_reachable = True except Exception as e: # Try the next startup node. diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a57d32f5d2..d7554b12a5 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -127,7 +127,9 @@ async def slowlog(r: RedisCluster) -> None: await r.config_set("slowlog-max-len", old_max_length_value) -async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster: +async def get_mocked_redis_client( + cluster_slots_raise_error=False, *args, **kwargs +) -> RedisCluster: """ Return a stable RedisCluster object that have deterministic nodes and slots setup to remove the problem of different IP addresses @@ -139,9 +141,13 @@ async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster: with mock.patch.object(ClusterNode, "execute_command") as execute_command_mock: async def execute_command(*_args, **_kwargs): + if _args[0] == "CLUSTER SLOTS": - mock_cluster_slots = cluster_slots - return mock_cluster_slots + if cluster_slots_raise_error: + raise ResponseError() + else: + mock_cluster_slots = cluster_slots + return mock_cluster_slots elif _args[0] == "COMMAND": return {"get": [], "set": []} elif _args[0] == "INFO": @@ -2458,7 +2464,10 @@ async def test_init_slots_cache_cluster_mode_disabled(self) -> None: """ with pytest.raises(RedisClusterException) as e: rc = await get_mocked_redis_client( - host=default_host, port=default_port, cluster_enabled=False + cluster_slots_raise_error=True, + host=default_host, + port=default_port, + cluster_enabled=False, ) await rc.aclose() assert "Cluster mode is not enabled on this node" in str(e.value) @@ -2719,10 +2728,9 @@ async def parse_response( async with r.pipeline() as pipe: with pytest.raises(ClusterDownError): await pipe.get(key).execute() - assert ( node.parse_response.await_count - == 4 * r.cluster_error_retry_attempts - 3 + == 3 * r.cluster_error_retry_attempts - 2 ) async def test_connection_error_not_raised(self, r: RedisCluster) -> None: diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 8a44d45ea3..1f505b816d 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -151,7 +151,9 @@ def cleanup(): r.config_set("slowlog-max-len", 128) -def get_mocked_redis_client(func=None, *args, **kwargs): +def get_mocked_redis_client( + func=None, cluster_slots_raise_error=False, *args, **kwargs +): """ Return a stable RedisCluster object that have deterministic nodes and slots setup to remove the problem of different IP addresses @@ -164,8 +166,11 @@ def get_mocked_redis_client(func=None, *args, **kwargs): def execute_command(*_args, **_kwargs): if _args[0] == "CLUSTER SLOTS": - mock_cluster_slots = cluster_slots - return mock_cluster_slots + if cluster_slots_raise_error: + raise ResponseError() + else: + mock_cluster_slots = cluster_slots + return mock_cluster_slots elif _args[0] == "COMMAND": return {"get": [], "set": []} elif _args[0] == "INFO": @@ -2654,7 +2659,10 @@ def test_init_slots_cache_cluster_mode_disabled(self): """ with pytest.raises(RedisClusterException) as e: get_mocked_redis_client( - host=default_host, port=default_port, cluster_enabled=False + cluster_slots_raise_error=True, + host=default_host, + port=default_port, + cluster_enabled=False, ) assert "Cluster mode is not enabled on this node" in str(e.value) From 5090875fe5ab84ea8142a28909a93d6ef35c784f Mon Sep 17 00:00:00 2001 From: Mathieu Rey Date: Thu, 14 Mar 2024 18:30:05 +0800 Subject: [PATCH 21/23] Update asyncio_examples.ipynb (#3125) clarified wording + one case of deprecated .close() changed to .aclose() --- docs/examples/asyncio_examples.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index 5029e907da..d2b11b56be 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -15,7 +15,7 @@ "\n", "## Connecting and Disconnecting\n", "\n", - "Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.aclose` which disconnects all connections." + "Using asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, an internal connection pool is created on `redis.Redis()` and attached to the `Redis` instance. When calling `Redis.aclose` this internal connection pool closes automatically, which disconnects all connections." ] }, { @@ -48,7 +48,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you create custom `ConnectionPool` for the `Redis` instance to use alone, use the `from_pool` class method to create it. This will cause the pool to be disconnected along with the Redis instance. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + "If you create a custom `ConnectionPool` to be used by a single `Redis` instance, use the `Redis.from_pool` class method. The Redis client will take ownership of the connection pool. This will cause the pool to be disconnected along with the Redis instance. Disconnecting the connection pool simply disconnects all connections hosted in the pool." ] }, { @@ -61,7 +61,7 @@ "\n", "pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n", "client = redis.Redis.from_pool(pool)\n", - "await client.close()" + "await client.aclose()" ] }, { @@ -74,7 +74,7 @@ }, "source": [ "\n", - "However, If you supply a `ConnectionPool` that is shared several `Redis` instances, you may want to disconnect the connection pool explicitly. use the `connection_pool` argument in that case." + "However, if the `ConnectionPool` is to be shared by several `Redis` instances, you should use the `connection_pool` argument, and you may want to disconnect the connection pool explicitly." ] }, { From 037d10826e609fa0de7a4a0f56bca440cc1d245c Mon Sep 17 00:00:00 2001 From: Gabriel Erzse Date: Mon, 18 Mar 2024 14:57:46 +0200 Subject: [PATCH 22/23] Avoid workflows canceling each other out (#3183) Co-authored-by: Gabriel Erzse --- .github/workflows/docs.yaml | 2 +- .github/workflows/integration.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index c5c74aa4d3..a3512b46dc 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -13,7 +13,7 @@ on: - cron: '0 1 * * *' # nightly build concurrency: - group: ${{ github.event.pull_request.number || github.ref }} + group: ${{ github.event.pull_request.number || github.ref }}-docs cancel-in-progress: true permissions: diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 695e1b307c..8f60efe6c7 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -17,7 +17,7 @@ on: - cron: '0 1 * * *' # nightly build concurrency: - group: ${{ github.event.pull_request.number || github.ref }} + group: ${{ github.event.pull_request.number || github.ref }}-integration cancel-in-progress: true permissions: From 07fc339b4a4088c1ff052527685ebdde43dfc4be Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Wed, 27 Mar 2024 13:13:32 +0200 Subject: [PATCH 23/23] Update black version to 24.3.0 (#3193) * Update black version to 24.3.0 * fix black changes --- .flake8 | 2 ++ dev_requirements.txt | 2 +- redis/_parsers/helpers.py | 31 +++++++++++++++-------------- redis/asyncio/client.py | 12 ++++------- redis/asyncio/cluster.py | 14 ++++++------- redis/asyncio/connection.py | 14 ++++++------- redis/asyncio/sentinel.py | 8 +++++--- redis/cluster.py | 6 +++--- redis/commands/core.py | 4 +--- redis/exceptions.py | 3 +-- redis/sentinel.py | 8 +++++--- redis/typing.py | 6 ++---- tests/test_asyncio/test_commands.py | 1 + 13 files changed, 55 insertions(+), 56 deletions(-) diff --git a/.flake8 b/.flake8 index 73b4a96bb6..d2ee181447 100644 --- a/.flake8 +++ b/.flake8 @@ -16,6 +16,8 @@ exclude = ignore = E126 E203 + E701 + E704 F405 N801 N802 diff --git a/dev_requirements.txt b/dev_requirements.txt index 3715599af0..ef3b1aa22d 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,5 @@ click==8.0.4 -black==22.3.0 +black==24.3.0 flake8==5.0.4 flake8-isort==6.0.0 flynt~=0.69.0 diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index bdd749a5bc..a1df927bf3 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -819,18 +819,19 @@ def string_keys_to_dict(key_string, callback): lambda r, **kwargs: r, ), **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), - "ACL LOG": lambda r: [ - {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r - ] - if isinstance(r, list) - else bool_ok(r), + "ACL LOG": lambda r: ( + [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} + for x in r + ] + if isinstance(r, list) + else bool_ok(r) + ), "COMMAND": parse_command_resp3, "CONFIG GET": lambda r: { - str_if_bytes(key) - if key is not None - else None: str_if_bytes(value) - if value is not None - else None + str_if_bytes(key) if key is not None else None: ( + str_if_bytes(value) if value is not None else None + ) for key, value in r.items() }, "MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()}, @@ -838,11 +839,11 @@ def string_keys_to_dict(key_string, callback): "SENTINEL MASTERS": parse_sentinel_masters_resp3, "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, - "STRALGO": lambda r, **options: { - str_if_bytes(key): str_if_bytes(value) for key, value in r.items() - } - if isinstance(r, dict) - else str_if_bytes(r), + "STRALGO": lambda r, **options: ( + {str_if_bytes(key): str_if_bytes(value) for key, value in r.items()} + if isinstance(r, dict) + else str_if_bytes(r) + ), "XINFO CONSUMERS": lambda r: [ {str_if_bytes(key): value for key, value in x.items()} for x in r ], diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9ff2e3917f..e153f0cd37 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -88,13 +88,11 @@ class ResponseCallbackProtocol(Protocol): - def __call__(self, response: Any, **kwargs): - ... + def __call__(self, response: Any, **kwargs): ... class AsyncResponseCallbackProtocol(Protocol): - async def __call__(self, response: Any, **kwargs): - ... + async def __call__(self, response: Any, **kwargs): ... ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] @@ -1220,13 +1218,11 @@ async def run( class PubsubWorkerExceptionHandler(Protocol): - def __call__(self, e: BaseException, pubsub: PubSub): - ... + def __call__(self, e: BaseException, pubsub: PubSub): ... class AsyncPubsubWorkerExceptionHandler(Protocol): - async def __call__(self, e: BaseException, pubsub: PubSub): - ... + async def __call__(self, e: BaseException, pubsub: PubSub): ... PSWorkerThreadExcHandlerT = Union[ diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 11c423b848..ff2bd10c9d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -402,10 +402,10 @@ def __init__( self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() - self.result_callbacks[ - "CLUSTER SLOTS" - ] = lambda cmd, res, **kwargs: parse_cluster_slots( - list(res.values())[0], **kwargs + self.result_callbacks["CLUSTER SLOTS"] = ( + lambda cmd, res, **kwargs: parse_cluster_slots( + list(res.values())[0], **kwargs + ) ) self._initialize = True @@ -1318,9 +1318,9 @@ async def initialize(self) -> None: ) tmp_slots[i].append(target_replica_node) # add this node to the nodes cache - tmp_nodes_cache[ - target_replica_node.name - ] = target_replica_node + tmp_nodes_cache[target_replica_node.name] = ( + target_replica_node + ) else: # Validate that 2 nodes want to use the same slot cache # setup diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 6c5c58c683..2e470bfcfb 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -87,13 +87,11 @@ class _Sentinel(enum.Enum): class ConnectCallbackProtocol(Protocol): - def __call__(self, connection: "AbstractConnection"): - ... + def __call__(self, connection: "AbstractConnection"): ... class AsyncConnectCallbackProtocol(Protocol): - async def __call__(self, connection: "AbstractConnection"): - ... + async def __call__(self, connection: "AbstractConnection"): ... ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] @@ -319,9 +317,11 @@ async def connect(self): await self.on_connect() else: # Use the passed function redis_connect_func - await self.redis_connect_func(self) if asyncio.iscoroutinefunction( - self.redis_connect_func - ) else self.redis_connect_func(self) + ( + await self.redis_connect_func(self) + if asyncio.iscoroutinefunction(self.redis_connect_func) + else self.redis_connect_func(self) + ) except RedisError: # clean up after any error in on_connect await self.disconnect() diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index d88babc59c..6fd233adc8 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -108,9 +108,11 @@ class SentinelConnectionPool(ConnectionPool): def __init__(self, service_name, sentinel_manager, **kwargs): kwargs["connection_class"] = kwargs.get( "connection_class", - SentinelManagedSSLConnection - if kwargs.pop("ssl", False) - else SentinelManagedConnection, + ( + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection + ), ) self.is_master = kwargs.pop("is_master", True) self.check_connection = kwargs.pop("check_connection", False) diff --git a/redis/cluster.py b/redis/cluster.py index a9213f4235..cfe902115e 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1582,9 +1582,9 @@ def initialize(self): ) tmp_slots[i].append(target_replica_node) # add this node to the nodes cache - tmp_nodes_cache[ - target_replica_node.name - ] = target_replica_node + tmp_nodes_cache[target_replica_node.name] = ( + target_replica_node + ) else: # Validate that 2 nodes want to use the same slot cache # setup diff --git a/redis/commands/core.py b/redis/commands/core.py index 464e8d8c85..566846a20e 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -3399,9 +3399,7 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ return self.execute_command("SMEMBERS", name, keys=[name]) - def smismember( - self, name: str, values: List, *args: List - ) -> Union[ + def smismember(self, name: str, values: List, *args: List) -> Union[ Awaitable[List[Union[Literal[0], Literal[1]]]], List[Union[Literal[0], Literal[1]]], ]: diff --git a/redis/exceptions.py b/redis/exceptions.py index 8af58cb0db..dcc06774b0 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -217,5 +217,4 @@ class SlotNotCoveredError(RedisClusterException): pass -class MaxConnectionsError(ConnectionError): - ... +class MaxConnectionsError(ConnectionError): ... diff --git a/redis/sentinel.py b/redis/sentinel.py index dfcd8ff64b..72b5bef548 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -145,9 +145,11 @@ class SentinelConnectionPool(ConnectionPool): def __init__(self, service_name, sentinel_manager, **kwargs): kwargs["connection_class"] = kwargs.get( "connection_class", - SentinelManagedSSLConnection - if kwargs.pop("ssl", False) - else SentinelManagedConnection, + ( + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection + ), ) self.is_master = kwargs.pop("is_master", True) self.check_connection = kwargs.pop("check_connection", False) diff --git a/redis/typing.py b/redis/typing.py index a5d1369d63..838219fbb6 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -54,12 +54,10 @@ class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] - def execute_command(self, *args, **options): - ... + def execute_command(self, *args, **options): ... class ClusterCommandsProtocol(CommandsProtocol, Protocol): encoder: "Encoder" - def execute_command(self, *args, **options) -> Union[Any, Awaitable]: - ... + def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ... diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 35b9f2a29f..7102450fe4 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1,6 +1,7 @@ """ Tests async overrides of commands from their mixins """ + import asyncio import binascii import datetime