-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'make_wandbot_great_again' of https://github.com/wandb/w…
…andbot into make_wandbot_great_again
- Loading branch information
Showing
5 changed files
with
301 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Configuration for vector store.""" | ||
|
||
import os | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from pydantic import Field | ||
from pydantic_settings import BaseSettings | ||
|
||
|
||
class VectorStoreConfig(BaseSettings): | ||
"""Configuration for vector store. | ||
Attributes: | ||
persist_dir: Directory to persist the database | ||
collection_name: Name of the collection | ||
embedding_model: OpenAI embedding model name | ||
embedding_dimensions: Embedding dimensions | ||
openai_api_key: OpenAI API key | ||
""" | ||
|
||
persist_dir: Path = Field( | ||
Path("artifacts/wandbot_chroma_index:v0"), | ||
description="Directory to persist the database", | ||
) | ||
collection_name: str = Field( | ||
"vectorstore", | ||
description="Name of the collection", | ||
) | ||
embedding_model: str = Field( | ||
"text-embedding-3-small", | ||
description="OpenAI embedding model name", | ||
) | ||
embedding_dimensions: int = Field( | ||
512, | ||
description="Embedding dimensions", | ||
) | ||
openai_api_key: Optional[str] = Field( | ||
None, | ||
description="OpenAI API key (defaults to env var)", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Retriever package for wandbot. | ||
This package provides the vector store and retrieval functionality for wandbot. | ||
It includes a native ChromaDB implementation with optimized MMR search. | ||
""" | ||
|
||
from wandbot.retriever.base import VectorStore | ||
from wandbot.retriever.native_chroma import NativeChromaWrapper, setup_native_chroma | ||
|
||
__all__ = ["VectorStore", "NativeChromaWrapper", "setup_native_chroma"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Base VectorStore class for wandbot.""" | ||
|
||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from wandbot.ingestion.config import VectorStoreConfig | ||
from wandbot.retriever.native_chroma import NativeChromaWrapper | ||
import chromadb | ||
from chromadb.utils import embedding_functions as chromadb_ef | ||
|
||
|
||
class VectorStore: | ||
"""Base VectorStore class that handles initialization and configuration.""" | ||
|
||
def __init__(self, config: VectorStoreConfig): | ||
"""Initialize VectorStore. | ||
Args: | ||
config: VectorStore configuration | ||
""" | ||
self.config = config | ||
self.vectorstore = None | ||
|
||
@classmethod | ||
def from_config(cls, config: VectorStoreConfig): | ||
"""Create VectorStore from config. | ||
Args: | ||
config: VectorStore configuration | ||
Returns: | ||
VectorStore instance | ||
""" | ||
instance = cls(config) | ||
instance._initialize() | ||
return instance | ||
|
||
def _initialize(self): | ||
"""Initialize the vectorstore.""" | ||
# Create persist directory if it doesn't exist | ||
persist_dir = Path(self.config.persist_dir) | ||
persist_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# Initialize chromadb client | ||
client = chromadb.PersistentClient(path=str(persist_dir)) | ||
|
||
# Initialize OpenAI embeddings | ||
embedding_fn = chromadb_ef.OpenAIEmbeddingFunction( | ||
api_key=self.config.openai_api_key, | ||
model_name=self.config.embedding_model_name, | ||
api_base="https://api.openai.com/v1", | ||
model_kwargs={"dimensions": self.config.embedding_dimensions} | ||
) | ||
|
||
# Get or create collection | ||
collection = client.get_or_create_collection( | ||
name=self.config.collection_name, | ||
embedding_function=embedding_fn, | ||
metadata={"hnsw:space": "cosine"} # Use cosine similarity | ||
) | ||
|
||
# Create wrapper | ||
self.vectorstore = NativeChromaWrapper(collection, embedding_fn) | ||
|
||
def as_retriever(self, *args, **kwargs): | ||
"""Return vectorstore as retriever. | ||
Args: | ||
*args: Positional arguments to pass to vectorstore | ||
**kwargs: Keyword arguments to pass to vectorstore | ||
Returns: | ||
Retriever interface | ||
""" | ||
return self.vectorstore.as_retriever(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
"""Native ChromaDB implementation that uses chromadb's built-in distance metrics. | ||
This module provides a native ChromaDB implementation that replaces the langchain-chroma | ||
dependency. It uses chromadb's built-in distance metrics (cosine, l2, ip) for better | ||
performance and compatibility. | ||
""" | ||
|
||
import os | ||
from typing import List, Dict, Any, Optional | ||
from langchain_core.documents import Document | ||
from langchain_core.runnables import RunnableLambda | ||
import chromadb | ||
from chromadb.utils import embedding_functions as chromadb_ef | ||
|
||
|
||
class NativeChromaWrapper: | ||
"""Native ChromaDB wrapper that matches langchain-chroma's interface. | ||
This class provides a drop-in replacement for langchain-chroma's Chroma class, | ||
implementing the same interface but using native chromadb operations for better | ||
performance. | ||
""" | ||
|
||
def __init__(self, collection, embedding_function): | ||
"""Initialize the wrapper. | ||
Args: | ||
collection: ChromaDB collection | ||
embedding_function: Function to generate embeddings | ||
""" | ||
self.collection = collection | ||
self.embedding_function = embedding_function | ||
|
||
def similarity_search( | ||
self, | ||
query: str, | ||
k: int = 2, | ||
filter: Optional[Dict[str, Any]] = None | ||
) -> List[Document]: | ||
"""Perform similarity search. | ||
Args: | ||
query: Query text | ||
k: Number of results to return | ||
filter: Optional metadata filter | ||
Returns: | ||
List of Documents | ||
""" | ||
results = self.collection.query( | ||
query_texts=[query], | ||
n_results=k, | ||
where=filter, | ||
include=['documents', 'metadatas', 'distances'] | ||
) | ||
|
||
return [ | ||
Document(page_content=doc, metadata=meta) | ||
for doc, meta in zip(results['documents'][0], results['metadatas'][0]) | ||
] | ||
|
||
def max_marginal_relevance_search( | ||
self, | ||
query: str, | ||
k: int = 2, | ||
fetch_k: int = 20, | ||
lambda_mult: float = 0.5, | ||
filter: Optional[Dict[str, Any]] = None | ||
) -> List[Document]: | ||
"""Perform MMR search using chromadb's built-in MMR. | ||
Args: | ||
query: Query text | ||
k: Number of results to return | ||
fetch_k: Number of initial candidates to fetch | ||
lambda_mult: MMR diversity weight | ||
filter: Optional metadata filter | ||
Returns: | ||
List of Documents | ||
""" | ||
# Use chromadb's built-in MMR | ||
results = self.collection.query( | ||
query_texts=[query], | ||
n_results=k, | ||
where=filter, | ||
include=['documents', 'metadatas'], | ||
query_type="mmr", | ||
mmr_lambda=lambda_mult, | ||
mmr_k=fetch_k | ||
) | ||
|
||
return [ | ||
Document(page_content=doc, metadata=meta) | ||
for doc, meta in zip(results['documents'][0], results['metadatas'][0]) | ||
] | ||
|
||
def as_retriever( | ||
self, | ||
search_type: str = "mmr", | ||
search_kwargs: Optional[Dict[str, Any]] = None | ||
): | ||
"""Return a retriever interface matching langchain-chroma. | ||
Args: | ||
search_type: Type of search ("similarity" or "mmr") | ||
search_kwargs: Search parameters | ||
Returns: | ||
Retriever callable | ||
""" | ||
if search_kwargs is None: | ||
search_kwargs = {"k": 5} | ||
|
||
def retrieve(query: str) -> List[Document]: | ||
if search_type == "mmr": | ||
k = search_kwargs.get("k", 5) | ||
fetch_k = search_kwargs.get("fetch_k", min(k * 2, 20)) | ||
lambda_mult = search_kwargs.get("lambda_mult", 0.5) | ||
filter_dict = search_kwargs.get("filter", None) | ||
|
||
return self.max_marginal_relevance_search( | ||
query=query, | ||
k=k, | ||
fetch_k=fetch_k, | ||
lambda_mult=lambda_mult, | ||
filter=filter_dict | ||
) | ||
else: | ||
return self.similarity_search( | ||
query=query, | ||
**search_kwargs | ||
) | ||
|
||
return RunnableLambda(retrieve) | ||
|
||
|
||
def setup_native_chroma( | ||
persist_dir: str, | ||
collection_name: str, | ||
embedding_model: str = "text-embedding-3-small", | ||
embedding_dimensions: int = 512, | ||
api_key: Optional[str] = None | ||
) -> NativeChromaWrapper: | ||
"""Setup a native chromadb vectorstore. | ||
Args: | ||
persist_dir: Directory to persist the database | ||
collection_name: Name of the collection | ||
embedding_model: OpenAI embedding model name | ||
embedding_dimensions: Embedding dimensions | ||
api_key: Optional OpenAI API key (defaults to env var) | ||
Returns: | ||
NativeChromaWrapper instance | ||
""" | ||
client = chromadb.PersistentClient(path=persist_dir) | ||
|
||
# Initialize OpenAI embeddings | ||
embedding_fn = chromadb_ef.OpenAIEmbeddingFunction( | ||
api_key=api_key or os.getenv("OPENAI_API_KEY"), | ||
model_name=embedding_model, | ||
api_base="https://api.openai.com/v1", | ||
model_kwargs={"dimensions": embedding_dimensions} | ||
) | ||
|
||
# Get or create collection | ||
collection = client.get_or_create_collection( | ||
name=collection_name, | ||
embedding_function=embedding_fn | ||
) | ||
|
||
return NativeChromaWrapper(collection, embedding_fn) |