Skip to content

Commit

Permalink
Support index embed specification via string
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Feb 5, 2025
1 parent 04a0443 commit 3cf60de
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
3 changes: 2 additions & 1 deletion libs/checkpoint/langgraph/store/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,13 +493,14 @@ class IndexConfig(TypedDict, total=False):
- cohere:embed-multilingual-light-v3.0: 384
"""

embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc]
embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc, str]
"""Optional function to generate embeddings from text.
Can be specified in three ways:
1. A LangChain Embeddings instance
2. A synchronous embedding function (EmbeddingsFunc)
3. An asynchronous embedding function (AEmbeddingsFunc)
4. A provider string (e.g., "openai:text-embedding-3-small")
???+ example "Examples"
Using LangChain's initialization with InMemoryStore:
Expand Down
41 changes: 40 additions & 1 deletion libs/checkpoint/langgraph/store/base/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import asyncio
import functools
import json
from typing import Any, Awaitable, Callable, Optional, Sequence, Union

Expand All @@ -28,7 +29,7 @@


def ensure_embeddings(
embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc, None],
embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc, str, None],
) -> Embeddings:
"""Ensure that an embedding function conforms to LangChain's Embeddings interface.
Expand Down Expand Up @@ -62,9 +63,37 @@ async def my_async_fn(texts):
embeddings = ensure_embeddings(my_async_fn)
result = await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
```
Initialize embeddings using a provider string:
```python
# Requires langchain>=0.3.9 and langgraph-checkpoint>=2.0.11
embeddings = ensure_embeddings("openai:text-embedding-3-small")
result = embeddings.embed_query("hello")
```
"""
if embed is None:
raise ValueError("embed must be provided")
if isinstance(embed, str):
init_embeddings = _get_init_embeddings()
if init_embeddings is None:
from importlib.metadata import PackageNotFoundError, version

try:
lc_version = version("langchain")
version_info = f"Found langchain version {lc_version}, but"
except PackageNotFoundError:
version_info = "langchain is not installed;"

raise ValueError(
f"Could not load embeddings from string '{embed}'. {version_info} "
"loading embeddings by provider:identifier string requires langchain>=0.3.9 "
"as well as the provider-specific package. "
"Install LangChain with: pip install 'langchain>=0.3.9' "
"and the provider-specific package (e.g., 'langchain-openai>=0.3.0'). "
"Alternatively, specify 'embed' as a compatible Embeddings object or python function."
)
return init_embeddings(embed)

if isinstance(embed, Embeddings):
return embed
return EmbeddingsLambda(embed)
Expand Down Expand Up @@ -373,6 +402,16 @@ def _is_async_callable(
)


@functools.lru_cache
def _get_init_embeddings() -> Callable[[str, ...], "Embeddings"] | None:
try:
from langchain.embeddings import init_embeddings # noqa: [import-not-found]

return init_embeddings
except ImportError:
return None


__all__ = [
"ensure_embeddings",
"EmbeddingsFunc",
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint"
version = "2.0.10"
version = "2.0.11"
description = "Library with base interfaces for LangGraph checkpoint savers."
authors = []
license = "MIT"
Expand Down

0 comments on commit 3cf60de

Please sign in to comment.