Skip to content

Commit

Permalink
Merge branch 'make_wandbot_great_again' of https://github.com/wandb/w…
Browse files Browse the repository at this point in the history
…andbot into make_wandbot_great_again
  • Loading branch information
morganmcg1 committed Jan 1, 2025
2 parents a3a1fdd + 03fb31b commit 8567ba1
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 2 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy>=1.26.1
numpy>=2.2.0
pandas>=2.1.2
pydantic>=2.10.0
pydantic-settings>=2.5.1
Expand All @@ -19,7 +19,7 @@ langchain>=0.3.10
langchain-openai>=0.2.14
langchain-experimental>=0.3.3
langchain-core>=0.3.27
langchain-chroma>=0.1.2
chromadb>=0.6.0
weave>=0.51.25
wandb[workspaces]>=0.19.0
fasttext-wheel
Expand Down
41 changes: 41 additions & 0 deletions wandbot/ingestion/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Configuration for vector store."""

import os
from pathlib import Path
from typing import Optional

from pydantic import Field
from pydantic_settings import BaseSettings


class VectorStoreConfig(BaseSettings):
"""Configuration for vector store.
Attributes:
persist_dir: Directory to persist the database
collection_name: Name of the collection
embedding_model: OpenAI embedding model name
embedding_dimensions: Embedding dimensions
openai_api_key: OpenAI API key
"""

persist_dir: Path = Field(
Path("artifacts/wandbot_chroma_index:v0"),
description="Directory to persist the database",
)
collection_name: str = Field(
"vectorstore",
description="Name of the collection",
)
embedding_model: str = Field(
"text-embedding-3-small",
description="OpenAI embedding model name",
)
embedding_dimensions: int = Field(
512,
description="Embedding dimensions",
)
openai_api_key: Optional[str] = Field(
None,
description="OpenAI API key (defaults to env var)",
)
10 changes: 10 additions & 0 deletions wandbot/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Retriever package for wandbot.
This package provides the vector store and retrieval functionality for wandbot.
It includes a native ChromaDB implementation with optimized MMR search.
"""

from wandbot.retriever.base import VectorStore
from wandbot.retriever.native_chroma import NativeChromaWrapper, setup_native_chroma

__all__ = ["VectorStore", "NativeChromaWrapper", "setup_native_chroma"]
75 changes: 75 additions & 0 deletions wandbot/retriever/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Base VectorStore class for wandbot."""

from pathlib import Path
from typing import Optional

from wandbot.ingestion.config import VectorStoreConfig
from wandbot.retriever.native_chroma import NativeChromaWrapper
import chromadb
from chromadb.utils import embedding_functions as chromadb_ef


class VectorStore:
"""Base VectorStore class that handles initialization and configuration."""

def __init__(self, config: VectorStoreConfig):
"""Initialize VectorStore.
Args:
config: VectorStore configuration
"""
self.config = config
self.vectorstore = None

@classmethod
def from_config(cls, config: VectorStoreConfig):
"""Create VectorStore from config.
Args:
config: VectorStore configuration
Returns:
VectorStore instance
"""
instance = cls(config)
instance._initialize()
return instance

def _initialize(self):
"""Initialize the vectorstore."""
# Create persist directory if it doesn't exist
persist_dir = Path(self.config.persist_dir)
persist_dir.mkdir(parents=True, exist_ok=True)

# Initialize chromadb client
client = chromadb.PersistentClient(path=str(persist_dir))

# Initialize OpenAI embeddings
embedding_fn = chromadb_ef.OpenAIEmbeddingFunction(
api_key=self.config.openai_api_key,
model_name=self.config.embedding_model_name,
api_base="https://api.openai.com/v1",
model_kwargs={"dimensions": self.config.embedding_dimensions}
)

# Get or create collection
collection = client.get_or_create_collection(
name=self.config.collection_name,
embedding_function=embedding_fn,
metadata={"hnsw:space": "cosine"} # Use cosine similarity
)

# Create wrapper
self.vectorstore = NativeChromaWrapper(collection, embedding_fn)

def as_retriever(self, *args, **kwargs):
"""Return vectorstore as retriever.
Args:
*args: Positional arguments to pass to vectorstore
**kwargs: Keyword arguments to pass to vectorstore
Returns:
Retriever interface
"""
return self.vectorstore.as_retriever(*args, **kwargs)
173 changes: 173 additions & 0 deletions wandbot/retriever/native_chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Native ChromaDB implementation that uses chromadb's built-in distance metrics.
This module provides a native ChromaDB implementation that replaces the langchain-chroma
dependency. It uses chromadb's built-in distance metrics (cosine, l2, ip) for better
performance and compatibility.
"""

