Skip to content

Commit

Permalink
Add partitioning func capabilities to allow doc-types-based embedding…
Browse files Browse the repository at this point in the history
… ranking (#752)

Co-authored-by: James Braza <[email protected]>
  • Loading branch information
mskarlin and jamesbraza authored Dec 10, 2024
1 parent 08026d3 commit e3623ed
Show file tree
Hide file tree
Showing 6 changed files with 2,929 additions and 8 deletions.
6 changes: 5 additions & 1 deletion paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paperqa.docs import Docs
from paperqa.llms import EmbeddingModel, LiteLLMModel
from paperqa.settings import Settings
from paperqa.types import DocDetails, PQASession
from paperqa.types import DocDetails, Embeddable, PQASession

from .search import get_directory_index

Expand Down Expand Up @@ -193,6 +193,7 @@ class GatherEvidence(NamedTool):
settings: Settings
summary_llm_model: LiteLLMModel
embedding_model: EmbeddingModel
partitioning_fn: Callable[[Embeddable], int] | None = None

async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
"""
Expand Down Expand Up @@ -236,6 +237,7 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
settings=self.settings,
embedding_model=self.embedding_model,
summary_llm_model=self.summary_llm_model,
partitioning_fn=self.partitioning_fn,
callbacks=self.settings.agent.callbacks.get(
f"{self.TOOL_FN_NAME}_aget_evidence"
),
Expand Down Expand Up @@ -275,6 +277,7 @@ class GenerateAnswer(NamedTool):
llm_model: LiteLLMModel
summary_llm_model: LiteLLMModel
embedding_model: EmbeddingModel
partitioning_fn: Callable[[Embeddable], int] | None = None

async def gen_answer(self, state: EnvironmentState) -> str:
"""
Expand Down Expand Up @@ -305,6 +308,7 @@ async def gen_answer(self, state: EnvironmentState) -> str:
llm_model=self.llm_model,
summary_llm_model=self.summary_llm_model,
embedding_model=self.embedding_model,
partitioning_fn=self.partitioning_fn,
callbacks=self.settings.agent.callbacks.get(
f"{self.TOOL_FN_NAME}_aget_query"
),
Expand Down
21 changes: 19 additions & 2 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Doc,
DocDetails,
DocKey,
Embeddable,
LLMResult,
PQASession,
Text,
Expand Down Expand Up @@ -518,6 +519,7 @@ async def retrieve_texts(
k: int,
settings: MaybeSettings = None,
embedding_model: EmbeddingModel | None = None,
partitioning_fn: Callable[[Embeddable], int] | None = None,
) -> list[Text]:

settings = get_settings(settings)
Expand All @@ -533,7 +535,11 @@ async def retrieve_texts(
list[Text],
(
await self.texts_index.max_marginal_relevance_search(
query, k=_k, fetch_k=2 * _k, embedding_model=embedding_model
query,
k=_k,
fetch_k=2 * _k,
embedding_model=embedding_model,
partitioning_fn=partitioning_fn,
)
)[0],
)
Expand All @@ -548,6 +554,7 @@ def get_evidence(
callbacks: list[Callable] | None = None,
embedding_model: EmbeddingModel | None = None,
summary_llm_model: LLMModel | None = None,
partitioning_fn: Callable[[Embeddable], int] | None = None,
) -> PQASession:
return get_loop().run_until_complete(
self.aget_evidence(
Expand All @@ -557,6 +564,7 @@ def get_evidence(
callbacks=callbacks,
embedding_model=embedding_model,
summary_llm_model=summary_llm_model,
partitioning_fn=partitioning_fn,
)
)

Expand All @@ -568,6 +576,7 @@ async def aget_evidence(
callbacks: list[Callable] | None = None,
embedding_model: EmbeddingModel | None = None,
summary_llm_model: LLMModel | None = None,
partitioning_fn: Callable[[Embeddable], int] | None = None,
) -> PQASession:

evidence_settings = get_settings(settings)
Expand Down Expand Up @@ -600,7 +609,11 @@ async def aget_evidence(

if answer_config.evidence_retrieval:
matches = await self.retrieve_texts(
session.question, _k, evidence_settings, embedding_model
session.question,
_k,
evidence_settings,
embedding_model,
partitioning_fn=partitioning_fn,
)
else:
matches = self.texts
Expand Down Expand Up @@ -662,6 +675,7 @@ def query(
llm_model: LLMModel | None = None,
summary_llm_model: LLMModel | None = None,
embedding_model: EmbeddingModel | None = None,
partitioning_fn: Callable[[Embeddable], int] | None = None,
) -> PQASession:
return get_loop().run_until_complete(
self.aquery(
Expand All @@ -671,6 +685,7 @@ def query(
llm_model=llm_model,
summary_llm_model=summary_llm_model,
embedding_model=embedding_model,
partitioning_fn=partitioning_fn,
)
)

Expand All @@ -682,6 +697,7 @@ async def aquery( # noqa: PLR0912
llm_model: LLMModel | None = None,
summary_llm_model: LLMModel | None = None,
embedding_model: EmbeddingModel | None = None,
partitioning_fn: Callable[[Embeddable], int] | None = None,
) -> PQASession:

query_settings = get_settings(settings)
Expand Down Expand Up @@ -709,6 +725,7 @@ async def aquery( # noqa: PLR0912
settings=settings,
embedding_model=embedding_model,
summary_llm_model=summary_llm_model,
partitioning_fn=partitioning_fn,
)
contexts = session.contexts
pre_str = None
Expand Down
93 changes: 89 additions & 4 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import asyncio
import contextlib
import functools
import itertools
import logging
from abc import ABC, abstractmethod
from collections.abc import (
AsyncGenerator,
Expand Down Expand Up @@ -43,6 +45,8 @@

MODEL_COST_MAP = litellm.get_model_cost_map("")

logger = logging.getLogger(__name__)


def prepare_args(func: Callable, chunk: str, name: str | None) -> tuple[tuple, dict]:
with contextlib.suppress(TypeError):
Expand Down Expand Up @@ -802,8 +806,35 @@ async def similarity_search(
def clear(self) -> None:
self.texts_hashes = set()

async def partitioned_similarity_search(
self,
query: str,
k: int,
embedding_model: EmbeddingModel,
partitioning_fn: Callable[[Embeddable], int],
) -> tuple[Sequence[Embeddable], list[float]]:
"""Partition the documents into different groups and perform similarity search.
Args:
query: query string
k: Number of results to return
embedding_model: model used to embed the query
partitioning_fn: function to partition the documents into different groups.
Returns:
Tuple of lists of Embeddables and scores of length k.
"""
raise NotImplementedError(
"partitioned_similarity_search is not implemented for this VectorStore."
)

async def max_marginal_relevance_search(
self, query: str, k: int, fetch_k: int, embedding_model: EmbeddingModel
self,
query: str,
k: int,
fetch_k: int,
embedding_model: EmbeddingModel,
partitioning_fn: Callable[[Embeddable], int] | None = None,
) -> tuple[Sequence[Embeddable], list[float]]:
"""Vectorized implementation of Maximal Marginal Relevance (MMR) search.
Expand All @@ -812,14 +843,24 @@ async def max_marginal_relevance_search(
k: Number of results to return.
fetch_k: Number of results to fetch from the vector store.
embedding_model: model used to embed the query
partitioning_fn: optional function to partition the documents into
different groups, performing MMR within each group.
Returns:
List of tuples (doc, score) of length k.
"""
if fetch_k < k:
raise ValueError("fetch_k must be greater or equal to k")

texts, scores = await self.similarity_search(query, fetch_k, embedding_model)
if partitioning_fn is None:
texts, scores = await self.similarity_search(
query, fetch_k, embedding_model
)
else:
texts, scores = await self.partitioned_similarity_search(
query, fetch_k, embedding_model, partitioning_fn
)

if len(texts) <= k or self.mmr_lambda >= 1.0:
return texts, scores

Expand Down Expand Up @@ -852,6 +893,7 @@ async def max_marginal_relevance_search(
class NumpyVectorStore(VectorStore):
texts: list[Embeddable] = Field(default_factory=list)
_embeddings_matrix: np.ndarray | None = None
_texts_filter: np.ndarray | None = None

def __eq__(self, other) -> bool:
if not isinstance(other, type(self)):
Expand All @@ -875,12 +917,47 @@ def clear(self) -> None:
super().clear()
self.texts = []
self._embeddings_matrix = None
self._texts_filter = None

def add_texts_and_embeddings(self, texts: Iterable[Embeddable]) -> None:
super().add_texts_and_embeddings(texts)
self.texts.extend(texts)
self._embeddings_matrix = np.array([t.embedding for t in self.texts])

async def partitioned_similarity_search(
self,
query: str,
k: int,
embedding_model: EmbeddingModel,
partitioning_fn: Callable[[Embeddable], int],
) -> tuple[Sequence[Embeddable], list[float]]:
scores: list[list[float]] = []
texts: list[Sequence[Embeddable]] = []

text_partitions = np.array([partitioning_fn(t) for t in self.texts])
# CPU bound so replacing w a gather wouldn't get us anything
# plus we need to reset self._texts_filter each iteration
for partition in np.unique(text_partitions):
self._texts_filter = text_partitions == partition
_texts, _scores = await self.similarity_search(query, k, embedding_model)
texts.append(_texts)
scores.append(_scores)
# reset the filter after running
self._texts_filter = None

return (
[
t
for t in itertools.chain.from_iterable(itertools.zip_longest(*texts))
if t is not None
][:k],
[
s
for s in itertools.chain.from_iterable(itertools.zip_longest(*scores))
if s is not None
][:k],
)

async def similarity_search(
self, query: str, k: int, embedding_model: EmbeddingModel
) -> tuple[Sequence[Embeddable], list[float]]:
Expand All @@ -895,16 +972,24 @@ async def similarity_search(

embedding_model.set_mode(EmbeddingModes.DOCUMENT)

embedding_matrix = self._embeddings_matrix

if self._texts_filter is not None:
original_indices = np.where(self._texts_filter)[0]
embedding_matrix = embedding_matrix[self._texts_filter] # type: ignore[index]
else:
original_indices = np.arange(len(self.texts))

similarity_scores = cosine_similarity(
np_query.reshape(1, -1), self._embeddings_matrix
np_query.reshape(1, -1), embedding_matrix
)[0]
similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf)
# minus so descending
# we could use arg-partition here
# but a lot of algorithms expect a sorted list
sorted_indices = np.argsort(-similarity_scores)
return (
[self.texts[i] for i in sorted_indices[:k]],
[self.texts[i] for i in original_indices[sorted_indices][:k]],
[similarity_scores[i] for i in sorted_indices[:k]],
)

Expand Down
Loading

0 comments on commit e3623ed

Please sign in to comment.