From 9aabb446c597292c0b387779582795d2e28feb28 Mon Sep 17 00:00:00 2001 From: Philippe PRADOS Date: Fri, 7 Jun 2024 23:17:02 +0200 Subject: [PATCH] community[minor]: Add SQL storage implementation (#22207) Hello @eyurtsev - package: langchain-comminity - **Description**: Add SQL implementation for docstore. A new implementation, in line with my other PR ([async PGVector](https://github.com/langchain-ai/langchain-postgres/pull/32), [SQLChatMessageMemory](https://github.com/langchain-ai/langchain/pull/22065)) - Twitter handler: pprados --------- Signed-off-by: ChengZi Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Piotr Mardziel Co-authored-by: ChengZi Co-authored-by: Eugene Yurtsev --- .../integrations/vectorstores/milvus.ipynb | 2 +- .../langchain_community/storage/__init__.py | 5 + .../langchain_community/storage/sql.py | 266 ++++++++++++++++++ .../integration_tests/storage/test_sql.py | 186 ++++++++++++ .../tests/unit_tests/storage/test_imports.py | 1 + .../tests/unit_tests/storage/test_sql.py | 89 ++++++ 6 files changed, 548 insertions(+), 1 deletion(-) create mode 100644 libs/community/langchain_community/storage/sql.py create mode 100644 libs/community/tests/integration_tests/storage/test_sql.py create mode 100644 libs/community/tests/unit_tests/storage/test_sql.py diff --git a/docs/docs/integrations/vectorstores/milvus.ipynb b/docs/docs/integrations/vectorstores/milvus.ipynb index 24cb43b436bdb..4c314dfa15f0b 100644 --- a/docs/docs/integrations/vectorstores/milvus.ipynb +++ b/docs/docs/integrations/vectorstores/milvus.ipynb @@ -390,4 +390,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/libs/community/langchain_community/storage/__init__.py b/libs/community/langchain_community/storage/__init__.py index 9a73d49110afc..d75b497bf584f 100644 --- a/libs/community/langchain_community/storage/__init__.py +++ b/libs/community/langchain_community/storage/__init__.py @@ -31,6 +31,9 @@ from langchain_community.storage.redis import ( RedisStore, ) + from langchain_community.storage.sql import ( + SQLStore, + ) from langchain_community.storage.upstash_redis import ( UpstashRedisByteStore, UpstashRedisStore, @@ -42,6 +45,7 @@ "CassandraByteStore", "MongoDBStore", "RedisStore", + "SQLStore", "UpstashRedisByteStore", "UpstashRedisStore", ] @@ -52,6 +56,7 @@ "CassandraByteStore": "langchain_community.storage.cassandra", "MongoDBStore": "langchain_community.storage.mongodb", "RedisStore": "langchain_community.storage.redis", + "SQLStore": "langchain_community.storage.sql", "UpstashRedisByteStore": "langchain_community.storage.upstash_redis", "UpstashRedisStore": "langchain_community.storage.upstash_redis", } diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py new file mode 100644 index 0000000000000..a92daae1d8c67 --- /dev/null +++ b/libs/community/langchain_community/storage/sql.py @@ -0,0 +1,266 @@ +import contextlib +from pathlib import Path +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Generator, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from langchain_core.stores import BaseStore +from sqlalchemy import ( + Engine, + LargeBinary, + and_, + create_engine, + delete, + select, +) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Mapped, + Session, + declarative_base, + mapped_column, + sessionmaker, +) + +Base = declarative_base() + + +def items_equal(x: Any, y: Any) -> bool: + return x == y + + +class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc] + """Table used to save values.""" + + # ATTENTION: + # Prior to modifying this table, please determine whether + # we should create migrations for this table to make sure + # users do not experience data loss. + __tablename__ = "langchain_key_value_stores" + + namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) + value = mapped_column(LargeBinary, index=False, nullable=False) + + +# This is a fix of original SQLStore. +# This can will be removed when a PR will be merged. +class SQLStore(BaseStore[str, bytes]): + """BaseStore interface that works on an SQL database. + + Examples: + Create a SQLStore instance and perform operations on it: + + .. code-block:: python + + from langchain_rag.storage import SQLStore + + # Instantiate the SQLStore with the root path + sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:") + + # Set values for keys + sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + for key in sql_store.yield_keys(): + print(key) + + """ + + def __init__( + self, + *, + namespace: str, + db_url: Optional[Union[str, Path]] = None, + engine: Optional[Union[Engine, AsyncEngine]] = None, + engine_kwargs: Optional[Dict[str, Any]] = None, + async_mode: Optional[bool] = None, + ): + if db_url is None and engine is None: + raise ValueError("Must specify either db_url or engine") + + if db_url is not None and engine is not None: + raise ValueError("Must specify either db_url or engine, not both") + + _engine: Union[Engine, AsyncEngine] + if db_url: + if async_mode is None: + async_mode = False + if async_mode: + _engine = create_async_engine( + url=str(db_url), + **(engine_kwargs or {}), + ) + else: + _engine = create_engine(url=str(db_url), **(engine_kwargs or {})) + elif engine: + _engine = engine + + else: + raise AssertionError("Something went wrong with configuration of engine.") + + _session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] + if isinstance(_engine, AsyncEngine): + self.async_mode = True + _session_maker = async_sessionmaker(bind=_engine) + else: + self.async_mode = False + _session_maker = sessionmaker(bind=_engine) + + self.engine = _engine + self.dialect = _engine.dialect.name + self.session_maker = _session_maker + self.namespace = namespace + + def create_schema(self) -> None: + Base.metadata.create_all(self.engine) + + async def acreate_schema(self) -> None: + assert isinstance(self.engine, AsyncEngine) + async with self.engine.begin() as session: + await session.run_sync(Base.metadata.create_all) + + def drop(self) -> None: + Base.metadata.drop_all(bind=self.engine.connect()) + + async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + assert isinstance(self.engine, AsyncEngine) + result: Dict[str, bytes] = {} + async with self._make_async_session() as session: + stmt = select(LangchainKeyValueStores).filter( + and_( + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, + ) + ) + for v in await session.scalars(stmt): + result[v.key] = v.value + return [result.get(key) for key in keys] + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + result = {} + + with self._make_sync_session() as session: + stmt = select(LangchainKeyValueStores).filter( + and_( + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, + ) + ) + for v in session.scalars(stmt): + result[v.key] = v.value + return [result.get(key) for key in keys] + + async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + async with self._make_async_session() as session: + await self._amdelete([key for key, _ in key_value_pairs], session) + session.add_all( + [ + LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) + for k, v in key_value_pairs + ] + ) + await session.commit() + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + values: Dict[str, bytes] = dict(key_value_pairs) + with self._make_sync_session() as session: + self._mdelete(list(values.keys()), session) + session.add_all( + [ + LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) + for k, v in values.items() + ] + ) + session.commit() + + def _mdelete(self, keys: Sequence[str], session: Session) -> None: + stmt = delete(LangchainKeyValueStores).filter( + and_( + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, + ) + ) + session.execute(stmt) + + async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None: + stmt = delete(LangchainKeyValueStores).filter( + and_( + LangchainKeyValueStores.key.in_(keys), + LangchainKeyValueStores.namespace == self.namespace, + ) + ) + await session.execute(stmt) + + def mdelete(self, keys: Sequence[str]) -> None: + with self._make_sync_session() as session: + self._mdelete(keys, session) + session.commit() + + async def amdelete(self, keys: Sequence[str]) -> None: + async with self._make_async_session() as session: + await self._amdelete(keys, session) + await session.commit() + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + with self._make_sync_session() as session: + for v in session.query(LangchainKeyValueStores).filter( # type: ignore + LangchainKeyValueStores.namespace == self.namespace + ): + if str(v.key).startswith(prefix or ""): + yield str(v.key) + session.close() + + async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: + async with self._make_async_session() as session: + stmt = select(LangchainKeyValueStores).filter( + LangchainKeyValueStores.namespace == self.namespace + ) + for v in await session.scalars(stmt): + if str(v.key).startswith(prefix or ""): + yield str(v.key) + await session.close() + + @contextlib.contextmanager + def _make_sync_session(self) -> Generator[Session, None, None]: + """Make an async session.""" + if self.async_mode: + raise ValueError( + "Attempting to use a sync method in when async mode is turned on. " + "Please use the corresponding async method instead." + ) + with cast(Session, self.session_maker()) as session: + yield cast(Session, session) + + @contextlib.asynccontextmanager + async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: + """Make an async session.""" + if not self.async_mode: + raise ValueError( + "Attempting to use an async method in when sync mode is turned on. " + "Please use the corresponding async method instead." + ) + async with cast(AsyncSession, self.session_maker()) as session: + yield cast(AsyncSession, session) diff --git a/libs/community/tests/integration_tests/storage/test_sql.py b/libs/community/tests/integration_tests/storage/test_sql.py new file mode 100644 index 0000000000000..a454029b86cdf --- /dev/null +++ b/libs/community/tests/integration_tests/storage/test_sql.py @@ -0,0 +1,186 @@ +"""Implement integration tests for Redis storage.""" + +import pytest +from sqlalchemy import Engine, create_engine, text +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from langchain_community.storage import SQLStore + +pytest.importorskip("sqlalchemy") + + +@pytest.fixture +def sql_engine() -> Engine: + """Yield redis client.""" + return create_engine(url="sqlite://", echo=True) + + +@pytest.fixture +def sql_aengine() -> AsyncEngine: + """Yield redis client.""" + return create_async_engine(url="sqlite+aiosqlite:///:memory:", echo=True) + + +def test_mget(sql_engine: Engine) -> None: + """Test mget method.""" + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + keys = ["key1", "key2"] + with sql_engine.connect() as session: + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) + session.commit() + + result = store.mget(keys) + assert result == [b"value1", b"value2"] + + +@pytest.mark.asyncio +async def test_amget(sql_aengine: AsyncEngine) -> None: + """Test mget method.""" + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + keys = ["key1", "key2"] + async with sql_aengine.connect() as session: + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) + await session.commit() + + result = await store.amget(keys) + assert result == [b"value1", b"value2"] + + +def test_mset(sql_engine: Engine) -> None: + """Test that multiple keys can be set.""" + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + store.mset(key_value_pairs) + + with sql_engine.connect() as session: + result = session.exec_driver_sql("select * from langchain_key_value_stores") + assert result.keys() == ["namespace", "key", "value"] + data = [(row[0], row[1]) for row in result] + assert data == [("test", "key1"), ("test", "key2")] + session.commit() + + +@pytest.mark.asyncio +async def test_amset(sql_aengine: AsyncEngine) -> None: + """Test that multiple keys can be set.""" + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + await store.amset(key_value_pairs) + + async with sql_aengine.connect() as session: + result = await session.exec_driver_sql( + "select * from langchain_key_value_stores" + ) + assert result.keys() == ["namespace", "key", "value"] + data = [(row[0], row[1]) for row in result] + assert data == [("test", "key1"), ("test", "key2")] + await session.commit() + + +def test_mdelete(sql_engine: Engine) -> None: + """Test that deletion works as expected.""" + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + keys = ["key1", "key2"] + with sql_engine.connect() as session: + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) + session.commit() + store.mdelete(keys) + with sql_engine.connect() as session: + result = session.exec_driver_sql("select * from langchain_key_value_stores") + assert result.keys() == ["namespace", "key", "value"] + data = [row for row in result] + assert data == [] + session.commit() + + +@pytest.mark.asyncio +async def test_amdelete(sql_aengine: AsyncEngine) -> None: + """Test that deletion works as expected.""" + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + keys = ["key1", "key2"] + async with sql_aengine.connect() as session: + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key1',:value)" + ).bindparams(value=b"value1"), + ) + await session.execute( + text( + "insert into langchain_key_value_stores ('namespace', 'key', 'value') " + "values('test','key2',:value)" + ).bindparams(value=b"value2"), + ) + await session.commit() + await store.amdelete(keys) + async with sql_aengine.connect() as session: + result = await session.exec_driver_sql( + "select * from langchain_key_value_stores" + ) + assert result.keys() == ["namespace", "key", "value"] + data = [row for row in result] + assert data == [] + await session.commit() + + +def test_yield_keys(sql_engine: Engine) -> None: + store = SQLStore(engine=sql_engine, namespace="test") + store.create_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + store.mset(key_value_pairs) + assert sorted(store.yield_keys()) == ["key1", "key2"] + assert sorted(store.yield_keys(prefix="key")) == ["key1", "key2"] + assert sorted(store.yield_keys(prefix="lang")) == [] + + +@pytest.mark.asyncio +async def test_ayield_keys(sql_aengine: AsyncEngine) -> None: + store = SQLStore(engine=sql_aengine, namespace="test") + await store.acreate_schema() + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] + await store.amset(key_value_pairs) + assert sorted([k async for k in store.ayield_keys()]) == ["key1", "key2"] + assert sorted([k async for k in store.ayield_keys(prefix="key")]) == [ + "key1", + "key2", + ] + assert sorted([k async for k in store.ayield_keys(prefix="lang")]) == [] diff --git a/libs/community/tests/unit_tests/storage/test_imports.py b/libs/community/tests/unit_tests/storage/test_imports.py index 750b7c5a3e2f7..791f0298cc5e7 100644 --- a/libs/community/tests/unit_tests/storage/test_imports.py +++ b/libs/community/tests/unit_tests/storage/test_imports.py @@ -5,6 +5,7 @@ "AstraDBByteStore", "CassandraByteStore", "MongoDBStore", + "SQLStore", "RedisStore", "UpstashRedisByteStore", "UpstashRedisStore", diff --git a/libs/community/tests/unit_tests/storage/test_sql.py b/libs/community/tests/unit_tests/storage/test_sql.py new file mode 100644 index 0000000000000..084f0e2d19089 --- /dev/null +++ b/libs/community/tests/unit_tests/storage/test_sql.py @@ -0,0 +1,89 @@ +from typing import AsyncGenerator, Generator, cast + +import pytest +from langchain.storage._lc_store import create_kv_docstore, create_lc_store +from langchain_core.documents import Document +from langchain_core.stores import BaseStore + +from langchain_community.storage.sql import SQLStore + + +@pytest.fixture +def sql_store() -> Generator[SQLStore, None, None]: + store = SQLStore(namespace="test", db_url="sqlite://") + store.create_schema() + yield store + + +@pytest.fixture +async def async_sql_store() -> AsyncGenerator[SQLStore, None]: + store = SQLStore(namespace="test", db_url="sqlite+aiosqlite://", async_mode=True) + await store.acreate_schema() + yield store + + +def test_create_lc_store(sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore: BaseStore[str, Document] = cast( + BaseStore[str, Document], create_lc_store(sql_store) + ) + docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) + fetched_doc = docstore.mget(["key1"])[0] + assert fetched_doc is not None + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +def test_create_kv_store(sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore = create_kv_docstore(sql_store) + docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))]) + fetched_doc = docstore.mget(["key1"])[0] + assert isinstance(fetched_doc, Document) + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +@pytest.mark.requires("aiosqlite") +async def test_async_create_kv_store(async_sql_store: SQLStore) -> None: + """Test that a docstore is created from a base store.""" + docstore = create_kv_docstore(async_sql_store) + await docstore.amset( + [("key1", Document(page_content="hello", metadata={"key": "value"}))] + ) + fetched_doc = (await docstore.amget(["key1"]))[0] + assert isinstance(fetched_doc, Document) + assert fetched_doc.page_content == "hello" + assert fetched_doc.metadata == {"key": "value"} + + +def test_sample_sql_docstore(sql_store: SQLStore) -> None: + # Set values for keys + sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + assert values == [b"value1", b"value2"] + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + assert [key for key in sql_store.yield_keys()] == ["key2"] + + +@pytest.mark.requires("aiosqlite") +async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None: + # Set values for keys + await async_sql_store.amset([("key1", b"value1"), ("key2", b"value2")]) + # sql_store.mset([("key1", "value1"), ("key2", "value2")]) + + # Get values for keys + values = await async_sql_store.amget( + ["key1", "key2"] + ) # Returns [b"value1", b"value2"] + assert values == [b"value1", b"value2"] + # Delete keys + await async_sql_store.amdelete(["key1"]) + + # Iterate over keys + assert [key async for key in async_sql_store.ayield_keys()] == ["key2"]