diff --git a/tests/integration/transport/redis/test_redis.py b/tests/integration/transport/redis/test_redis.py index 50edbb9e..d7577074 100644 --- a/tests/integration/transport/redis/test_redis.py +++ b/tests/integration/transport/redis/test_redis.py @@ -14,8 +14,9 @@ def test_consumer_group( redis_consumer_factory: ConsumerFactory, ) -> None: redis_consumer = redis_consumer_factory._consumer - consumer_group = redis_consumer._consumer_group - assert consumer_group.keys.get(RedisStream.TEST) + with redis_consumer: + consumer_group = redis_consumer._consumer_group + assert consumer_group.keys.get(RedisStream.TEST) @pytest.mark.parametrize("enable_sentinel", [False, True], indirect=True) def test_consume_new( @@ -25,8 +26,9 @@ def test_consume_new( run_id: int, ) -> None: redis_consumer = redis_consumer_factory._consumer - task = TestRunTask(data=TestRunData(test_run_id=run_id)) - assert redis_producer.add_task(task) - task_from_consumer = redis_consumer._consume()[-1] - run_id_from_task = task_from_consumer.decoded_message["data"]["test_run_id"] - assert run_id_from_task == run_id + with redis_consumer: + task = TestRunTask(data=TestRunData(test_run_id=run_id)) + assert redis_producer.add_task(task) + task_from_consumer = redis_consumer._consume()[-1] + run_id_from_task = task_from_consumer.decoded_message["data"]["test_run_id"] + assert run_id_from_task == run_id