diff --git a/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py b/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py index 549814976..77cdc873f 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py +++ b/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py @@ -1,22 +1,108 @@ -from abc import ABC, abstractmethod -from typing import List - +from abc import ABC +from typing import List, Any, Optional +from collections import defaultdict class EmbeddingModel(ABC): """Embedding model.""" - @abstractmethod + def __init__(self, embeddings: Any, method_map: Optional[dict] = None, other_methods: Optional[List[str]] = None): + self.embeddings = embeddings + self.method_name = {} + method_map = method_map if method_map else {} + other_methods = other_methods if other_methods else [] + + base_methods = ['embed_texts', 'aembed_texts', 'embed_query', 'aembed_query'] + extended_methods = ['embed_images', 'aembed_images', 'embed_image', 'aembed_image'] + + # Combining all method names, including those mapped + all_methods = set(base_methods + extended_methods + other_methods + list(method_map.values())) + + for method in all_methods: + mapped_method = method_map.get(method) + if hasattr(embeddings, method): + self.method_name[method] = method + elif hasattr(embeddings, mapped_method) if mapped_method else False: + self.method_name[method] = mapped_method + else: + self.method_name[method] = None + + def does_implement(self, method_name: str) -> bool: + """Check if the method is implemented.""" + return self.method_name.get(method_name) is not None + + def implements(self) -> List[str]: + """List of methods that are implemented""" + return [method for method, impl in self.method_name.items() if impl is not None] + + def invoke(self, method_name: str, *args, **kwargs): + """Invoke a synchronous method if it's implemented.""" + target_method = self.method_name.get(method_name) + if target_method and hasattr(self.embeddings, target_method): + return getattr(self.embeddings, target_method)(*args, **kwargs) + else: + raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}") + + async def ainvoke(self, method_name: str, *args, **kwargs): + """Invoke an asynchronous method if it's implemented.""" + target_method = self.method_name.get(method_name) + if target_method and hasattr(self.embeddings, target_method): + return await getattr(self.embeddings, target_method)(*args, **kwargs) + else: + raise NotImplementedError(f"{self.embeddings.__class__.__name__} does not implement {target_method}") + + def embed_mimes(self, texts: List[str], mime_types: List[str]) -> List[List[float]]: + """Embed mime content.""" + + # Extract main MIME types + main_mime_types = [mime_type.split('/')[0] for mime_type in mime_types] + + # Group texts by main MIME types + grouped_texts = defaultdict(list) + index_mapping = defaultdict(list) + for index, (text, main_mime_type) in enumerate(zip(texts, main_mime_types)): + grouped_texts[main_mime_type].append(text) + index_mapping[main_mime_type].append(index) + + # Initialize result list with None to preserve order + embeddings = [None] * len(texts) + + # Process each MIME type group + for mime_type, group_texts in grouped_texts.items(): + method_name = f"embed_{mime_type}s" + if self.does_implement(method_name): + # Bulk embedding method exists + group_embeddings = self.invoke(method_name, group_texts) + for idx, emb in zip(index_mapping[mime_type], group_embeddings): + embeddings[idx] = emb + else: + # No bulk method, fall back to individual methods + singular_method_name = f"embed_{mime_type}" + for text, idx in zip(group_texts, index_mapping[mime_type]): + if self.does_implement(singular_method_name): + embedding = self.invoke(singular_method_name, text) + embeddings[idx] = embedding + else: + raise NotImplementedError(f"No embedding method available for MIME type: {mime_type}, implemented methods: {self.implements()}.") + + # Ensure all embeddings are computed + if None in embeddings: + raise ValueError("Some embeddings were not computed correctly.") + + return embeddings + def embed_texts(self, texts: List[str]) -> List[List[float]]: """Embed texts.""" + return self.invoke('embed_texts', texts) - @abstractmethod def embed_query(self, text: str) -> List[float]: """Embed query text.""" + return self.invoke('embed_query', text) - @abstractmethod async def aembed_texts(self, texts: List[str]) -> List[List[float]]: """Embed texts.""" + return await self.ainvoke('aembed_texts', texts) - @abstractmethod async def aembed_query(self, text: str) -> List[float]: """Embed query text.""" + return await self.ainvoke('aembed_query', text) + diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 094d1f15f..ff9995c77 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -31,6 +31,7 @@ class Node: id: Optional[str] = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" + text: str = None """Text contained by the node.""" metadata: dict = field(default_factory=dict) @@ -38,6 +39,11 @@ class Node: links: Set[Link] = field(default_factory=set) """Links for the node.""" + mime_type: str = "text/plain" + """Type of content, e.g. text/plain or image/png.""" + + mime_encoding: str = None + """Encoding format""" class SetupMode(Enum): SYNC = 1 @@ -338,6 +344,7 @@ def add_nodes( node_ids = [] texts = [] metadatas = [] + mime_types = [] nodes_links: List[Set[Link]] = [] for node in nodes: if not node.id: @@ -346,9 +353,10 @@ def add_nodes( node_ids.append(node.id) texts.append(node.text) metadatas.append(node.metadata) + mime_types.append(node.mime_type) nodes_links.append(node.links) - text_embeddings = self._embedding.embed_texts(texts) + text_embeddings = self._embedding.embed_mimes(texts,mime_types) with self._concurrent_queries() as cq: tuples = zip(node_ids, texts, text_embeddings, metadatas, nodes_links) diff --git a/libs/langchain/ragstack_langchain/graph_store/base.py b/libs/langchain/ragstack_langchain/graph_store/base.py index 7fed810be..aa027dbb6 100644 --- a/libs/langchain/ragstack_langchain/graph_store/base.py +++ b/libs/langchain/ragstack_langchain/graph_store/base.py @@ -22,10 +22,8 @@ from langchain_core.runnables import run_in_executor from langchain_core.vectorstores import VectorStore, VectorStoreRetriever from langchain_core.pydantic_v1 import Field - from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link - def _has_next(iterator: Iterator) -> bool: """Checks if the iterator has more elements. Warning: consumes an element from the iterator""" @@ -41,13 +39,21 @@ class Node(Serializable): id: Optional[str] = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" + text: str """Text contained by the node.""" + metadata: dict = Field(default_factory=dict) """Metadata for the node.""" + links: Set[Link] = Field(default_factory=set) """Links associated with the node.""" + mime_type: str = "text/plain" + """Type of content, e.g. text/plain or image/png.""" + + mime_encoding: str = None + """Encoding format""" def _texts_to_nodes( texts: Iterable[str], @@ -67,6 +73,7 @@ def _texts_to_nodes( except StopIteration: raise ValueError("texts iterable longer than ids") + mime_type = _metadata.get("mime_type", "text/plain") links = _metadata.pop(METADATA_LINKS_KEY, set()) if not isinstance(links, Set): links = set(links) @@ -74,6 +81,7 @@ def _texts_to_nodes( id=_id, metadata=_metadata, text=text, + mime_type=mime_type, links=links, ) if ids_it and _has_next(ids_it): @@ -94,12 +102,16 @@ def _documents_to_nodes( raise ValueError("documents iterable longer than ids") metadata = doc.metadata.copy() links = metadata.pop(METADATA_LINKS_KEY, set()) + mime_type = metadata.get("mime_type","text/plain") + mime_encoding = metadata.get("mime_encoding") if not isinstance(links, Set): links = set(links) yield Node( id=_id, metadata=metadata, text=doc.page_content, + mime_type=mime_type, + mime_encoding=mime_encoding, links=links, ) if ids_it and _has_next(ids_it): diff --git a/libs/langchain/ragstack_langchain/graph_store/cassandra.py b/libs/langchain/ragstack_langchain/graph_store/cassandra.py index 06df93799..4b10f58f1 100644 --- a/libs/langchain/ragstack_langchain/graph_store/cassandra.py +++ b/libs/langchain/ragstack_langchain/graph_store/cassandra.py @@ -12,25 +12,8 @@ from langchain_core.embeddings import Embeddings from .base import GraphStore, Node, nodes_to_documents -from ragstack_knowledge_store import EmbeddingModel, graph_store - - -class _EmbeddingModelAdapter(EmbeddingModel): - def __init__(self, embeddings: Embeddings): - self.embeddings = embeddings - - def embed_texts(self, texts: List[str]) -> List[List[float]]: - return self.embeddings.embed_documents(texts) - - def embed_query(self, text: str) -> List[float]: - return self.embeddings.embed_query(text) - - async def aembed_texts(self, texts: List[str]) -> List[List[float]]: - return await self.embeddings.aembed_documents(texts) - - async def aembed_query(self, text: str) -> List[float]: - return await self.embeddings.aembed_query(text) - +from .embedding_adapter import EmbeddingAdapter +from ragstack_knowledge_store import graph_store class CassandraGraphStore(GraphStore): def __init__( @@ -60,7 +43,7 @@ def __init__( _setup_mode = getattr(graph_store.SetupMode, setup_mode.name) self.store = graph_store.GraphStore( - embedding=_EmbeddingModelAdapter(embedding), + embedding=EmbeddingAdapter(embedding), node_table=node_table, targets_table=targets_table, session=session, @@ -80,10 +63,8 @@ def add_nodes( _nodes = [] for node in nodes: _nodes.append( - graph_store.Node( - id=node.id, text=node.text, metadata=node.metadata, links=node.links - ) - ) + graph_store.Node(id=node.id, text=node.text, mime_type=node.mime_type, mime_encoding=node.mime_encoding, metadata=node.metadata) + ) return self.store.add_nodes(_nodes) @classmethod diff --git a/libs/langchain/ragstack_langchain/graph_store/embedding_adapter.py b/libs/langchain/ragstack_langchain/graph_store/embedding_adapter.py new file mode 100644 index 000000000..ed4dec78e --- /dev/null +++ b/libs/langchain/ragstack_langchain/graph_store/embedding_adapter.py @@ -0,0 +1,9 @@ +from typing import List +from ragstack_knowledge_store import EmbeddingModel + +class EmbeddingAdapter(EmbeddingModel): + def __init__(self, embeddings): + super().__init__(embeddings, + method_map={'embed_texts': 'embed_documents', + 'aembed_texts': 'aembed_documents'}) +