Skip to content

Commit

Permalink
Add async support for Redis state persister
Browse files Browse the repository at this point in the history
Implements the async function from redis-py.asyncio for state
persistance.
  • Loading branch information
jernejfrank committed Jan 13, 2025
1 parent d3ba49d commit 304eb21
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 1 deletion.
166 changes: 166 additions & 0 deletions burr/integrations/persisters/b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

try:
import redis # can't name module redis because this import wouldn't work.
import redis.asyncio as aredis

except ImportError as e:
base.require_plugin(e, "redis")
Expand Down Expand Up @@ -217,6 +218,171 @@ def __init__(
super(RedisPersister, self).__init__(connection, serde_kwargs, namespace)


class AsyncRedisBasePersister(persistence.BaseStatePersister):
"""Main class for Async Redis persister.
Use this class if you want to directly control injecting the async Redis client.
.. warning::
The synchronous persister closes the connection on deletion of the class using the ``__del__`` method.
In an async context that is not reliable (the event loop may already be closed by the time ``__del__``
gets invoked). Therefore, you are responsible for closing the connection yourself (i.e. manual cleanup).
This class is responsible for async persisting state data to a Redis database.
It inherits from the BaseStatePersister class.
"""

@classmethod
def from_values(
cls,
host: str,
port: int,
db: int,
password: str = None,
serde_kwargs: dict = None,
redis_client_kwargs: dict = None,
namespace: str = None,
) -> "AsyncRedisBasePersister":
"""Creates a new instance of the AsyncRedisBasePersister from passed in values."""
if redis_client_kwargs is None:
redis_client_kwargs = {}
connection = aredis.Redis(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
return cls(connection, serde_kwargs, namespace)

def __init__(
self,
connection,
serde_kwargs: dict = None,
namespace: str = None,
):
"""Initializes the AsyncRedisPersister class.
:param connection: the redis connection object.
:param serde_kwargs: serialization and deserialization keyword arguments to pass to state SERDE.
:param namespace: The name of the project to optionally use in the key prefix.
"""
self.connection = connection
self.serde_kwargs = serde_kwargs or {}
self.namespace = namespace if namespace else ""

async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
"""List the app ids for a given partition key."""
namespaced_partition_key = (
f"{self.namespace}:{partition_key}" if self.namespace else partition_key
)
app_ids = await self.connection.zrevrange(namespaced_partition_key, 0, -1)
return [app_id.decode() for app_id in app_ids]

async def load(
self, partition_key: str, app_id: str, sequence_id: int = None, **kwargs
) -> Optional[persistence.PersistedStateData]:
"""Load the state data for a given partition key, app id, and sequence id.
If the sequence id is not given, it will be looked up in the Redis database. If it is not found, None will be returned.
:param partition_key:
:param app_id:
:param sequence_id:
:param kwargs:
:return: Value or None.
"""
namespaced_partition_key = (
f"{self.namespace}:{partition_key}" if self.namespace else partition_key
)
if sequence_id is None:
sequence_id = await self.connection.zscore(namespaced_partition_key, app_id)
if sequence_id is None:
return None
sequence_id = int(sequence_id)
key = await self.create_key(app_id, partition_key, sequence_id)
data = await self.connection.hgetall(key)
if not data:
return None
_state = state.State.deserialize(json.loads(data[b"state"].decode()), **self.serde_kwargs)
return {
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": sequence_id,
"position": data[b"position"].decode(),
"state": _state,
"created_at": data[b"created_at"].decode(),
"status": data[b"status"].decode(),
}

async def create_key(self, app_id, partition_key, sequence_id):
"""Create a key for the Redis database."""
if self.namespace:
key = f"{self.namespace}:{partition_key}:{app_id}:{sequence_id}"
else:
key = f"{partition_key}:{app_id}:{sequence_id}"
return key

async def save(
self,
partition_key: str,
app_id: str,
sequence_id: int,
position: str,
state: state.State,
status: Literal["completed", "failed"],
**kwargs,
):
"""Save the state data to the Redis database.
:param partition_key:
:param app_id:
:param sequence_id:
:param position:
:param state:
:param status:
:param kwargs:
:return:
"""
key = await self.create_key(app_id, partition_key, sequence_id)
if await self.connection.exists(key):
raise ValueError(f"partition_key:app_id:sequence_id[{key}] already exists.")
json_state = json.dumps(state.serialize(**self.serde_kwargs))
await self.connection.hset(
key,
mapping={
"partition_key": partition_key,
"app_id": app_id,
"sequence_id": sequence_id,
"position": position,
"state": json_state,
"status": status,
"created_at": datetime.now(timezone.utc).isoformat(),
},
)
namespaced_partition_key = (
f"{self.namespace}:{partition_key}" if self.namespace else partition_key
)
await self.connection.zadd(namespaced_partition_key, {app_id: sequence_id})

def __getstate__(self) -> dict:
state = self.__dict__.copy()
if not hasattr(self.connection, "connection_pool"):
logger.warning("Redis connection is not serializable.")
return state
state["connection_params"] = {
"host": self.connection.connection_pool.connection_kwargs["host"],
"port": self.connection.connection_pool.connection_kwargs["port"],
"db": self.connection.connection_pool.connection_kwargs["db"],
"password": self.connection.connection_pool.connection_kwargs["password"],
}
del state["connection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume normal redis client.
self.connection = aredis.Redis(**connection_params)
self.__dict__.update(state)


if __name__ == "__main__":
# test the RedisBasePersister class
persister = RedisBasePersister.from_values("localhost", 6379, 0)
Expand Down
5 changes: 5 additions & 0 deletions docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,8 @@ Currently we support the following, although we highly recommend you contribute
:members:

.. automethod:: __init__

.. autoclass:: burr.integrations.persisters.b_redis.AsyncRedisBasePersister
:members:

.. automethod:: __init__
86 changes: 85 additions & 1 deletion tests/integrations/persisters/test_b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import pytest

from burr.core import state
from burr.integrations.persisters.b_redis import RedisBasePersister, RedisPersister
from burr.integrations.persisters.b_redis import (
AsyncRedisBasePersister,
RedisBasePersister,
RedisPersister,
)

if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true":
pytest.skip("Skipping integration tests", allow_module_level=True)
Expand Down Expand Up @@ -89,3 +93,83 @@ def test_serialization_with_pickle(redis_persister_with_ns):
data = deserialized_persister.load("pk", "app_id_serde", 1)

assert data["state"].get_all() == {"a": 1, "b": 2}


@pytest.fixture
async def async_redis_persister():
persister = AsyncRedisBasePersister.from_values(host="localhost", port=6379, db=0)
yield persister
await persister.connection.aclose()


@pytest.fixture
async def async_redis_persister_with_ns():
persister = AsyncRedisBasePersister.from_values(
host="localhost", port=6379, db=0, namespace="test"
)
yield persister
await persister.connection.aclose()


async def test_async_save_and_load_state(async_redis_persister):
await async_redis_persister.save(
"pk", "app_id", 1, "pos", state.State({"a": 1, "b": 2}), "completed"
)
data = await async_redis_persister.load("pk", "app_id", 1)
assert data["state"].get_all() == {"a": 1, "b": 2}


async def test_async_list_app_ids(async_redis_persister):
await async_redis_persister.save("pk", "app_id1", 1, "pos1", state.State({"a": 1}), "completed")
await async_redis_persister.save("pk", "app_id2", 2, "pos2", state.State({"b": 2}), "completed")
app_ids = await async_redis_persister.list_app_ids("pk")
assert "app_id1" in app_ids
assert "app_id2" in app_ids


async def test_async_load_nonexistent_key(async_redis_persister):
state_data = await async_redis_persister.load("pk", "nonexistent_key")
assert state_data is None


async def test_async_save_and_load_state_ns(async_redis_persister_with_ns):
await async_redis_persister_with_ns.save(
"pk", "app_id", 1, "pos", state.State({"a": 1, "b": 2}), "completed"
)
data = await async_redis_persister_with_ns.load("pk", "app_id", 1)
assert data["state"].get_all() == {"a": 1, "b": 2}


async def test_async_list_app_ids_with_ns(async_redis_persister_with_ns):
await async_redis_persister_with_ns.save(
"pk", "app_id1", 1, "pos1", state.State({"a": 1}), "completed"
)
await async_redis_persister_with_ns.save(
"pk", "app_id2", 2, "pos2", state.State({"b": 2}), "completed"
)
app_ids = await async_redis_persister_with_ns.list_app_ids("pk")
assert "app_id1" in app_ids
assert "app_id2" in app_ids


async def test_async_load_nonexistent_key_with_ns(async_redis_persister_with_ns):
state_data = await async_redis_persister_with_ns.load("pk", "nonexistent_key")
assert state_data is None


async def test_async_serialization_with_pickle(async_redis_persister_with_ns):
# Save some state
await async_redis_persister_with_ns.save(
"pk", "app_id_serde", 1, "pos", state.State({"a": 1, "b": 2}), "completed"
)

# Serialize the persister
serialized_persister = pickle.dumps(async_redis_persister_with_ns)

# Deserialize the persister
deserialized_persister = pickle.loads(serialized_persister)

# Load the state from the deserialized persister
data = await deserialized_persister.load("pk", "app_id_serde", 1)

assert data["state"].get_all() == {"a": 1, "b": 2}

0 comments on commit 304eb21

Please sign in to comment.