import os
from typing import List, Dict, Any, Optional
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda
import chromadb
from chromadb.utils import embedding_functions as chromadb_ef


class NativeChromaWrapper:
"""Native ChromaDB wrapper that matches langchain-chroma's interface.
This class provides a drop-in replacement for langchain-chroma's Chroma class,
implementing the same interface but using native chromadb operations for better
performance.
"""

def __init__(self, collection, embedding_function):
"""Initialize the wrapper.
Args:
collection: ChromaDB collection
embedding_function: Function to generate embeddings
"""
self.collection = collection
self.embedding_function = embedding_function

def similarity_search(
self,
query: str,
k: int = 2,
filter: Optional[Dict[str, Any]] = None
) -> List[Document]:
"""Perform similarity search.
Args:
query: Query text
k: Number of results to return
filter: Optional metadata filter
Returns:
List of Documents
"""
results = self.collection.query(
query_texts=[query],
n_results=k,
where=filter,
include=['documents', 'metadatas', 'distances']
)

return [
Document(page_content=doc, metadata=meta)
for doc, meta in zip(results['documents'][0], results['metadatas'][0])
]

def max_marginal_relevance_search(
self,
query: str,
k: int = 2,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None
) -> List[Document]:
"""Perform MMR search using chromadb's built-in MMR.
Args:
query: Query text
k: Number of results to return
fetch_k: Number of initial candidates to fetch
lambda_mult: MMR diversity weight
filter: Optional metadata filter
Returns:
List of Documents
"""
# Use chromadb's built-in MMR
results = self.collection.query(
query_texts=[query],
n_results=k,
where=filter,
include=['documents', 'metadatas'],
query_type="mmr",
mmr_lambda=lambda_mult,
mmr_k=fetch_k
)

return [
Document(page_content=doc, metadata=meta)
for doc, meta in zip(results['documents'][0], results['metadatas'][0])
]

def as_retriever(
self,
search_type: str = "mmr",
search_kwargs: Optional[Dict[str, Any]] = None
):
"""Return a retriever interface matching langchain-chroma.
Args:
search_type: Type of search ("similarity" or "mmr")
search_kwargs: Search parameters
Returns:
Retriever callable
"""
if search_kwargs is None:
search_kwargs = {"k": 5}

def retrieve(query: str) -> List[Document]:
if search_type == "mmr":
k = search_kwargs.get("k", 5)
fetch_k = search_kwargs.get("fetch_k", min(k * 2, 20))
lambda_mult = search_kwargs.get("lambda_mult", 0.5)
filter_dict = search_kwargs.get("filter", None)

return self.max_marginal_relevance_search(
query=query,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter_dict
)
else:
return self.similarity_search(
query=query,
**search_kwargs
)

return RunnableLambda(retrieve)


def setup_native_chroma(
persist_dir: str,
collection_name: str,
embedding_model: str = "text-embedding-3-small",
embedding_dimensions: int = 512,
api_key: Optional[str] = None
) -> NativeChromaWrapper:
"""Setup a native chromadb vectorstore.
Args:
persist_dir: Directory to persist the database
collection_name: Name of the collection
embedding_model: OpenAI embedding model name
embedding_dimensions: Embedding dimensions
api_key: Optional OpenAI API key (defaults to env var)
Returns:
NativeChromaWrapper instance
"""
client = chromadb.PersistentClient(path=persist_dir)

# Initialize OpenAI embeddings
embedding_fn = chromadb_ef.OpenAIEmbeddingFunction(
api_key=api_key or os.getenv("OPENAI_API_KEY"),
model_name=embedding_model,
api_base="https://api.openai.com/v1",
model_kwargs={"dimensions": embedding_dimensions}
)

# Get or create collection
collection = client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_fn
)

return NativeChromaWrapper(collection, embedding_fn)

0 comments on commit 8567ba1

Please sign in to comment.