Skip to content

adding MimeNode type #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
100 changes: 93 additions & 7 deletions libs/knowledge-store/ragstack_knowledge_store/embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,108 @@
from abc import ABC, abstractmethod
from typing import List

from abc import ABC
from typing import List, Any, Optional
from collections import defaultdict

class EmbeddingModel(ABC):
"""Embedding model."""

@abstractmethod
def __init__(self, embeddings: Any, method_map: Optional[dict] = None, other_methods: Optional[List[str]] = None):
self.embeddings = embeddings
self.method_name = {}
method_map = method_map if method_map else {}
other_methods = other_methods if other_methods else []

base_methods = ['embed_texts', 'aembed_texts', 'embed_query', 'aembed_query']
extended_methods = ['embed_images', 'aembed_images', 'embed_image', 'aembed_image']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should try to add all of these as methods, it's definitely pretty messy.

I think we should just have embed_mime(self, mime_type: str, content: Union[str, Bytes]) or something like that. Then there is only a single abstract method to use for any mime type and the names can be different, etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% but right now, LangChain doesn't have "embed_mime" :)


# Combining all method names, including those mapped
all_methods = set(base_methods + extended_methods + other_methods + list(method_map.values()))

for method in all_methods:
mapped_method = method_map.get(method)
if hasattr(embeddings, method):
self.method_name[method] = method
elif hasattr(embeddings, mapped_method) if mapped_method else False:
self.method_name[method] = mapped_method
else:
self.method_name[method] = None

def does_implement(self, method_name: str) -> bool:
"""Check if the method is implemented."""
return self.method_name.get(method_name) is not None

def implements(self) -> List[str]:
"""List of methods that are implemented"""
return [method for method, impl in self.method_name.items() if impl is not None]

def invoke(self, method_name: str, *args, **kwargs):
"""Invoke a synchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

async def ainvoke(self, method_name: str, *args, **kwargs):
"""Invoke an asynchronous method if it's implemented."""
target_method = self.method_name.get(method_name)
if target_method and hasattr(self.embeddings, target_method):
return await getattr(self.embeddings, target_method)(*args, **kwargs)
else:
raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}")

def embed_mimes(self, texts: List[str], mime_types: List[str]) -> List[List[float]]:
"""Embed mime content."""

# Extract main MIME types
main_mime_types = [mime_type.split('/')[0] for mime_type in mime_types]

# Group texts by main MIME types
grouped_texts = defaultdict(list)
index_mapping = defaultdict(list)
for index, (text, main_mime_type) in enumerate(zip(texts, main_mime_types)):
grouped_texts[main_mime_type].append(text)
index_mapping[main_mime_type].append(index)

# Initialize result list with None to preserve order
embeddings = [None] * len(texts)

# Process each MIME type group
for mime_type, group_texts in grouped_texts.items():
method_name = f"embed_{mime_type}s"
if self.does_implement(method_name):
# Bulk embedding method exists
group_embeddings = self.invoke(method_name, group_texts)
for idx, emb in zip(index_mapping[mime_type], group_embeddings):
embeddings[idx] = emb
else:
# No bulk method, fall back to individual methods
singular_method_name = f"embed_{mime_type}"
for text, idx in zip(group_texts, index_mapping[mime_type]):
if self.does_implement(singular_method_name):
embedding = self.invoke(singular_method_name, text)
embeddings[idx] = embedding
else:
raise NotImplementedError(f"No embedding method available for MIME type: {mime_type}, implemented methods: {self.implements()}.")

# Ensure all embeddings are computed
if None in embeddings:
raise ValueError("Some embeddings were not computed correctly.")

return embeddings

def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return self.invoke('embed_texts', texts)

@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.invoke('embed_query', text)

@abstractmethod
async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts."""
return await self.ainvoke('aembed_texts', texts)

@abstractmethod
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text."""
return await self.ainvoke('aembed_query', text)

10 changes: 9 additions & 1 deletion libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,19 @@ class Node:

