forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Add SQL storage implementation (langchain-ai#22207)
Hello @eyurtsev - package: langchain-comminity - **Description**: Add SQL implementation for docstore. A new implementation, in line with my other PR ([async PGVector](langchain-ai/langchain-postgres#32), [SQLChatMessageMemory](langchain-ai#22065)) - Twitter handler: pprados --------- Signed-off-by: ChengZi <[email protected]> Co-authored-by: Bagatur <[email protected]> Co-authored-by: Piotr Mardziel <[email protected]> Co-authored-by: ChengZi <[email protected]> Co-authored-by: Eugene Yurtsev <[email protected]>
- Loading branch information
1 parent
f2f0e0e
commit 9aabb44
Showing
6 changed files
with
548 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -390,4 +390,4 @@ | |
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
import contextlib | ||
from pathlib import Path | ||
from typing import ( | ||
Any, | ||
AsyncGenerator, | ||
AsyncIterator, | ||
Dict, | ||
Generator, | ||
Iterator, | ||
List, | ||
Optional, | ||
Sequence, | ||
Tuple, | ||
Union, | ||
cast, | ||
) | ||
|
||
from langchain_core.stores import BaseStore | ||
from sqlalchemy import ( | ||
Engine, | ||
LargeBinary, | ||
and_, | ||
create_engine, | ||
delete, | ||
select, | ||
) | ||
from sqlalchemy.ext.asyncio import ( | ||
AsyncEngine, | ||
AsyncSession, | ||
async_sessionmaker, | ||
create_async_engine, | ||
) | ||
from sqlalchemy.orm import ( | ||
Mapped, | ||
Session, | ||
declarative_base, | ||
mapped_column, | ||
sessionmaker, | ||
) | ||
|
||
Base = declarative_base() | ||
|
||
|
||
def items_equal(x: Any, y: Any) -> bool: | ||
return x == y | ||
|
||
|
||
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc] | ||
"""Table used to save values.""" | ||
|
||
# ATTENTION: | ||
# Prior to modifying this table, please determine whether | ||
# we should create migrations for this table to make sure | ||
# users do not experience data loss. | ||
__tablename__ = "langchain_key_value_stores" | ||
|
||
namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) | ||
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False) | ||
value = mapped_column(LargeBinary, index=False, nullable=False) | ||
|
||
|
||
# This is a fix of original SQLStore. | ||
# This can will be removed when a PR will be merged. | ||
class SQLStore(BaseStore[str, bytes]): | ||
"""BaseStore interface that works on an SQL database. | ||
Examples: | ||
Create a SQLStore instance and perform operations on it: | ||
.. code-block:: python | ||
from langchain_rag.storage import SQLStore | ||
# Instantiate the SQLStore with the root path | ||
sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:") | ||
# Set values for keys | ||
sql_store.mset([("key1", b"value1"), ("key2", b"value2")]) | ||
# Get values for keys | ||
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] | ||
# Delete keys | ||
sql_store.mdelete(["key1"]) | ||
# Iterate over keys | ||
for key in sql_store.yield_keys(): | ||
print(key) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
namespace: str, | ||
db_url: Optional[Union[str, Path]] = None, | ||
engine: Optional[Union[Engine, AsyncEngine]] = None, | ||
engine_kwargs: Optional[Dict[str, Any]] = None, | ||
async_mode: Optional[bool] = None, | ||
): | ||
if db_url is None and engine is None: | ||
raise ValueError("Must specify either db_url or engine") | ||
|
||
if db_url is not None and engine is not None: | ||
raise ValueError("Must specify either db_url or engine, not both") | ||
|
||
_engine: Union[Engine, AsyncEngine] | ||
if db_url: | ||
if async_mode is None: | ||
async_mode = False | ||
if async_mode: | ||
_engine = create_async_engine( | ||
url=str(db_url), | ||
**(engine_kwargs or {}), | ||
) | ||
else: | ||
_engine = create_engine(url=str(db_url), **(engine_kwargs or {})) | ||
elif engine: | ||
_engine = engine | ||
|
||
else: | ||
raise AssertionError("Something went wrong with configuration of engine.") | ||
|
||
_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] | ||
if isinstance(_engine, AsyncEngine): | ||
self.async_mode = True | ||
_session_maker = async_sessionmaker(bind=_engine) | ||
else: | ||
self.async_mode = False | ||
_session_maker = sessionmaker(bind=_engine) | ||
|
||
self.engine = _engine | ||
self.dialect = _engine.dialect.name | ||
self.session_maker = _session_maker | ||
self.namespace = namespace | ||
|
||
def create_schema(self) -> None: | ||
Base.metadata.create_all(self.engine) | ||
|
||
async def acreate_schema(self) -> None: | ||
assert isinstance(self.engine, AsyncEngine) | ||
async with self.engine.begin() as session: | ||
await session.run_sync(Base.metadata.create_all) | ||
|
||
def drop(self) -> None: | ||
Base.metadata.drop_all(bind=self.engine.connect()) | ||
|
||
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: | ||
assert isinstance(self.engine, AsyncEngine) | ||
result: Dict[str, bytes] = {} | ||
async with self._make_async_session() as session: | ||
stmt = select(LangchainKeyValueStores).filter( | ||
and_( | ||
LangchainKeyValueStores.key.in_(keys), | ||
LangchainKeyValueStores.namespace == self.namespace, | ||
) | ||
) | ||
for v in await session.scalars(stmt): | ||
result[v.key] = v.value | ||
return [result.get(key) for key in keys] | ||
|
||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: | ||
result = {} | ||
|
||
with self._make_sync_session() as session: | ||
stmt = select(LangchainKeyValueStores).filter( | ||
and_( | ||
LangchainKeyValueStores.key.in_(keys), | ||
LangchainKeyValueStores.namespace == self.namespace, | ||
) | ||
) | ||
for v in session.scalars(stmt): | ||
result[v.key] = v.value | ||
return [result.get(key) for key in keys] | ||
|
||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: | ||
async with self._make_async_session() as session: | ||
await self._amdelete([key for key, _ in key_value_pairs], session) | ||
session.add_all( | ||
[ | ||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) | ||
for k, v in key_value_pairs | ||
] | ||
) | ||
await session.commit() | ||
|
||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: | ||
values: Dict[str, bytes] = dict(key_value_pairs) | ||
with self._make_sync_session() as session: | ||
self._mdelete(list(values.keys()), session) | ||
session.add_all( | ||
[ | ||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) | ||
for k, v in values.items() | ||
] | ||
) | ||
session.commit() | ||
|
||
def _mdelete(self, keys: Sequence[str], session: Session) -> None: | ||
stmt = delete(LangchainKeyValueStores).filter( | ||
and_( | ||
LangchainKeyValueStores.key.in_(keys), | ||
LangchainKeyValueStores.namespace == self.namespace, | ||
) | ||
) | ||
session.execute(stmt) | ||
|
||
async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None: | ||
stmt = delete(LangchainKeyValueStores).filter( | ||
and_( | ||
LangchainKeyValueStores.key.in_(keys), | ||
LangchainKeyValueStores.namespace == self.namespace, | ||
) | ||
) | ||
await session.execute(stmt) | ||
|
||
def mdelete(self, keys: Sequence[str]) -> None: | ||
with self._make_sync_session() as session: | ||
self._mdelete(keys, session) | ||
session.commit() | ||
|
||
async def amdelete(self, keys: Sequence[str]) -> None: | ||
async with self._make_async_session() as session: | ||
await self._amdelete(keys, session) | ||
await session.commit() | ||
|
||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: | ||
with self._make_sync_session() as session: | ||
for v in session.query(LangchainKeyValueStores).filter( # type: ignore | ||
LangchainKeyValueStores.namespace == self.namespace | ||
): | ||
if str(v.key).startswith(prefix or ""): | ||
yield str(v.key) | ||
session.close() | ||
|
||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]: | ||
async with self._make_async_session() as session: | ||
stmt = select(LangchainKeyValueStores).filter( | ||
LangchainKeyValueStores.namespace == self.namespace | ||
) | ||
for v in await session.scalars(stmt): | ||
if str(v.key).startswith(prefix or ""): | ||
yield str(v.key) | ||
await session.close() | ||
|
||
@contextlib.contextmanager | ||
def _make_sync_session(self) -> Generator[Session, None, None]: | ||
"""Make an async session.""" | ||
if self.async_mode: | ||
raise ValueError( | ||
"Attempting to use a sync method in when async mode is turned on. " | ||
"Please use the corresponding async method instead." | ||
) | ||
with cast(Session, self.session_maker()) as session: | ||
yield cast(Session, session) | ||
|
||
@contextlib.asynccontextmanager | ||
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: | ||
"""Make an async session.""" | ||
if not self.async_mode: | ||
raise ValueError( | ||
"Attempting to use an async method in when sync mode is turned on. " | ||
"Please use the corresponding async method instead." | ||
) | ||
async with cast(AsyncSession, self.session_maker()) as session: | ||
yield cast(AsyncSession, session) |
Oops, something went wrong.