id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphStore if not set."""

text: str = None
"""Text contained by the node."""
metadata: dict = field(default_factory=dict)
"""Metadata for the node."""
links: Set[Link] = field(default_factory=set)
"""Links for the node."""

mime_type: str = "text/plain"
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

class SetupMode(Enum):
SYNC = 1
Expand Down Expand Up @@ -338,6 +344,7 @@ def add_nodes(
node_ids = []
texts = []
metadatas = []
mime_types = []
nodes_links: List[Set[Link]] = []
for node in nodes:
if not node.id:
Expand All @@ -346,9 +353,10 @@ def add_nodes(
node_ids.append(node.id)
texts.append(node.text)
metadatas.append(node.metadata)
mime_types.append(node.mime_type)
nodes_links.append(node.links)

text_embeddings = self._embedding.embed_texts(texts)
text_embeddings = self._embedding.embed_mimes(texts,mime_types)

with self._concurrent_queries() as cq:
tuples = zip(node_ids, texts, text_embeddings, metadatas, nodes_links)
Expand Down
16 changes: 14 additions & 2 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_core.pydantic_v1 import Field

from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link


def _has_next(iterator: Iterator) -> bool:
"""Checks if the iterator has more elements.
Warning: consumes an element from the iterator"""
Expand All @@ -41,13 +39,21 @@ class Node(Serializable):

id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphStore if not set."""

text: str
"""Text contained by the node."""

metadata: dict = Field(default_factory=dict)
"""Metadata for the node."""

links: Set[Link] = Field(default_factory=set)
"""Links associated with the node."""

mime_type: str = "text/plain"
"""Type of content, e.g. text/plain or image/png."""

mime_encoding: str = None
"""Encoding format"""

def _texts_to_nodes(
texts: Iterable[str],
Expand All @@ -67,13 +73,15 @@ def _texts_to_nodes(
except StopIteration:
raise ValueError("texts iterable longer than ids")

mime_type = _metadata.get("mime_type", "text/plain")
links = _metadata.pop(METADATA_LINKS_KEY, set())
if not isinstance(links, Set):
links = set(links)
yield Node(
id=_id,
metadata=_metadata,
text=text,
mime_type=mime_type,
links=links,
)
if ids_it and _has_next(ids_it):
Expand All @@ -94,12 +102,16 @@ def _documents_to_nodes(
raise ValueError("documents iterable longer than ids")
metadata = doc.metadata.copy()
links = metadata.pop(METADATA_LINKS_KEY, set())
mime_type = metadata.get("mime_type","text/plain")
mime_encoding = metadata.get("mime_encoding")
if not isinstance(links, Set):
links = set(links)
yield Node(
id=_id,
metadata=metadata,
text=doc.page_content,
mime_type=mime_type,
mime_encoding=mime_encoding,
links=links,
)
if ids_it and _has_next(ids_it):
Expand Down
29 changes: 5 additions & 24 deletions libs/langchain/ragstack_langchain/graph_store/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,8 @@
from langchain_core.embeddings import Embeddings

from .base import GraphStore, Node, nodes_to_documents
from ragstack_knowledge_store import EmbeddingModel, graph_store


class _EmbeddingModelAdapter(EmbeddingModel):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings

def embed_texts(self, texts: List[str]) -> List[List[float]]:
return self.embeddings.embed_documents(texts)

def embed_query(self, text: str) -> List[float]:
return self.embeddings.embed_query(text)

async def aembed_texts(self, texts: List[str]) -> List[List[float]]:
return await self.embeddings.aembed_documents(texts)

async def aembed_query(self, text: str) -> List[float]:
return await self.embeddings.aembed_query(text)

from .embedding_adapter import EmbeddingAdapter
from ragstack_knowledge_store import graph_store

class CassandraGraphStore(GraphStore):
def __init__(
Expand Down Expand Up @@ -60,7 +43,7 @@ def __init__(
_setup_mode = getattr(graph_store.SetupMode, setup_mode.name)

self.store = graph_store.GraphStore(
embedding=_EmbeddingModelAdapter(embedding),
embedding=EmbeddingAdapter(embedding),
node_table=node_table,
targets_table=targets_table,
session=session,
Expand All @@ -80,10 +63,8 @@ def add_nodes(
_nodes = []
for node in nodes:
_nodes.append(
graph_store.Node(
id=node.id, text=node.text, metadata=node.metadata, links=node.links
)
)
graph_store.Node(id=node.id, text=node.text, mime_type=node.mime_type, mime_encoding=node.mime_encoding, metadata=node.metadata)
)
return self.store.add_nodes(_nodes)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import List
from ragstack_knowledge_store import EmbeddingModel

class EmbeddingAdapter(EmbeddingModel):
def __init__(self, embeddings):
super().__init__(embeddings,
method_map={'embed_texts': 'embed_documents',
'aembed_texts': 'aembed_documents'})

Loading