From 04bcda183e56bfdd00a2ce82ce1f1057994642f4 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 3 Jan 2024 16:55:21 -0800 Subject: [PATCH 01/16] First draft without langchain --- paperqa/chains.py | 169 +++++++++------- paperqa/docs.py | 460 +++++++++++++++++------------------------- paperqa/prompts.py | 25 +-- paperqa/readers.py | 12 +- paperqa/types.py | 291 +++++++++++++------------- paperqa/utils.py | 8 - paperqa/version.py | 2 +- setup.py | 7 +- tests/test_paperqa.py | 66 +----- 9 files changed, 460 insertions(+), 580 deletions(-) diff --git a/paperqa/chains.py b/paperqa/chains.py index 75894ce6c..d06c8128f 100644 --- a/paperqa/chains.py +++ b/paperqa/chains.py @@ -1,89 +1,116 @@ import re -from typing import Any, Dict, List, Optional, cast +from typing import Callable -from langchain.base_language import BaseLanguageModel -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.chains import LLMChain -from langchain.chat_models import ChatOpenAI -from langchain.memory.chat_memory import BaseChatMemory -from langchain.prompts import PromptTemplate, StringPromptTemplate -from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.schema import LLMResult, SystemMessage +from openai import AsyncOpenAI from .prompts import default_system_prompt -from .types import CBManager -memory_prompt = PromptTemplate( - input_variables=["memory", "start"], - template="Here are previous questions and answers, which may be referenced in subsequent questions:\n\n{memory}\n\n" - "----------------------------------------\n\n" - "{start}", -) +default_system_prompt = "End your responses with [END]" -class FallbackLLMChain(LLMChain): - """Chain that falls back to synchronous generation if the async generation fails.""" - - async def agenerate( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[CBManager] = None, - ) -> LLMResult: - """Generate LLM result from inputs.""" - try: - run_manager = cast(AsyncCallbackManagerForChainRun, run_manager) - return await super().agenerate(input_list, run_manager=run_manager) - except NotImplementedError: - run_manager = cast(CallbackManagerForChainRun, run_manager) - return self.generate(input_list) +def process_llm_config(llm_config: dict) -> dict: + """Remove model_type and try to set max_tokens""" + result = {k: v for k, v in llm_config.items() if k != "model_type"} + if "max_tokens" not in result or result["max_tokens"] == -1: + model = llm_config["model"] + # now we guess! + if model.startswith("gpt-4") or ( + model.startswith("gpt-3.5") and "1106" in model + ): + result["max_tokens"] = 4096 + else: + result["max_tokens"] = 2048 # ? + return result -# TODO: If upstream is fixed remove this +def make_chain( + client: AsyncOpenAI, + prompt: str, + llm_config: dict, + skip_system: bool = False, + system_prompt: str = default_system_prompt, +) -> Callable[[list[dict], list[Callable[[str], None]] | None], list[str]]: + """Create a function to execute a batch of prompts + Args: + client: OpenAI client + prompt: The prompt to use + llm_config: The config to use + skip_system: Whether to skip the system prompt + system_prompt: The system prompt to use -class ExtendedHumanMessagePromptTemplate(HumanMessagePromptTemplate): - prompt: StringPromptTemplate + Returns: + A function to execute a prompt. Its signature is: + execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> str + where data is a dict with keys for the input variables that will be formatted into prompt + and callbacks is a list of functions to call with each chunk of the completion. + """ + if llm_config["model_type"] == "chat": + system_message_prompt = dict(role="system", content=system_prompt) + human_message_prompt = dict(role="user", content=prompt) + if skip_system: + chat_prompt = [human_message_prompt] + else: + chat_prompt = [system_message_prompt, human_message_prompt] + async def execute( + data: dict, callbacks: list[Callable[[str], None]] | None = None + ) -> str: + messages = chat_prompt[:-1] + [ + dict(role="user", content=chat_prompt[-1]["content"].format(**data)) + ] + if callbacks is None: + completion = await client.chat.completions.create( + messages=messages, **process_llm_config(llm_config) + ) + output = completion.choices[0].message.content + else: + completion = await client.chat.completions.create( + messages=messages, **process_llm_config(llm_config), stream=True + ) + result = [] + async for chunk in completion: + c = chunk.choices[0].delta.content + if c: + result.append(c) + [f(c) for f in callbacks] + output = "".join(result) + return output -def make_chain( - prompt: StringPromptTemplate, - llm: BaseLanguageModel, - skip_system: bool = False, - memory: Optional[BaseChatMemory] = None, - system_prompt: str = default_system_prompt, -) -> FallbackLLMChain: - if memory and len(memory.load_memory_variables({})["memory"]) > 0: - # we copy the prompt so we don't modify the original - # TODO: Figure out pipeline prompts to avoid this - # the problem with pipeline prompts is that - # the memory is a constant (or partial), not a prompt - # and I cannot seem to make an empty prompt (or str) - # work as an input to pipeline prompt - assert isinstance( - prompt, PromptTemplate - ), "Memory only works with prompt templates - see comment above" - assert "memory" in memory.load_memory_variables({}) - new_prompt = PromptTemplate( - input_variables=prompt.input_variables, - template=memory_prompt.format( - start=prompt.template, **memory.load_memory_variables({}) - ), - ) - prompt = new_prompt - if type(llm) == ChatOpenAI: - system_message_prompt = SystemMessage(content=system_prompt) - human_message_prompt = ExtendedHumanMessagePromptTemplate(prompt=prompt) + return execute + elif llm_config["model_type"] == "completion": if skip_system: - chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt]) + completion_prompt = prompt else: - chat_prompt = ChatPromptTemplate.from_messages( - [system_message_prompt, human_message_prompt] - ) - return FallbackLLMChain(prompt=chat_prompt, llm=llm) - return FallbackLLMChain(prompt=prompt, llm=llm) + completion_prompt = system_prompt + "\n\n" + prompt + + async def execute( + data: dict, callbacks: list[Callable[[str], None]] | None = None + ) -> str: + if callbacks is None: + completion = await client.completions.create( + prompt=completion_prompt.format(**data), + **process_llm_config(llm_config), + ) + output = completion.choices[0].text + else: + completion = await client.completions.create( + prompt=completion_prompt.format(**data), + **process_llm_config(llm_config), + stream=True, + ) + result = [] + async for chunk in completion: + c = chunk.choices[0].text + if c: + result.append(c) + [f(c) for f in callbacks] + output = "".join(result) + return output + + return execute + else: + raise NotImplementedError(f"Unknown model type {llm_config['model_type']}") def get_score(text: str) -> int: diff --git a/paperqa/docs.py b/paperqa/docs.py index cdc064d25..70b8e30f6 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -6,29 +6,27 @@ from datetime import datetime from io import BytesIO from pathlib import Path -from typing import BinaryIO, Dict, List, Optional, Set, Union, cast - -from langchain.chat_models import ChatOpenAI -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.memory import ConversationTokenBufferMemory -from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema.embeddings import Embeddings -from langchain.schema.language_model import BaseLanguageModel -from langchain.schema.vectorstore import VectorStore -from langchain.vectorstores import FAISS - -try: - from pydantic.v1 import BaseModel, validator -except ImportError: - from pydantic import BaseModel, validator +from typing import Any, BinaryIO, cast + +from openai import AsyncOpenAI +from pydantic import BaseModel, Field, field_validator, model_validator from .chains import get_score, make_chain from .paths import PAPERQA_DIR from .readers import read_doc -from .types import Answer, CallbackFactory, Context, Doc, DocKey, PromptCollection, Text +from .types import ( + Answer, + CallbackFactory, + Context, + Doc, + DocKey, + NumpyVectorStore, + PromptCollection, + Text, + VectorStore, +) from .utils import ( gather_with_concurrency, - get_llm_name, guess_is_4xx, maybe_is_html, maybe_is_pdf, @@ -39,80 +37,56 @@ ) -class Docs(BaseModel, arbitrary_types_allowed=True, smart_union=True): +class Docs(BaseModel): """A collection of documents to be used for answering questions.""" - docs: Dict[DocKey, Doc] = {} - texts: List[Text] = [] - docnames: Set[str] = set() - texts_index: Optional[VectorStore] = None - doc_index: Optional[VectorStore] = None - llm: Union[str, BaseLanguageModel] = ChatOpenAI( - temperature=0.1, model="gpt-3.5-turbo", client=None - ) - summary_llm: Optional[Union[str, BaseLanguageModel]] = None + _client: AsyncOpenAI + docs: dict[DocKey, Doc] = {} + texts: list[Text] = [] + docnames: set[str] = set() + texts_index: VectorStore = NumpyVectorStore() + doc_index: VectorStore = NumpyVectorStore() + llm_config: dict = dict(model="gpt-3.5-turbo") + summary_llm_config: dict | None = Field(default=None, validate_default=True) name: str = "default" - index_path: Optional[Path] = PAPERQA_DIR / name - embeddings: Embeddings = OpenAIEmbeddings(client=None) + index_path: Path | None = PAPERQA_DIR / name + embeddings_model: str = "text-embedding-ada-002" + batch_size: int = 1 max_concurrent: int = 5 - deleted_dockeys: Set[DocKey] = set() + deleted_dockeys: set[DocKey] = set() prompts: PromptCollection = PromptCollection() - memory: bool = False - memory_model: Optional[BaseChatMemory] = None jit_texts_index: bool = False # This is used to strip indirect citations that come up from the summary llm strip_citations: bool = True - # TODO: Not sure how to get this to work - # while also passing mypy checks - @validator("llm", "summary_llm") - def check_llm(cls, v: Union[BaseLanguageModel, str]) -> BaseLanguageModel: - if type(v) is str: - return ChatOpenAI(temperature=0.1, model=v, client=None) - return cast(BaseLanguageModel, v) - - @validator("summary_llm", always=True) - def copy_llm_if_not_set(cls, v, values): - return v or values["llm"] - - @validator("memory_model", always=True) - def check_memory_model(cls, v, values): - if values["memory"]: - if v is None: - return ConversationTokenBufferMemory( - llm=values["summary_llm"], - max_token_limit=512, - memory_key="memory", - human_prefix="Question", - ai_prefix="Answer", - input_key="Question", - output_key="Answer", - ) - if v.memory_variables()[0] != "memory": - raise ValueError("Memory model must have memory_variables=['memory']") - return values["memory_model"] - return None + def __init__(self, **data): + if "client" in data: + client = data.pop("client") + else: + client = AsyncOpenAI() + super().__init__(**data) + self._client = client + + @field_validator("llm_config", "summary_llm_config") + @classmethod + def llm_config_has_type(cls, v): + if v is not None and "model_type" not in v: + raise ValueError("Must specify if chat or completion model.") + return v + + @model_validator(mode="after") + @classmethod + def config_summary_llm_conig(cls, data: Any) -> Any: + if isinstance(data, dict): + if data["summary_llm_config"] is None: + data["summary_llm_config"] = data["llm_config"] + return data def clear_docs(self): self.texts = [] self.docs = {} self.docnames = set() - def update_llm( - self, - llm: Union[BaseLanguageModel, str], - summary_llm: Optional[Union[BaseLanguageModel, str]] = None, - ) -> None: - """Update the LLM for answering questions.""" - if type(llm) is str: - llm = ChatOpenAI(temperature=0.1, model=llm, client=None) - if type(summary_llm) is str: - summary_llm = ChatOpenAI(temperature=0.1, model=summary_llm, client=None) - self.llm = cast(BaseLanguageModel, llm) - if summary_llm is None: - summary_llm = llm - self.summary_llm = cast(BaseLanguageModel, summary_llm) - def _get_unique_name(self, docname: str) -> str: """Create a unique name given proposed name""" suffix = "" @@ -128,11 +102,11 @@ def _get_unique_name(self, docname: str) -> str: def add_file( self, file: BinaryIO, - citation: Optional[str] = None, - docname: Optional[str] = None, - dockey: Optional[DocKey] = None, + citation: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, chunk_chars: int = 3000, - ) -> Optional[str]: + ) -> str | None: """Add a document to the collection.""" # just put in temp file and use existing method suffix = ".txt" @@ -155,11 +129,11 @@ def add_file( def add_url( self, url: str, - citation: Optional[str] = None, - docname: Optional[str] = None, - dockey: Optional[DocKey] = None, + citation: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, chunk_chars: int = 3000, - ) -> Optional[str]: + ) -> str | None: """Add a document to the collection.""" import urllib.request @@ -177,20 +151,21 @@ def add_url( def add( self, path: Path, - citation: Optional[str] = None, - docname: Optional[str] = None, + citation: str | None = None, + docname: str | None = None, disable_check: bool = False, - dockey: Optional[DocKey] = None, + dockey: DocKey | None = None, chunk_chars: int = 3000, - ) -> Optional[str]: + ) -> str | None: """Add a document to the collection.""" if dockey is None: dockey = md5sum(path) if citation is None: # skip system because it's too hesitant to answer cite_chain = make_chain( + client=self._client, prompt=self.prompts.cite, - llm=cast(BaseLanguageModel, self.summary_llm), + llm_config=self.summary_llm_config, skip_system=True, ) # peak first chunk @@ -198,7 +173,12 @@ def add( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - citation = cite_chain.run(texts[0].text) + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + citation = loop.run_until_complete(cite_chain(texts[0].text)) if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -237,7 +217,7 @@ def add( def add_texts( self, - texts: List[Text], + texts: list[Text], doc: Doc, ) -> bool: """Add chunked texts to the collection. This is useful if you have already chunked the texts yourself. @@ -253,34 +233,21 @@ def add_texts( for t in texts: t.name = t.name.replace(doc.docname, new_docname) doc.docname = new_docname - if texts[0].embeddings is None: + if texts[0].embedding is None: text_embeddings = self.embeddings.embed_documents([t.text for t in texts]) for i, t in enumerate(texts): - t.embeddings = text_embeddings[i] - else: - text_embeddings = cast(List[List[float]], [t.embeddings for t in texts]) - if self.texts_index is not None: - try: - # TODO: Simplify - super weird - vec_store_text_and_embeddings = list( - map(lambda x: (x.text, x.embeddings), texts) - ) - self.texts_index.add_embeddings( # type: ignore - vec_store_text_and_embeddings, - metadatas=[t.dict(exclude={"embeddings", "text"}) for t in texts], - ) - except AttributeError: - raise ValueError("Need a vector store that supports adding embeddings.") - if self.doc_index is not None: - self.doc_index.add_texts([doc.citation], metadatas=[doc.dict()]) + t.embedding = text_embeddings[i] + if doc.embedding is None: + doc.embedding = self.embeddings.embed_documents([doc.citation])[0] + if not self.jit_texts_index: + self.texts_index.add_texts_and_embeddings(texts) + self.doc_index.add_texts_and_embeddings(doc) self.docs[doc.dockey] = doc self.texts += texts self.docnames.add(doc.docname) return True - def delete( - self, name: Optional[str] = None, dockey: Optional[DocKey] = None - ) -> None: + def delete(self, name: str | None = None, dockey: DocKey | None = None) -> None: """Delete a document from the collection.""" if name is not None: doc = next((doc for doc in self.docs.values() if doc.docname == name), None) @@ -295,48 +262,33 @@ async def adoc_match( self, query: str, k: int = 25, - rerank: Optional[bool] = None, + rerank: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, - ) -> Set[DocKey]: + ) -> set[DocKey]: """Return a list of dockeys that match the query.""" - if self.doc_index is None: - if len(self.docs) == 0: - return set() - texts = [doc.citation for doc in self.docs.values()] - metadatas = [d.dict() for d in self.docs.values()] - self.doc_index = FAISS.from_texts( - texts, metadatas=metadatas, embedding=self.embeddings - ) - matches = self.doc_index.max_marginal_relevance_search( + matches, _ = self.doc_index.max_marginal_relevance_search( query, k=k + len(self.deleted_dockeys) ) # filter the matches - matches = [ - m for m in matches if m.metadata["dockey"] not in self.deleted_dockeys - ] - try: - # for backwards compatibility (old pickled objects) - matched_docs = [self.docs[m.metadata["dockey"]] for m in matches] - except KeyError: - matched_docs = [Doc(**m.metadata) for m in matches] + matched_docs = [m for m in matches if m.dockey not in self.deleted_dockeys] if len(matched_docs) == 0: return set() # this only works for gpt-4 (in my testing) try: if ( rerank is None - and get_llm_name(cast(BaseLanguageModel, self.llm)).startswith("gpt-4") + and self.llm_config.model.startswith("gpt-4") or rerank is True ): chain = make_chain( - self.prompts.select, - cast(BaseLanguageModel, self.llm), + client=self._client, + prompt=self.prompts.select, + llm_config=self.llm_config, skip_system=True, ) papers = [f"{d.docname}: {d.citation}" for d in matched_docs] - result = await chain.arun( # type: ignore - question=query, - papers="\n".join(papers), + result = await chain( + data=[dict(question=query, papers="\n".join(papers))], callbacks=get_callbacks("filter"), ) return set([d.dockey for d in matched_docs if d.docname in result]) @@ -344,48 +296,15 @@ async def adoc_match( pass return set([d.dockey for d in matched_docs]) - def __getstate__(self): - state = self.__dict__.copy() - if self.texts_index is not None and self.index_path is not None: - state["texts_index"].save_local(self.index_path) - del state["texts_index"] - del state["doc_index"] - return {"__dict__": state, "__fields_set__": self.__fields_set__} - - def __setstate__(self, state): - object.__setattr__(self, "__dict__", state["__dict__"]) - object.__setattr__(self, "__fields_set__", state["__fields_set__"]) - try: - self.texts_index = FAISS.load_local(self.index_path, self.embeddings) - except Exception: - # they use some special exception type, but I don't want to import it - self.texts_index = None - self.doc_index = None - - def _build_texts_index(self, keys: Optional[Set[DocKey]] = None): + def _build_texts_index(self, keys: set[DocKey] = None): if keys is not None and self.jit_texts_index: - del self.texts_index - self.texts_index = None - if self.texts_index is None: texts = self.texts if keys is not None: texts = [t for t in texts if t.doc.dockey in keys] if len(texts) == 0: return - raw_texts = [t.text for t in texts] - text_embeddings = [t.embeddings for t in texts] - metadatas = [t.dict(exclude={"embeddings", "text"}) for t in texts] - self.texts_index = FAISS.from_embeddings( - # wow adding list to the zip was tricky - text_embeddings=list(zip(raw_texts, text_embeddings)), - embedding=self.embeddings, - metadatas=metadatas, - ) - - def clear_memory(self): - """Clear the memory of the model.""" - if self.memory_model is not None: - self.memory_model.clear() + self.texts_index.clear() + self.texts_index.add_texts_and_embeddings(texts) def get_evidence( self, @@ -424,7 +343,7 @@ def get_evidence( async def aget_evidence( self, answer: Answer, - k: int = 10, # Number of vectors to retrieve + k: int = 10, # Number of evidence pieces to retrieve max_sources: int = 5, # Number of scored contexts to use marginal_relevance: bool = True, get_callbacks: CallbackFactory = lambda x: None, @@ -432,116 +351,100 @@ async def aget_evidence( disable_vector_search: bool = False, disable_summarization: bool = False, ) -> Answer: - if disable_vector_search: - k = k * 10000 if len(self.docs) == 0 and self.doc_index is None: + # do we have no docs? return answer self._build_texts_index(keys=answer.dockey_filter) - if self.texts_index is None: - return answer self.texts_index = cast(VectorStore, self.texts_index) _k = k if answer.dockey_filter is not None: - _k = k * 10 # heuristic - if marginal_relevance: - matches = self.texts_index.max_marginal_relevance_search( - answer.question, k=_k, fetch_k=5 * _k - ) + _k = k * 10 # heuristic - get enough so we can downselect + if disable_vector_search: + matches = self.texts else: - matches = self.texts_index.similarity_search( - answer.question, k=_k, fetch_k=5 * _k - ) - # ok now filter + if marginal_relevance: + matches = self.texts_index.max_marginal_relevance_search( + answer.question, k=_k, fetch_k=5 * _k + ) + else: + matches = self.texts_index.similarity_search(answer.question, k=_k) + # ok now filter (like ones from adoc_match) if answer.dockey_filter is not None: - matches = [ - m - for m in matches - if m.metadata["doc"]["dockey"] in answer.dockey_filter - ] + matches = [m for m in matches if m.doc.dockey in answer.dockey_filter] # check if it is deleted - matches = [ - m - for m in matches - if m.metadata["doc"]["dockey"] not in self.deleted_dockeys - ] + matches = [m for m in matches if m.doc.dockey not in self.deleted_dockeys] # check if it is already in answer cur_names = [c.text.name for c in answer.contexts] - matches = [m for m in matches if m.metadata["name"] not in cur_names] + matches = [m for m in matches if m.name not in cur_names] # now finally cut down matches = matches[:k] async def process(match): - callbacks = get_callbacks("evidence:" + match.metadata["name"]) - summary_chain = make_chain( - self.prompts.summary, - self.summary_llm, - memory=self.memory_model, - system_prompt=self.prompts.system, - ) - # This is dangerous because it - # could mask errors that are important- like auth errors - # I also cannot know what the exception - # type is because any model could be used - # my best idea is see if there is a 4XX - # http code in the exception - try: - citation = match.metadata["doc"]["citation"] - if detailed_citations: - citation = match.metadata["name"] + ": " + citation - if self.prompts.skip_summary: - context = match.page_content - else: - context = await summary_chain.arun( - question=answer.question, - # Add name so chunk is stated - citation=citation, - summary_length=answer.summary_length, - text=match.page_content, + callbacks = get_callbacks("evidence:" + match.name) + citation = match.doc.citation + if detailed_citations: + citation = match.name + ": " + citation + + if self.prompts.skip_summary or disable_summarization: + context = match.text + score = 5 + else: + summary_chain = make_chain( + client=self._client, + prompt=self.prompts.summary, + llm_config=self.summary_llm_config, + system_prompt=self.prompts.system, + ) + # This is dangerous because it + # could mask errors that are important- like auth errors + # I also cannot know what the exception + # type is because any model could be used + # my best idea is see if there is a 4XX + # http code in the exception + try: + context = await summary_chain( + dict( + question=answer.question, + # Add name so chunk is stated + citation=citation, + summary_length=answer.summary_length, + text=match.text, + ), callbacks=callbacks, ) - except Exception as e: - if guess_is_4xx(str(e)): + except Exception as e: + if guess_is_4xx(str(e)): + return None + raise e + if ( + "not applicable" in context.lower() + or "not relevant" in context.lower() + ): return None - raise e - if "not applicable" in context.lower() or "not relevant" in context.lower(): - return None - if self.strip_citations: - # remove citations that collide with our grounded citations (for the answer LLM) - context = strip_citations(context) + if self.strip_citations: + # remove citations that collide with our grounded citations (for the answer LLM) + context = strip_citations(context) + score = get_score(context) c = Context( context=context, + # below will remove embedding from Text/Doc text=Text( - text=match.page_content, - name=match.metadata["name"], - doc=Doc(**match.metadata["doc"]), + text=match.text, + name=match.name, + doc=Doc(**match.doc), ), - score=get_score(context), + score=score, ) return c - if disable_summarization: - contexts = [ - Context( - context=match.page_content, - score=10, - text=Text( - text=match.page_content, - name=match.metadata["name"], - doc=Doc(**match.metadata["doc"]), - ), - ) - for match in matches - ] - - else: - results = await gather_with_concurrency( - self.max_concurrent, *[process(m) for m in matches] - ) - # filter out failures - contexts = [c for c in results if c is not None] + results = await gather_with_concurrency( + self.max_concurrent, *[process(m) for m in matches] + ) + # filter out failures + contexts = [c for c in results if c is not None] answer.contexts = sorted( contexts + answer.contexts, key=lambda x: x.score, reverse=True @@ -567,8 +470,8 @@ def query( max_sources: int = 5, length_prompt="about 100 words", marginal_relevance: bool = True, - answer: Optional[Answer] = None, - key_filter: Optional[bool] = None, + answer: Answer | None = None, + key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, ) -> Answer: # special case for jupyter notebooks @@ -601,8 +504,8 @@ async def aquery( max_sources: int = 5, length_prompt: str = "about 100 words", marginal_relevance: bool = True, - answer: Optional[Answer] = None, - key_filter: Optional[bool] = None, + answer: Answer | None = None, + key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, ) -> Answer: if k < max_sources: @@ -627,13 +530,13 @@ async def aquery( ) if self.prompts.pre is not None: chain = make_chain( - self.prompts.pre, - cast(BaseLanguageModel, self.llm), - memory=self.memory_model, + client=self._client, + prompt=self.prompts.pre, + llm_config=self.llm_config, system_prompt=self.prompts.system, ) - pre = await chain.arun( - question=answer.question, callbacks=get_callbacks("pre") + pre = await chain( + data=dict(question=answer.question), callbacks=get_callbacks("pre") ) answer.context = answer.context + "\n\nExtra background information:" + pre bib = dict() @@ -644,17 +547,18 @@ async def aquery( else: callbacks = get_callbacks("answer") qa_chain = make_chain( - self.prompts.qa, - cast(BaseLanguageModel, self.llm), - memory=self.memory_model, + client=self._client, + prompt=self.prompts.qa, + llm_config=self.llm_config, system_prompt=self.prompts.system, ) - answer_text = await qa_chain.arun( - context=answer.context, - answer_length=answer.answer_length, - question=answer.question, + answer_text = await qa_chain( + data=dict( + context=answer.context, + answer_length=answer.answer_length, + question=answer.question, + ), callbacks=callbacks, - verbose=True, ) # it still happens if "(Example2012)" in answer_text: @@ -677,12 +581,14 @@ async def aquery( if self.prompts.post is not None: chain = make_chain( - self.prompts.post, - cast(BaseLanguageModel, self.llm), - memory=self.memory_model, + client=self._client, + prompt=self.prompts.post, + llm_config=self.llm_config, system_prompt=self.prompts.system, ) - post = await chain.arun(**answer.dict(), callbacks=get_callbacks("post")) + post = await chain.arun( + data=answer.model_dump(), callbacks=get_callbacks("post") + ) answer.answer = post answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n" if len(bib) > 0: diff --git a/paperqa/prompts.py b/paperqa/prompts.py index a13ddb6d5..e77852138 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -1,8 +1,5 @@ -from langchain.prompts import PromptTemplate - -summary_prompt = PromptTemplate( - input_variables=["text", "citation", "question", "summary_length"], - template="Summarize the text below to help answer a question. " +summary_prompt = ( + "Summarize the text below to help answer a question. " "Do not directly answer the question, instead summarize " "to give evidence to help answer the question. " "Focus on specific details, including numbers, equations, or specific quotes. " @@ -16,9 +13,8 @@ "Relevant Information Summary:", ) -qa_prompt = PromptTemplate( - input_variables=["context", "answer_length", "question"], - template="Write an answer ({answer_length}) " +qa_prompt = ( + "Write an answer ({answer_length}) " "for the question below based on the provided context. " "If the context provides insufficient information and the question cannot be directly answered, " 'reply "I cannot answer". ' @@ -29,9 +25,8 @@ "Answer: ", ) -select_paper_prompt = PromptTemplate( - input_variables=["question", "papers"], - template="Select papers that may help answer the question below. " +select_paper_prompt = ( + "Select papers that may help answer the question below. " "Papers are listed as $KEY: $PAPER_INFO. " "Return a list of keys, separated by commas. " 'Return "None", if no papers are applicable. ' @@ -41,12 +36,8 @@ "Papers: {papers}\n\n" "Selected keys:", ) - -# We are unable to serialize with partial variables -# so TODO: update year next year -citation_prompt = PromptTemplate( - input_variables=["text"], - template="Provide the citation for the following text in MLA Format. The year is 2023\n" +citation_prompt = ( + "Provide the citation for the following text in MLA Format.\n\n" "{text}\n\n" "Citation:", ) diff --git a/paperqa/readers.py b/paperqa/readers.py index 1b9773067..de6b901f6 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -2,7 +2,6 @@ from typing import List from html2text import html2text -from langchain.text_splitter import TokenTextSplitter from .types import Doc, Text @@ -84,9 +83,14 @@ def parse_txt( text = f.read() if html: text = html2text(text) - # yo, no idea why but the texts are not split correctly - text_splitter = TokenTextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) - raw_texts = text_splitter.split_text(text) + # chunk into size chunk_chars with overlap overlap + raw_texts = [] + start = 0 + while start < len(text): + end = min(start + chunk_chars, len(text)) + raw_texts.append(text[start:end]) + start = end - overlap + texts = [ Text(text=t, name=f"{doc.docname} chunk {i}", doc=doc) for i, t in enumerate(raw_texts) diff --git a/paperqa/types.py b/paperqa/types.py index f468d6575..a7603f894 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,18 +1,8 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from abc import ABC, abstractmethod +from typing import Any, Callable -from langchain.callbacks.base import BaseCallbackHandler -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.prompts import PromptTemplate - -try: - from pydantic.v1 import BaseModel, validator -except ImportError: - from pydantic import BaseModel, validator - -import re +import numpy as np +from pydantic import BaseModel, Field, field_validator from .prompts import ( citation_prompt, @@ -21,75 +11,191 @@ select_paper_prompt, summary_prompt, ) -from .utils import extract_doi, iter_citations +# Just for clarity DocKey = Any -CBManager = Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun] -CallbackFactory = Callable[[str], Union[None, List[BaseCallbackHandler]]] +CallbackFactory = Callable[[str], Callable[[str], None]] + + +class Embeddable(BaseModel): + embedding: list[float] | None = Field(default=None, repr=False) -class Doc(BaseModel): + +class Doc(Embeddable): docname: str citation: str dockey: DocKey -class Text(BaseModel): +class Text(Embeddable): text: str name: str doc: Doc - embeddings: Optional[List[float]] = None + + +def cosine_similarity(a, b): + dot_product = np.dot(a, b.T) + norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) + return dot_product / norm_product + + +class VectorStore(BaseModel, ABC): + """Interface for vector store - very similar to LangChain's VectorStore to be compatible""" + + @abstractmethod + def add_texts_and_embeddings(self, texts: list[Embeddable]) -> None: + pass + + @abstractmethod + def similarity_search( + self, query: list[float], k: int + ) -> list[tuple[Embeddable, float]]: + pass + + @abstractmethod + def clear(self) -> None: + pass + + def max_marginal_relevance_search( + self, query: list[float], k: int, fetch_k: int, lambda_: float = 0.5 + ) -> list[tuple[Embeddable, float]]: + """Vectorized implementation of Maximal Marginal Relevance (MMR) search. + + Args: + query: Query vector. + k: Number of results to return. + lambda_: Weighting of relevance and diversity. + + Returns: + List of tuples (doc, score) of length k. + """ + if fetch_k < k: + raise ValueError("fetch_k must be greater or equal to k") + + initial_results = self.similarity_search(query, fetch_k) + if len(initial_results) <= k: + return initial_results + + embeddings = np.array([t.embedding for t, _ in initial_results]) + scores = np.array([score for _, score in initial_results]) + similarity_matrix = cosine_similarity(embeddings, embeddings) + + selected_indices = [0] + remaining_indices = list(range(1, len(initial_results))) + + while len(selected_indices) < k: + selected_similarities = similarity_matrix[:, selected_indices] + max_sim_to_selected = selected_similarities.max(axis=1) + + mmr_scores = lambda_ * scores - (1 - lambda_) * max_sim_to_selected + mmr_scores[selected_indices] = -np.inf # Exclude already selected documents + + max_mmr_index = mmr_scores.argmax() + selected_indices.append(max_mmr_index) + remaining_indices.remove(max_mmr_index) + + return [(initial_results[i][0], scores[i]) for i in selected_indices] + + +class NumpyVectorStore(VectorStore): + texts: list[Embeddable] = [] + _embeddings_matrix: np.ndarray | None = None + + def clear(self) -> None: + self.texts = [] + self._embeddings_matrix = None + + def add_texts_and_embeddings( + self, + texts: list[Embeddable], + ) -> None: + self.texts.extend(texts) + self._embeddings_matrix = np.array([t.embedding for t in self.texts]) + + def similarity_search( + self, query: list[float], k: int + ) -> list[tuple[Embeddable, float]]: + if len(self.texts) == 0: + return [] + query = np.array(query) + similarity_scores = cosine_similarity( + query.reshape(1, -1), self._embeddings_matrix + )[0] + similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf) + sorted_indices = np.argsort(similarity_scores)[::-1] + return [(self.texts[i], similarity_scores[i]) for i in sorted_indices[:k]] + + +class _FormatDict(dict): + def __missing__(self, key: str) -> str: + return key + + +def get_formatted_variables(s: str) -> set[str]: + format_dict = _FormatDict() + s.format_map(format_dict) + return set(format_dict.keys()) class PromptCollection(BaseModel): - summary: PromptTemplate = summary_prompt - qa: PromptTemplate = qa_prompt - select: PromptTemplate = select_paper_prompt - cite: PromptTemplate = citation_prompt - pre: Optional[PromptTemplate] = None - post: Optional[PromptTemplate] = None + summary: str = summary_prompt + qa: str = qa_prompt + select: str = select_paper_prompt + cite: str = citation_prompt + pre: str | None = None + post: str | None = None system: str = default_system_prompt skip_summary: bool = False - @validator("summary") - def check_summary(cls, v: PromptTemplate) -> PromptTemplate: - if not set(v.input_variables).issubset(set(summary_prompt.input_variables)): + @field_validator("summary") + @classmethod + def check_summary(cls, v: str) -> str: + if not set(get_formatted_variables(v)).issubset( + set(get_formatted_variables(summary_prompt)) + ): raise ValueError( - f"Summary prompt can only have variables: {summary_prompt.input_variables}" + f"Summary prompt can only have variables: {get_formatted_variables(summary_prompt)}" ) return v - @validator("qa") - def check_qa(cls, v: PromptTemplate) -> PromptTemplate: - if not set(v.input_variables).issubset(set(qa_prompt.input_variables)): + @field_validator("qa") + @classmethod + def check_qa(cls, v: str) -> str: + if not set(get_formatted_variables(v)).issubset( + set(get_formatted_variables(qa_prompt)) + ): raise ValueError( - f"QA prompt can only have variables: {qa_prompt.input_variables}" + f"QA prompt can only have variables: {get_formatted_variables(qa_prompt)}" ) return v - @validator("select") - def check_select(cls, v: PromptTemplate) -> PromptTemplate: - if not set(v.input_variables).issubset( - set(select_paper_prompt.input_variables) + @field_validator("select") + @classmethod + def check_select(cls, v: str) -> str: + if not set(get_formatted_variables(v)).issubset( + set(get_formatted_variables(select_paper_prompt)) ): raise ValueError( - f"Select prompt can only have variables: {select_paper_prompt.input_variables}" + f"Select prompt can only have variables: {get_formatted_variables(select_paper_prompt)}" ) return v - @validator("pre") - def check_pre(cls, v: Optional[PromptTemplate]) -> Optional[PromptTemplate]: + @field_validator("pre") + @classmethod + def check_pre(cls, v: str | None) -> str | None: if v is not None: - if set(v.input_variables) != set(["question"]): + if set(get_formatted_variables(v)) != set(["question"]): raise ValueError("Pre prompt must have input variables: question") return v - @validator("post") - def check_post(cls, v: Optional[PromptTemplate]) -> Optional[PromptTemplate]: + @field_validator("post") + @classmethod + def check_post(cls, v: str | None) -> str | None: if v is not None: # kind of a hack to get list of attributes in answer attrs = [a.name for a in Answer.__fields__.values()] - if not set(v.input_variables).issubset(attrs): + if not set(get_formatted_variables(v)).issubset(attrs): raise ValueError(f"Post prompt must have input variables: {attrs}") return v @@ -113,18 +219,18 @@ class Answer(BaseModel): question: str answer: str = "" context: str = "" - contexts: List[Context] = [] + contexts: list[Context] = [] references: str = "" formatted_answer: str = "" - dockey_filter: Optional[Set[DocKey]] = None + dockey_filter: set[DocKey] | None = None summary_length: str = "about 100 words" answer_length: str = "about 100 words" - memory: Optional[str] = None + memory: str | None = None # these two below are for convenience # and are not set. But you can set them # if you want to use them. - cost: Optional[float] = None - token_counts: Optional[Dict[str, List[int]]] = None + cost: float | None = None + token_counts: dict[str, list[int]] | None = None def __str__(self) -> str: """Return the answer as a string.""" @@ -137,88 +243,3 @@ def get_citation(self, name: str) -> str: except StopIteration: raise ValueError(f"Could not find docname {name} in contexts") return doc.citation - - def markdown(self) -> Tuple[str, str]: - """Return the answer with footnote style citations.""" - # example: This is an answer.[^1] - # [^1]: This the citation. - output = self.answer - refs: Dict[str, int] = dict() - index = 1 - for citation in iter_citations(self.answer): - compound = "" - strip = True - for c in re.split(",|;", citation): - c = c.strip("() ") - if c == "Extra background information": - continue - if c in refs: - compound += f"[^{refs[c]}]" - continue - # check if it is a citation - try: - self.get_citation(c) - except ValueError: - # not a citation - strip = False - continue - refs[c] = index - compound += f"[^{index}]" - index += 1 - if strip: - output = output.replace(citation, compound) - formatted_refs = "\n".join( - [ - f"[^{i}]: [{self.get_citation(r)}]({extract_doi(self.get_citation(r))})" - for r, i in refs.items() - ] - ) - # quick fix of space before period - output = output.replace(" .", ".") - return output, formatted_refs - - def combine_with(self, other: "Answer") -> "Answer": - """ - Combine this answer object with another, merging their context/answer. - """ - combined = Answer( - question=self.question + " / " + other.question, - answer=self.answer + " " + other.answer, - context=self.context + " " + other.context, - contexts=self.contexts + other.contexts, - references=self.references + " " + other.references, - formatted_answer=self.formatted_answer + " " + other.formatted_answer, - summary_length=self.summary_length, # Assuming the same summary_length for both - answer_length=self.answer_length, # Assuming the same answer_length for both - memory=self.memory if self.memory else other.memory, - cost=self.cost if self.cost else other.cost, - token_counts=self.merge_token_counts(self.token_counts, other.token_counts), - ) - # Handling dockey_filter if present in either of the Answer objects - if self.dockey_filter or other.dockey_filter: - combined.dockey_filter = ( - self.dockey_filter if self.dockey_filter else set() - ) | (other.dockey_filter if other.dockey_filter else set()) - return combined - - @staticmethod - def merge_token_counts( - counts1: Optional[Dict[str, List[int]]], counts2: Optional[Dict[str, List[int]]] - ) -> Optional[Dict[str, List[int]]]: - """ - Merge two dictionaries of token counts. - """ - if counts1 is None and counts2 is None: - return None - if counts1 is None: - return counts2 - if counts2 is None: - return counts1 - merged_counts = counts1.copy() - for key, values in counts2.items(): - if key in merged_counts: - merged_counts[key][0] += values[0] - merged_counts[key][1] += values[1] - else: - merged_counts[key] = values - return merged_counts diff --git a/paperqa/utils.py b/paperqa/utils.py index 6cb0d1a0e..2b8930bbb 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -6,7 +6,6 @@ from typing import BinaryIO, List, Union import pypdf -from langchain.base_language import BaseLanguageModel StrPath = Union[str, Path] @@ -93,13 +92,6 @@ def guess_is_4xx(msg: str) -> bool: return False -def get_llm_name(llm: BaseLanguageModel) -> str: - try: - return llm.model_name # type: ignore - except AttributeError: - return llm.model # type: ignore - - def strip_citations(text: str) -> str: # Combined regex for identifying citations (see unit tests for examples) citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" diff --git a/paperqa/version.py b/paperqa/version.py index 5de4250b4..16763f330 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "3.13.5" +__version__ = "4.0.0-pre.1" diff --git a/setup.py b/setup.py index d2463365d..d712958e3 100644 --- a/setup.py +++ b/setup.py @@ -18,10 +18,9 @@ packages=["paperqa", "paperqa.contrib"], install_requires=[ "pypdf", - "pydantic<2", - "langchain>=0.0.303", - "openai <1", - "faiss-cpu", + "pydantic>=2", + "openai>=1", + "numpy", "PyCryptodome", "html2text", "tiktoken>=0.4.0", diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 755268807..b467d398a 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -1,17 +1,12 @@ import os import pickle from io import BytesIO -from typing import Any from unittest import IsolatedAsyncioTestCase import numpy as np import requests -from langchain.callbacks.base import AsyncCallbackHandler -from langchain.llms import OpenAI -from langchain.llms.fake import FakeListLLM -from langchain.prompts import PromptTemplate -from paperqa import Answer, Context, Doc, Docs, PromptCollection, Text +from paperqa import Answer, Doc, Docs, PromptCollection, Text from paperqa.chains import get_score from paperqa.readers import read_doc from paperqa.utils import ( @@ -24,11 +19,6 @@ ) -class TestHandler(AsyncCallbackHandler): - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - print(token) - - def test_iter_citations(): text = ( "Yes, COVID-19 vaccines are effective. Various studies have documented the " @@ -127,56 +117,6 @@ def test_citations_with_nonstandard_chars(): ) -def test_markdown(): - answer = Answer( - question="What was Fredic's greatest accomplishment?", - answer="Frederick Bates's greatest accomplishment was his role in resolving land disputes " - "and his service as governor of Missouri (Wiki2023 chunk 1, Wiki2023 chunk 2). It is said (in 2010) that foo." - "However many dispute this (Wiki2023 chunk 1).", - contexts=[ - Context( - context="", - text=Text( - text="Frederick Bates's greatest accomplishment was his role in resolving land disputes " - "and his service as governor of Missouri.", - name="Wiki2023 chunk 1", - doc=Doc( - name="Wiki2023", - docname="Wiki2023", - citation="WikiMedia Foundation, 2023, Accessed now", - texts=[], - ), - ), - score=5, - ), - Context( - context="", - text=Text( - text="It is said (in 2010) that foo.", - name="Wiki2023 chunk 2", - doc=Doc( - name="Wiki2023", - docname="Wiki2023", - citation="WikiMedia Foundation, 2023, Accessed now", - texts=[], - ), - ), - score=5, - ), - ], - ) - m, r = answer.markdown() - assert len(r.split("\n")) == 2 - assert "[^2]" in m - assert "[^3]" not in m - assert "[^1]" in m - print(m, r) - answer = answer.combine_with(answer) - m2, r2 = answer.markdown() - assert m2.startswith(m) - assert r2 == r - - def test_ablations(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") @@ -420,8 +360,8 @@ def test_extract_score(): def test_docs(): - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + llm_config = dict(temperature=0.1, model="text-ada-001", model_type="completion") + docs = Docs(llm_config=llm_config) docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", From 0009c0c159703c4e0921a1fe0c0f06c0e3f105ce Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 3 Jan 2024 22:47:26 -0800 Subject: [PATCH 02/16] Improved unit tests to be closer --- paperqa/docs.py | 122 ++++++++++++++++--------------- paperqa/{chains.py => llms.py} | 35 ++++++++- paperqa/prompts.py | 10 +-- paperqa/readers.py | 47 ++++++++---- paperqa/types.py | 33 +++++---- tests/test_paperqa.py | 127 +++++++++------------------------ 6 files changed, 188 insertions(+), 186 deletions(-) rename paperqa/{chains.py => llms.py} (82%) diff --git a/paperqa/docs.py b/paperqa/docs.py index 70b8e30f6..7b851ca18 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -1,7 +1,7 @@ +import nest_asyncio # isort:skip import asyncio import os import re -import sys import tempfile from datetime import datetime from io import BytesIO @@ -11,7 +11,7 @@ from openai import AsyncOpenAI from pydantic import BaseModel, Field, field_validator, model_validator -from .chains import get_score, make_chain +from .llms import embed_documents, get_score, guess_model_type, make_chain from .paths import PAPERQA_DIR from .readers import read_doc from .types import ( @@ -36,17 +36,20 @@ strip_citations, ) +# Apply the patch to allow nested loops +nest_asyncio.apply() + class Docs(BaseModel): """A collection of documents to be used for answering questions.""" - _client: AsyncOpenAI + _client: AsyncOpenAI | None docs: dict[DocKey, Doc] = {} texts: list[Text] = [] docnames: set[str] = set() texts_index: VectorStore = NumpyVectorStore() doc_index: VectorStore = NumpyVectorStore() - llm_config: dict = dict(model="gpt-3.5-turbo") + llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat") summary_llm_config: dict | None = Field(default=None, validate_default=True) name: str = "default" index_path: Path | None = PAPERQA_DIR / name @@ -69,17 +72,17 @@ def __init__(self, **data): @field_validator("llm_config", "summary_llm_config") @classmethod - def llm_config_has_type(cls, v): + def llm_guess_model_type(cls, v: dict) -> dict: if v is not None and "model_type" not in v: - raise ValueError("Must specify if chat or completion model.") + v["model_type"] = guess_model_type(v["model"]) return v @model_validator(mode="after") @classmethod def config_summary_llm_conig(cls, data: Any) -> Any: - if isinstance(data, dict): - if data["summary_llm_config"] is None: - data["summary_llm_config"] = data["llm_config"] + if isinstance(data, Docs): + if data.summary_llm_config is None: + data.summary_llm_config = data.llm_config return data def clear_docs(self): @@ -87,6 +90,21 @@ def clear_docs(self): self.docs = {} self.docnames = set() + def __getstate__(self): + state = super().__getstate__() + # remove client from private attributes + del state["__pydantic_private__"]["_client"] + return state + + def __setstate__(self, state): + super().__setstate__(state) + self._client = None + + def set_client(self, client: AsyncOpenAI | None = None): + if client is None: + client = AsyncOpenAI() + self._client = client + def _get_unique_name(self, docname: str) -> str: """Create a unique name given proposed name""" suffix = "" @@ -173,12 +191,7 @@ def add( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - citation = loop.run_until_complete(cite_chain(texts[0].text)) + citation = asyncio.run(cite_chain(data=dict(text=texts[0].text))) if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -234,14 +247,20 @@ def add_texts( t.name = t.name.replace(doc.docname, new_docname) doc.docname = new_docname if texts[0].embedding is None: - text_embeddings = self.embeddings.embed_documents([t.text for t in texts]) + text_embeddings = asyncio.run( + embed_documents( + self._client, [t.text for t in texts], self.embeddings_model + ) + ) for i, t in enumerate(texts): t.embedding = text_embeddings[i] if doc.embedding is None: - doc.embedding = self.embeddings.embed_documents([doc.citation])[0] + doc.embedding = asyncio.run( + embed_documents(self._client, [doc.citation], self.embeddings_model) + )[0] if not self.jit_texts_index: self.texts_index.add_texts_and_embeddings(texts) - self.doc_index.add_texts_and_embeddings(doc) + self.doc_index.add_texts_and_embeddings([doc]) self.docs[doc.dockey] = doc self.texts += texts self.docnames.add(doc.docname) @@ -266,8 +285,13 @@ async def adoc_match( get_callbacks: CallbackFactory = lambda x: None, ) -> set[DocKey]: """Return a list of dockeys that match the query.""" + query_vector = ( + await embed_documents(self._client, [query], self.embeddings_model) + )[0] matches, _ = self.doc_index.max_marginal_relevance_search( - query, k=k + len(self.deleted_dockeys) + query_vector, + k=k + len(self.deleted_dockeys), + fetch_k=5 * (k + len(self.deleted_dockeys)), ) # filter the matches matched_docs = [m for m in matches if m.dockey not in self.deleted_dockeys] @@ -277,7 +301,7 @@ async def adoc_match( try: if ( rerank is None - and self.llm_config.model.startswith("gpt-4") + and self.llm_config["model"].startswith("gpt-4") or rerank is True ): chain = make_chain( @@ -296,7 +320,7 @@ async def adoc_match( pass return set([d.dockey for d in matched_docs]) - def _build_texts_index(self, keys: set[DocKey] = None): + def _build_texts_index(self, keys: set[DocKey] | None = None): if keys is not None and self.jit_texts_index: texts = self.texts if keys is not None: @@ -317,17 +341,7 @@ def get_evidence( disable_vector_search: bool = False, disable_summarization: bool = False, ) -> Answer: - # special case for jupyter notebooks - if "get_ipython" in globals() or "google.colab" in sys.modules: - import nest_asyncio - - nest_asyncio.apply() - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( + return asyncio.run( self.aget_evidence( answer, k=k, @@ -362,12 +376,17 @@ async def aget_evidence( if disable_vector_search: matches = self.texts else: + query_vector = ( + await embed_documents( + self._client, [answer.question], self.embeddings_model + ) + )[0] if marginal_relevance: - matches = self.texts_index.max_marginal_relevance_search( - answer.question, k=_k, fetch_k=5 * _k + matches, _ = self.texts_index.max_marginal_relevance_search( + query_vector, k=_k, fetch_k=5 * _k ) else: - matches = self.texts_index.similarity_search(answer.question, k=_k) + matches, _ = self.texts_index.similarity_search(query_vector, k=_k) # ok now filter (like ones from adoc_match) if answer.dockey_filter is not None: matches = [m for m in matches if m.doc.dockey in answer.dockey_filter] @@ -406,7 +425,7 @@ async def process(match): # http code in the exception try: context = await summary_chain( - dict( + data=dict( question=answer.question, # Add name so chunk is stated citation=citation, @@ -434,7 +453,7 @@ async def process(match): text=Text( text=match.text, name=match.name, - doc=Doc(**match.doc), + doc=Doc(**match.doc.model_dump()), ), score=score, ) @@ -474,17 +493,7 @@ def query( key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, ) -> Answer: - # special case for jupyter notebooks - if "get_ipython" in globals() or "google.colab" in sys.modules: - import nest_asyncio - - nest_asyncio.apply() - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( + return asyncio.run( self.aquery( query, k=k, @@ -540,12 +549,11 @@ async def aquery( ) answer.context = answer.context + "\n\nExtra background information:" + pre bib = dict() - if len(answer.context) < 10 and not self.memory: + if len(answer.context) < 10: # and not self.memory: answer_text = ( "I cannot answer this question due to insufficient information." ) else: - callbacks = get_callbacks("answer") qa_chain = make_chain( client=self._client, prompt=self.prompts.qa, @@ -558,7 +566,7 @@ async def aquery( answer_length=answer.answer_length, question=answer.question, ), - callbacks=callbacks, + callbacks=get_callbacks("answer"), ) # it still happens if "(Example2012)" in answer_text: @@ -586,17 +594,17 @@ async def aquery( llm_config=self.llm_config, system_prompt=self.prompts.system, ) - post = await chain.arun( + post = await chain( data=answer.model_dump(), callbacks=get_callbacks("post") ) answer.answer = post answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n" if len(bib) > 0: answer.formatted_answer += f"\nReferences\n\n{bib_str}\n" - if self.memory_model is not None: - answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"] - self.memory_model.save_context( - {"Question": answer.question}, {"Answer": answer.answer} - ) + # if self.memory_model is not None: + # answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"] + # self.memory_model.save_context( + # {"Question": answer.question}, {"Answer": answer.answer} + # ) return answer diff --git a/paperqa/chains.py b/paperqa/llms.py similarity index 82% rename from paperqa/chains.py rename to paperqa/llms.py index d06c8128f..18fff67bf 100644 --- a/paperqa/chains.py +++ b/paperqa/llms.py @@ -1,11 +1,22 @@ import re -from typing import Callable +from typing import Any, Awaitable, Callable, get_args, get_type_hints from openai import AsyncOpenAI from .prompts import default_system_prompt -default_system_prompt = "End your responses with [END]" + +def guess_model_type(model_name: str) -> str: + import openai + + model_type = get_type_hints( + openai.types.chat.completion_create_params.CompletionCreateParamsBase + )["model"] + model_union = get_args(get_args(model_type)[1]) + model_arr = list(model_union) + if model_name in model_arr: + return "chat" + return "completion" def process_llm_config(llm_config: dict) -> dict: @@ -23,13 +34,27 @@ def process_llm_config(llm_config: dict) -> dict: return result +async def embed_documents( + client: AsyncOpenAI, texts: list[str], embedding_model: str +) -> list[list[float]]: + """Embed a list of documents with batching""" + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) + response = await client.embeddings.create( + model=embedding_model, input=texts, encoding_format="float" + ) + return [e.embedding for e in response.data] + + def make_chain( client: AsyncOpenAI, prompt: str, llm_config: dict, skip_system: bool = False, system_prompt: str = default_system_prompt, -) -> Callable[[list[dict], list[Callable[[str], None]] | None], list[str]]: +) -> Awaitable[Any]: """Create a function to execute a batch of prompts Args: @@ -45,6 +70,10 @@ def make_chain( where data is a dict with keys for the input variables that will be formatted into prompt and callbacks is a list of functions to call with each chunk of the completion. """ + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) if llm_config["model_type"] == "chat": system_message_prompt = dict(role="system", content=system_prompt) human_message_prompt = dict(role="user", content=prompt) diff --git a/paperqa/prompts.py b/paperqa/prompts.py index e77852138..1432244ce 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -10,7 +10,7 @@ "{text}\n\n" "Excerpt from {citation}\n" "Question: {question}\n" - "Relevant Information Summary:", + "Relevant Information Summary:" ) qa_prompt = ( @@ -22,7 +22,7 @@ "via valid citation markers at the end of sentences, like (Example2012). \n" "Context (with relevance scores):\n {context}\n" "Question: {question}\n" - "Answer: ", + "Answer: " ) select_paper_prompt = ( @@ -34,12 +34,12 @@ "(if the question requires timely information). \n\n" "Question: {question}\n\n" "Papers: {papers}\n\n" - "Selected keys:", + "Selected keys:" ) citation_prompt = ( - "Provide the citation for the following text in MLA Format.\n\n" + "Provide the citation for the following text in MLA Format. If reporting date accessed, the current year is 2024\n\n" "{text}\n\n" - "Citation:", + "Citation:" ) default_system_prompt = ( diff --git a/paperqa/readers.py b/paperqa/readers.py index de6b901f6..fd8be1b8a 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import List +import tiktoken from html2text import html2text from .types import Doc, Text @@ -75,6 +76,10 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text def parse_txt( path: Path, doc: Doc, chunk_chars: int, overlap: int, html: bool = False ) -> List[Text]: + """Parse a document into chunks, based on tiktoken encoding. + + NOTE: We get some byte continuation errors. Currnetly ignored, but should explore more to make sure we don't miss anything. + """ try: with open(path) as f: text = f.read() @@ -83,18 +88,36 @@ def parse_txt( text = f.read() if html: text = html2text(text) - # chunk into size chunk_chars with overlap overlap - raw_texts = [] - start = 0 - while start < len(text): - end = min(start + chunk_chars, len(text)) - raw_texts.append(text[start:end]) - start = end - overlap - - texts = [ - Text(text=t, name=f"{doc.docname} chunk {i}", doc=doc) - for i, t in enumerate(raw_texts) - ] + texts = [] + # we tokenize using tiktoken so cuts are in reasonable places + enc = tiktoken.get_encoding("cl100k_base") + encoded = [enc.decode_single_token_bytes(token) for token in enc.encode(text)] + split_size = 0 + split_flat = "" + split = [] + for chunk in encoded: + split.append(chunk) + split_size += len(chunk) + if split_size > chunk_chars: + split_flat = b"".join(split).decode() + texts.append( + Text( + text=split_flat[:chunk_chars], + name=f"{doc.docname} chunk {len(texts) + 1}", + doc=doc, + ) + ) + split = [split_flat[chunk_chars - overlap :].encode("utf-8")] + split_size = len(split[0]) + if len(split) > overlap: + split_flat = b"".join(split).decode() + texts.append( + Text( + text=split_flat[:chunk_chars], + name=f"{doc.docname} lines {len(texts) + 1}", + doc=doc, + ) + ) return texts diff --git a/paperqa/types.py b/paperqa/types.py index a7603f894..1e80c8caa 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -50,7 +50,7 @@ def add_texts_and_embeddings(self, texts: list[Embeddable]) -> None: @abstractmethod def similarity_search( self, query: list[float], k: int - ) -> list[tuple[Embeddable, float]]: + ) -> tuple[list[Embeddable], list[float]]: pass @abstractmethod @@ -59,7 +59,7 @@ def clear(self) -> None: def max_marginal_relevance_search( self, query: list[float], k: int, fetch_k: int, lambda_: float = 0.5 - ) -> list[tuple[Embeddable, float]]: + ) -> tuple[list[Embeddable], list[float]]: """Vectorized implementation of Maximal Marginal Relevance (MMR) search. Args: @@ -73,16 +73,16 @@ def max_marginal_relevance_search( if fetch_k < k: raise ValueError("fetch_k must be greater or equal to k") - initial_results = self.similarity_search(query, fetch_k) - if len(initial_results) <= k: - return initial_results + texts, scores = self.similarity_search(query, fetch_k) + if len(texts) <= k: + return texts, scores - embeddings = np.array([t.embedding for t, _ in initial_results]) - scores = np.array([score for _, score in initial_results]) + embeddings = np.array([t.embedding for t in texts]) + scores = np.array(scores) similarity_matrix = cosine_similarity(embeddings, embeddings) selected_indices = [0] - remaining_indices = list(range(1, len(initial_results))) + remaining_indices = list(range(1, len(texts))) while len(selected_indices) < k: selected_similarities = similarity_matrix[:, selected_indices] @@ -95,7 +95,9 @@ def max_marginal_relevance_search( selected_indices.append(max_mmr_index) remaining_indices.remove(max_mmr_index) - return [(initial_results[i][0], scores[i]) for i in selected_indices] + return [texts[i] for i in selected_indices], [ + scores[i] for i in selected_indices + ] class NumpyVectorStore(VectorStore): @@ -115,16 +117,19 @@ def add_texts_and_embeddings( def similarity_search( self, query: list[float], k: int - ) -> list[tuple[Embeddable, float]]: + ) -> tuple[list[Embeddable], list[float]]: if len(self.texts) == 0: - return [] - query = np.array(query) + return [], [] + np_query = np.array(query) similarity_scores = cosine_similarity( - query.reshape(1, -1), self._embeddings_matrix + np_query.reshape(1, -1), self._embeddings_matrix )[0] similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf) sorted_indices = np.argsort(similarity_scores)[::-1] - return [(self.texts[i], similarity_scores[i]) for i in sorted_indices[:k]] + return ( + [self.texts[i] for i in sorted_indices[:k]], + [similarity_scores[i] for i in sorted_indices[:k]], + ) class _FormatDict(dict): diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index b467d398a..1ce15db55 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -7,7 +7,7 @@ import requests from paperqa import Answer, Doc, Docs, PromptCollection, Text -from paperqa.chains import get_score +from paperqa.llms import get_score from paperqa.readers import read_doc from paperqa.utils import ( iter_citations, @@ -370,15 +370,6 @@ def test_docs(): assert docs.docs["test"].docname == "Wiki2023" -def test_update_llm(): - doc = Docs() - doc.update_llm("gpt-3.5-turbo") - assert doc.llm == doc.summary_llm - - doc.update_llm(OpenAI(client=None, temperature=0.1, model="text-ada-001")) - assert doc.llm == doc.summary_llm - - def test_evidence(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: @@ -451,7 +442,6 @@ async def test_adoc_match(self): "What is Frederick Bates's greatest accomplishment?" ) assert len(sources) > 0 - docs.update_llm("gpt-3.5-turbo") sources = await docs.adoc_match( "What is Frederick Bates's greatest accomplishment?" ) @@ -464,15 +454,20 @@ def test_docs_pickle(): # get front page of wikipedia r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") f.write(r.text) - llm = OpenAI(client=None, temperature=0.0, model="text-curie-001") - docs = Docs(llm=llm) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) + old_config = docs.llm_config docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) os.remove(doc_path) docs_pickle = pickle.dumps(docs) docs2 = pickle.loads(docs_pickle) - docs2.update_llm(llm) - assert llm.model_name == docs2.llm.model_name - assert docs2.summary_llm.model_name == docs2.llm.model_name + # make sure it fails if we haven't set client + try: + docs2.query("What date is bring your dog to work in the US?") + except ValueError: + pass + docs2.set_client() + assert docs2.llm_config == old_config + assert docs2.summary_llm_config == old_config assert len(docs.docs) == len(docs2.docs) context1, context2 = ( docs.get_evidence( @@ -497,45 +492,6 @@ def test_docs_pickle(): docs.query("What date is bring your dog to work in the US?") -def test_docs_pickle_no_faiss(): - doc_path = "example.html" - with open(doc_path, "w", encoding="utf-8") as f: - # get front page of wikipedia - r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") - f.write(r.text) - llm = OpenAI(client=None, temperature=0.0, model="text-curie-001") - docs = Docs(llm=llm) - docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) - docs.doc_index = None - docs.texts_index = None - docs_pickle = pickle.dumps(docs) - docs2 = pickle.loads(docs_pickle) - docs2.update_llm(llm) - assert len(docs.docs) == len(docs2.docs) - assert ( - strings_similarity( - docs.get_evidence( - Answer( - question="What date is bring your dog to work in the US?", - summary_length="about 20 words", - ), - k=3, - max_sources=1, - ).context, - docs2.get_evidence( - Answer( - question="What date is bring your dog to work in the US?", - summary_length="about 20 words", - ), - k=3, - max_sources=1, - ).context, - ) - > 0.75 - ) - os.remove(doc_path) - - def test_bad_context(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: @@ -555,7 +511,7 @@ def test_repeat_keys(): # get wiki page about politician r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") f.write(r.text) - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-ada-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="text-ada-001")) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") try: docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") @@ -584,7 +540,7 @@ def test_repeat_keys(): def test_pdf_reader(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-curie-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) docs.add(doc_path, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -594,7 +550,7 @@ def test_fileio_reader_pdf(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-curie-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -602,7 +558,7 @@ def test_fileio_reader_pdf(): def test_fileio_reader_txt(): # can't use curie, because it has trouble with parsed HTML - docs = Docs(llm=OpenAI(client=None, temperature=0.0)) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") if r.status_code != 200: raise ValueError("Could not download wikipedia page") @@ -652,7 +608,7 @@ def test_prompt_length(): def test_code(): # load this script doc_path = os.path.abspath(__file__) - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-ada-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="babbage-002")) docs.add(doc_path, "test_paperqa.py", docname="test_paperqa.py", disable_check=True) assert len(docs.docs) == 1 docs.query("What function tests the preview?") @@ -667,8 +623,8 @@ def test_citation(): docs = Docs() docs.add(doc_path) assert ( - list(docs.docs.values())[0].docname == "Wikipedia2023" - or list(docs.docs.values())[0].docname == "Frederick2023" + list(docs.docs.values())[0].docname == "Wikipedia2024" + or list(docs.docs.values())[0].docname == "Frederick2024" ) @@ -740,19 +696,6 @@ def test_query_filter(): # the filter shouldn't trigger, so just checking that it doesn't crash -def test_nonopenai_client(): - responses = ["This is a test", "This is another test"] * 50 - model = FakeListLLM(responses=responses) - doc_path = "example.txt" - with open(doc_path, "w", encoding="utf-8") as f: - # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") - f.write(r.text) - docs = Docs(llm=model) - docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") - docs.query("What country is Bates from?") - - def test_zotera(): from paperqa.contrib import ZoteroDB @@ -786,11 +729,10 @@ def test_too_much_evidence(): def test_custom_prompts(): - my_qaprompt = PromptTemplate( - input_variables=["question", "context"], - template="Answer the question '{question}' " + my_qaprompt = ( + "Answer the question '{question}' " "using the country name alone. For example: " - "A: United States\nA: Canada\nA: Mexico\n\n Using the context:\n\n{context}\n\nA: ", + "A: United States\nA: Canada\nA: Mexico\n\n Using the context:\n\n{context}\n\nA: " ) docs = Docs(prompts=PromptCollection(qa=my_qaprompt)) @@ -807,11 +749,7 @@ def test_custom_prompts(): def test_pre_prompt(): - pre = PromptTemplate( - input_variables=["question"], - template="Provide context you have memorized " - "that could help answer '{question}'. ", - ) + pre = "Provide context you have memorized " "that could help answer '{question}'. " docs = Docs(prompts=PromptCollection(pre=pre)) @@ -825,13 +763,12 @@ def test_pre_prompt(): def test_post_prompt(): - post = PromptTemplate( - input_variables=["question", "answer"], - template="We are trying to answer the question below " + post = ( + "We are trying to answer the question below " "and have an answer provided. " "Please edit the answer be extremely terse, with no extra words or formatting" "with no extra information.\n\n" - "Q: {question}\nA: {answer}\n\n", + "Q: {question}\nA: {answer}\n\n" ) docs = Docs(prompts=PromptCollection(post=post)) @@ -883,8 +820,8 @@ def disabled_test_memory(): def test_add_texts(): - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + llm_config = dict(temperature=0.1, model="text-ada-001") + docs = Docs(llm_config=llm_config) docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", @@ -894,17 +831,17 @@ def test_add_texts(): docs2 = Docs() texts = [Text(**dict(t)) for t in docs.texts] for t in texts: - t.embeddings = None + t.embedding = None docs2.add_texts(texts, list(docs.docs.values())[0]) for t1, t2 in zip(docs2.texts, docs.texts): assert t1.text == t2.text - assert np.allclose(t1.embeddings, t2.embeddings, atol=1e-3) + assert np.allclose(t1.embedding, t2.embedding, atol=1e-3) docs2._build_texts_index() # now do it again to test after text index is already built - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + llm_config = dict(temperature=0.1, model="text-ada-001") + docs = Docs(llm_config=llm_config) docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", @@ -913,7 +850,7 @@ def test_add_texts(): texts = [Text(**dict(t)) for t in docs.texts] for t in texts: - t.embeddings = None + t.embedding = None docs2.add_texts(texts, list(docs.docs.values())[0]) assert len(docs2.docs) == 2 From 527472b1c8761cd70bc1f5269f0a97ccb76627c9 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 4 Jan 2024 11:01:52 -0800 Subject: [PATCH 03/16] Fixed remaining tests --- paperqa/docs.py | 44 ++++++++++++++++++-------------- paperqa/llms.py | 4 +-- paperqa/prompts.py | 5 ++-- paperqa/readers.py | 6 +++-- paperqa/types.py | 18 ++++++++++---- tests/test_paperqa.py | 58 +++++++++++++++++++++++++++++++++++++------ 6 files changed, 99 insertions(+), 36 deletions(-) diff --git a/paperqa/docs.py b/paperqa/docs.py index 7b851ca18..69fdf376d 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -49,7 +49,7 @@ class Docs(BaseModel): docnames: set[str] = set() texts_index: VectorStore = NumpyVectorStore() doc_index: VectorStore = NumpyVectorStore() - llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat") + llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat", temperature=0.1) summary_llm_config: dict | None = Field(default=None, validate_default=True) name: str = "default" index_path: Path | None = PAPERQA_DIR / name @@ -61,6 +61,7 @@ class Docs(BaseModel): jit_texts_index: bool = False # This is used to strip indirect citations that come up from the summary llm strip_citations: bool = True + verbose: bool = False def __init__(self, **data): if "client" in data: @@ -183,7 +184,7 @@ def add( cite_chain = make_chain( client=self._client, prompt=self.prompts.cite, - llm_config=self.summary_llm_config, + llm_config=cast(dict, self.summary_llm_config), skip_system=True, ) # peak first chunk @@ -191,7 +192,9 @@ def add( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - citation = asyncio.run(cite_chain(data=dict(text=texts[0].text))) + citation = asyncio.run( + cite_chain(dict(text=texts[0].text), None), + ) if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -312,8 +315,8 @@ async def adoc_match( ) papers = [f"{d.docname}: {d.citation}" for d in matched_docs] result = await chain( - data=[dict(question=query, papers="\n".join(papers))], - callbacks=get_callbacks("filter"), + dict(question=query, papers="\n".join(papers)), + get_callbacks("filter"), ) return set([d.dockey for d in matched_docs if d.docname in result]) except AttributeError: @@ -321,14 +324,22 @@ async def adoc_match( return set([d.dockey for d in matched_docs]) def _build_texts_index(self, keys: set[DocKey] | None = None): + texts = self.texts if keys is not None and self.jit_texts_index: - texts = self.texts if keys is not None: texts = [t for t in texts if t.doc.dockey in keys] if len(texts) == 0: return self.texts_index.clear() self.texts_index.add_texts_and_embeddings(texts) + if self.jit_texts_index and keys is None: + # Not sure what else to do here??????? + print( + "Warning: JIT text index without keys " + "requires rebuilding index each time!" + ) + self.texts_index.clear() + self.texts_index.add_texts_and_embeddings(texts) def get_evidence( self, @@ -369,7 +380,6 @@ async def aget_evidence( # do we have no docs? return answer self._build_texts_index(keys=answer.dockey_filter) - self.texts_index = cast(VectorStore, self.texts_index) _k = k if answer.dockey_filter is not None: _k = k * 10 # heuristic - get enough so we can downselect @@ -414,7 +424,7 @@ async def process(match): summary_chain = make_chain( client=self._client, prompt=self.prompts.summary, - llm_config=self.summary_llm_config, + llm_config=cast(dict, self.summary_llm_config), system_prompt=self.prompts.system, ) # This is dangerous because it @@ -425,14 +435,14 @@ async def process(match): # http code in the exception try: context = await summary_chain( - data=dict( + dict( question=answer.question, # Add name so chunk is stated citation=citation, summary_length=answer.summary_length, text=match.text, ), - callbacks=callbacks, + callbacks, ) except Exception as e: if guess_is_4xx(str(e)): @@ -544,9 +554,7 @@ async def aquery( llm_config=self.llm_config, system_prompt=self.prompts.system, ) - pre = await chain( - data=dict(question=answer.question), callbacks=get_callbacks("pre") - ) + pre = await chain(dict(question=answer.question), get_callbacks("pre")) answer.context = answer.context + "\n\nExtra background information:" + pre bib = dict() if len(answer.context) < 10: # and not self.memory: @@ -560,14 +568,16 @@ async def aquery( llm_config=self.llm_config, system_prompt=self.prompts.system, ) + print(answer.context) answer_text = await qa_chain( - data=dict( + dict( context=answer.context, answer_length=answer.answer_length, question=answer.question, ), - callbacks=get_callbacks("answer"), + get_callbacks("answer"), ) + print(answer_text) # it still happens if "(Example2012)" in answer_text: answer_text = answer_text.replace("(Example2012)", "") @@ -594,9 +604,7 @@ async def aquery( llm_config=self.llm_config, system_prompt=self.prompts.system, ) - post = await chain( - data=answer.model_dump(), callbacks=get_callbacks("post") - ) + post = await chain(answer.model_dump(), get_callbacks("post")) answer.answer = post answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n" if len(bib) > 0: diff --git a/paperqa/llms.py b/paperqa/llms.py index 18fff67bf..934435c95 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1,5 +1,5 @@ import re -from typing import Any, Awaitable, Callable, get_args, get_type_hints +from typing import Any, Callable, Coroutine, get_args, get_type_hints from openai import AsyncOpenAI @@ -54,7 +54,7 @@ def make_chain( llm_config: dict, skip_system: bool = False, system_prompt: str = default_system_prompt, -) -> Awaitable[Any]: +) -> Callable[[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, str]]: """Create a function to execute a batch of prompts Args: diff --git a/paperqa/prompts.py b/paperqa/prompts.py index 1432244ce..1d177cb86 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -15,7 +15,7 @@ qa_prompt = ( "Write an answer ({answer_length}) " - "for the question below based on the provided context. " + "for the question below based on the provided context. Ignore irrelevant context. " "If the context provides insufficient information and the question cannot be directly answered, " 'reply "I cannot answer". ' "For each part of your answer, indicate which sources most support it " @@ -37,7 +37,8 @@ "Selected keys:" ) citation_prompt = ( - "Provide the citation for the following text in MLA Format. If reporting date accessed, the current year is 2024\n\n" + "Provide the citation for the following text in MLA Format. " + "If reporting date accessed, the current year is 2024\n\n" "{text}\n\n" "Citation:" ) diff --git a/paperqa/readers.py b/paperqa/readers.py index fd8be1b8a..9c629573b 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -78,7 +78,9 @@ def parse_txt( ) -> List[Text]: """Parse a document into chunks, based on tiktoken encoding. - NOTE: We get some byte continuation errors. Currnetly ignored, but should explore more to make sure we don't miss anything. + NOTE: We get some byte continuation errors. + Currnetly ignored, but should explore more to make sure we + don't miss anything. """ try: with open(path) as f: @@ -88,7 +90,7 @@ def parse_txt( text = f.read() if html: text = html2text(text) - texts = [] + texts: list[Text] = [] # we tokenize using tiktoken so cuts are in reasonable places enc = tiktoken.get_encoding("cl100k_base") encoded = [enc.decode_single_token_bytes(token) for token in enc.encode(text)] diff --git a/paperqa/types.py b/paperqa/types.py index 1e80c8caa..da4705ab0 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -15,7 +15,7 @@ # Just for clarity DocKey = Any -CallbackFactory = Callable[[str], Callable[[str], None]] +CallbackFactory = Callable[[str], list[Callable[[str], None]] | None] class Embeddable(BaseModel): @@ -78,7 +78,7 @@ def max_marginal_relevance_search( return texts, scores embeddings = np.array([t.embedding for t in texts]) - scores = np.array(scores) + np_scores = np.array(scores) similarity_matrix = cosine_similarity(embeddings, embeddings) selected_indices = [0] @@ -88,7 +88,7 @@ def max_marginal_relevance_search( selected_similarities = similarity_matrix[:, selected_indices] max_sim_to_selected = selected_similarities.max(axis=1) - mmr_scores = lambda_ * scores - (1 - lambda_) * max_sim_to_selected + mmr_scores = lambda_ * np_scores - (1 - lambda_) * max_sim_to_selected mmr_scores[selected_indices] = -np.inf # Exclude already selected documents max_mmr_index = mmr_scores.argmax() @@ -132,15 +132,21 @@ def similarity_search( ) +# Mock a dictionary and store any missing items class _FormatDict(dict): + def __init__(self) -> None: + self.key_set: set[str] = set() + def __missing__(self, key: str) -> str: + self.key_set.add(key) return key def get_formatted_variables(s: str) -> set[str]: + """Returns the set of variables implied by the format string""" format_dict = _FormatDict() s.format_map(format_dict) - return set(format_dict.keys()) + return format_dict.key_set class PromptCollection(BaseModel): @@ -190,6 +196,8 @@ def check_select(cls, v: str) -> str: @classmethod def check_pre(cls, v: str | None) -> str | None: if v is not None: + print(v) + print(get_formatted_variables(v)) if set(get_formatted_variables(v)) != set(["question"]): raise ValueError("Pre prompt must have input variables: question") return v @@ -199,7 +207,7 @@ def check_pre(cls, v: str | None) -> str | None: def check_post(cls, v: str | None) -> str | None: if v is not None: # kind of a hack to get list of attributes in answer - attrs = [a.name for a in Answer.__fields__.values()] + attrs = set(Answer.model_fields.keys()) if not set(get_formatted_variables(v)).issubset(attrs): raise ValueError(f"Post prompt must have input variables: {attrs}") return v diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 1ce15db55..ec17bfa89 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -5,9 +5,10 @@ import numpy as np import requests +from openai import AsyncOpenAI from paperqa import Answer, Doc, Docs, PromptCollection, Text -from paperqa.llms import get_score +from paperqa.llms import get_score, make_chain from paperqa.readers import read_doc from paperqa.utils import ( iter_citations, @@ -359,6 +360,49 @@ def test_extract_score(): assert get_score(sample) == 9 +class TestChains(IsolatedAsyncioTestCase): + async def test_chain_completion(self): + client = AsyncOpenAI() + call = make_chain( + client, + "The {animal} says", + llm_config=dict( + model_type="completion", + temperature=0, + model="babbage-002", + max_tokens=56, + ), + skip_system=True, + ) + outputs = [] + + def accum(x): + outputs.append(x) + + completion = await call(dict(animal="duck"), callbacks=[accum]) + assert completion == "".join(outputs) + assert type(completion) == str + + async def test_chain_chat(self): + client = AsyncOpenAI() + call = make_chain( + client, + "The {animal} says", + llm_config=dict( + model_type="chat", temperature=0, model="gpt-3.5-turbo", max_tokens=56 + ), + skip_system=True, + ) + outputs = [] + + def accum(x): + outputs.append(x) + + completion = await call(dict(animal="duck"), callbacks=[accum]) + assert completion == "".join(outputs) + assert type(completion) == str + + def test_docs(): llm_config = dict(temperature=0.1, model="text-ada-001", model_type="completion") docs = Docs(llm_config=llm_config) @@ -454,7 +498,7 @@ def test_docs_pickle(): # get front page of wikipedia r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") f.write(r.text) - docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) + docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) old_config = docs.llm_config docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) os.remove(doc_path) @@ -511,7 +555,7 @@ def test_repeat_keys(): # get wiki page about politician r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") f.write(r.text) - docs = Docs(llm_config=dict(temperature=0.0, model="text-ada-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="babbage-002")) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") try: docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") @@ -540,7 +584,7 @@ def test_repeat_keys(): def test_pdf_reader(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") - docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) + docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) docs.add(doc_path, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -550,7 +594,7 @@ def test_fileio_reader_pdf(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) + docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -558,7 +602,7 @@ def test_fileio_reader_pdf(): def test_fileio_reader_txt(): # can't use curie, because it has trouble with parsed HTML - docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) + docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") if r.status_code != 200: raise ValueError("Could not download wikipedia page") @@ -568,7 +612,7 @@ def test_fileio_reader_txt(): chunk_chars=1000, ) answer = docs.query("What country was Frederick Bates born in?") - assert "Virginia" in answer.answer + assert "United States" in answer.answer def test_pdf_pypdf_reader(): From 2a32876656259b383f45122426db6652de16d5c7 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 4 Jan 2024 11:12:06 -0800 Subject: [PATCH 04/16] Added new dependencies --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index d712958e3..a69571ff1 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "pydantic>=2", "openai>=1", "numpy", + "nest-asyncio", "PyCryptodome", "html2text", "tiktoken>=0.4.0", From 37b82f99eaa80c60b0195a30df53c1e22f0fffb3 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 4 Jan 2024 16:54:18 -0800 Subject: [PATCH 05/16] Refactored LLMs to allow swapping --- paperqa/docs.py | 110 +++++++++++------- paperqa/llms.py | 265 ++++++++++++++++++++++++++++-------------- tests/test_paperqa.py | 60 ++++++---- 3 files changed, 283 insertions(+), 152 deletions(-) diff --git a/paperqa/docs.py b/paperqa/docs.py index 69fdf376d..4490b04f9 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -9,9 +9,16 @@ from typing import Any, BinaryIO, cast from openai import AsyncOpenAI -from pydantic import BaseModel, Field, field_validator, model_validator - -from .llms import embed_documents, get_score, guess_model_type, make_chain +from pydantic import BaseModel, Field, model_validator + +from .llms import ( + EmbeddingModel, + LLMModel, + OpenAIEmbeddingModel, + OpenAILLMModel, + get_score, + is_openai_model, +) from .paths import PAPERQA_DIR from .readers import read_doc from .types import ( @@ -43,17 +50,21 @@ class Docs(BaseModel): """A collection of documents to be used for answering questions.""" - _client: AsyncOpenAI | None + # ephemeral clients that should not be pickled + _client: Any | None + _embedding_client: Any | None + llm: str = "default" + summary_llm: str | None = None + llm_model: LLMModel = Field(default_factory=OpenAILLMModel) + summary_llm_model: LLMModel | None = Field(default=None, validate_default=True) + embedding: EmbeddingModel = OpenAIEmbeddingModel() docs: dict[DocKey, Doc] = {} texts: list[Text] = [] docnames: set[str] = set() texts_index: VectorStore = NumpyVectorStore() doc_index: VectorStore = NumpyVectorStore() - llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat", temperature=0.1) - summary_llm_config: dict | None = Field(default=None, validate_default=True) name: str = "default" index_path: Path | None = PAPERQA_DIR / name - embeddings_model: str = "text-embedding-ada-002" batch_size: int = 1 max_concurrent: int = 5 deleted_dockeys: set[DocKey] = set() @@ -61,29 +72,46 @@ class Docs(BaseModel): jit_texts_index: bool = False # This is used to strip indirect citations that come up from the summary llm strip_citations: bool = True - verbose: bool = False def __init__(self, **data): + if "embedding_client" in data: + embedding_client = data.pop("embedding_client") + elif "client" in data: + embedding_client = data["client"] + else: + embedding_client = AsyncOpenAI() if "client" in data: client = data.pop("client") else: client = AsyncOpenAI() super().__init__(**data) self._client = client + self._embedding_client = embedding_client - @field_validator("llm_config", "summary_llm_config") + @model_validator(mode="before") @classmethod - def llm_guess_model_type(cls, v: dict) -> dict: - if v is not None and "model_type" not in v: - v["model_type"] = guess_model_type(v["model"]) - return v + def setup_alias_models(cls, data: Any) -> Any: + if isinstance(data, dict): + if "llm" in data and data["llm"] != "default": + if is_openai_model(data["llm"]): + data["llm_model"] = OpenAILLMModel(config=dict(model=data["llm"])) + else: + raise ValueError(f"Could not guess model type for {data['llm']}. ") + if "summary_llm" in data and data["summary_llm"] is not None: + if is_openai_model(data["summary_llm"]): + data["summary_llm_model"] = OpenAILLMModel( + config=dict(model=data["summary_llm"]) + ) + else: + raise ValueError(f"Could not guess model type for {data['llm']}. ") + return data @model_validator(mode="after") @classmethod - def config_summary_llm_conig(cls, data: Any) -> Any: + def config_summary_llm_config(cls, data: Any) -> Any: if isinstance(data, Docs): - if data.summary_llm_config is None: - data.summary_llm_config = data.llm_config + if data.summary_llm_model is None: + data.summary_llm_model = data.llm_model return data def clear_docs(self): @@ -95,16 +123,25 @@ def __getstate__(self): state = super().__getstate__() # remove client from private attributes del state["__pydantic_private__"]["_client"] + del state["__pydantic_private__"]["_embedding_client"] return state def __setstate__(self, state): super().__setstate__(state) self._client = None + self._embedding_client = None - def set_client(self, client: AsyncOpenAI | None = None): + def set_client( + self, + client: AsyncOpenAI | None = None, + embedding_client: AsyncOpenAI | None = None, + ): if client is None: client = AsyncOpenAI() self._client = client + if embedding_client is None: + embedding_client = client + self._embedding_client = embedding_client def _get_unique_name(self, docname: str) -> str: """Create a unique name given proposed name""" @@ -181,10 +218,9 @@ def add( dockey = md5sum(path) if citation is None: # skip system because it's too hesitant to answer - cite_chain = make_chain( + cite_chain = self.llm_model.make_chain( client=self._client, prompt=self.prompts.cite, - llm_config=cast(dict, self.summary_llm_config), skip_system=True, ) # peak first chunk @@ -251,15 +287,15 @@ def add_texts( doc.docname = new_docname if texts[0].embedding is None: text_embeddings = asyncio.run( - embed_documents( - self._client, [t.text for t in texts], self.embeddings_model + self.embedding.embed_documents( + self._embedding_client, [t.text for t in texts] ) ) for i, t in enumerate(texts): t.embedding = text_embeddings[i] if doc.embedding is None: doc.embedding = asyncio.run( - embed_documents(self._client, [doc.citation], self.embeddings_model) + self.embedding.embed_documents(self._embedding_client, [doc.citation]) )[0] if not self.jit_texts_index: self.texts_index.add_texts_and_embeddings(texts) @@ -289,7 +325,7 @@ async def adoc_match( ) -> set[DocKey]: """Return a list of dockeys that match the query.""" query_vector = ( - await embed_documents(self._client, [query], self.embeddings_model) + await self.embedding.embed_documents(self._embedding_client, [query]) )[0] matches, _ = self.doc_index.max_marginal_relevance_search( query_vector, @@ -304,13 +340,17 @@ async def adoc_match( try: if ( rerank is None - and self.llm_config["model"].startswith("gpt-4") + and ( + type(self.llm) == OpenAILLMModel + and cast(OpenAILLMModel, self) + .llm.config["model"] + .startswith("gpt-4") + ) or rerank is True ): - chain = make_chain( + chain = self.llm_model.make_chain( client=self._client, prompt=self.prompts.select, - llm_config=self.llm_config, skip_system=True, ) papers = [f"{d.docname}: {d.citation}" for d in matched_docs] @@ -387,9 +427,7 @@ async def aget_evidence( matches = self.texts else: query_vector = ( - await embed_documents( - self._client, [answer.question], self.embeddings_model - ) + await self.embedding.embed_documents(self._client, [answer.question]) )[0] if marginal_relevance: matches, _ = self.texts_index.max_marginal_relevance_search( @@ -421,10 +459,9 @@ async def process(match): context = match.text score = 5 else: - summary_chain = make_chain( + summary_chain = self.summary_llm_model.make_chain( client=self._client, prompt=self.prompts.summary, - llm_config=cast(dict, self.summary_llm_config), system_prompt=self.prompts.system, ) # This is dangerous because it @@ -548,10 +585,9 @@ async def aquery( get_callbacks=get_callbacks, ) if self.prompts.pre is not None: - chain = make_chain( + chain = self.llm_model.make_chain( client=self._client, prompt=self.prompts.pre, - llm_config=self.llm_config, system_prompt=self.prompts.system, ) pre = await chain(dict(question=answer.question), get_callbacks("pre")) @@ -562,13 +598,11 @@ async def aquery( "I cannot answer this question due to insufficient information." ) else: - qa_chain = make_chain( + qa_chain = self.llm_model.make_chain( client=self._client, prompt=self.prompts.qa, - llm_config=self.llm_config, system_prompt=self.prompts.system, ) - print(answer.context) answer_text = await qa_chain( dict( context=answer.context, @@ -577,7 +611,6 @@ async def aquery( ), get_callbacks("answer"), ) - print(answer_text) # it still happens if "(Example2012)" in answer_text: answer_text = answer_text.replace("(Example2012)", "") @@ -598,10 +631,9 @@ async def aquery( answer.references = bib_str if self.prompts.post is not None: - chain = make_chain( + chain = self.llm_model.make_chain( client=self._client, prompt=self.prompts.post, - llm_config=self.llm_config, system_prompt=self.prompts.system, ) post = await chain(answer.model_dump(), get_callbacks("post")) diff --git a/paperqa/llms.py b/paperqa/llms.py index 934435c95..aec1210b3 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1,7 +1,9 @@ import re -from typing import Any, Callable, Coroutine, get_args, get_type_hints +from abc import ABC, abstractmethod +from typing import Any, Callable, Coroutine, cast, get_args, get_type_hints from openai import AsyncOpenAI +from pydantic import BaseModel, Field, model_validator from .prompts import default_system_prompt @@ -19,6 +21,24 @@ def guess_model_type(model_name: str) -> str: return "completion" +def is_openai_model(model_name): + import openai + + model_type = get_type_hints( + openai.types.chat.completion_create_params.CompletionCreateParamsBase + )["model"] + model_union = get_args(get_args(model_type)[1]) + model_arr = list(model_union) + + complete_model_types = get_type_hints( + openai.types.completion_create_params.CompletionCreateParamsBase + )["model"] + complete_model_union = get_args(get_args(complete_model_types)[1]) + complete_model_arr = list(complete_model_union) + + return model_name in model_arr or model_name in complete_model_arr + + def process_llm_config(llm_config: dict) -> dict: """Remove model_type and try to set max_tokens""" result = {k: v for k, v in llm_config.items() if k != "model_type"} @@ -48,98 +68,163 @@ async def embed_documents( return [e.embedding for e in response.data] -def make_chain( - client: AsyncOpenAI, - prompt: str, - llm_config: dict, - skip_system: bool = False, - system_prompt: str = default_system_prompt, -) -> Callable[[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, str]]: - """Create a function to execute a batch of prompts - - Args: - client: OpenAI client - prompt: The prompt to use - llm_config: The config to use - skip_system: Whether to skip the system prompt - system_prompt: The system prompt to use - - Returns: - A function to execute a prompt. Its signature is: - execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> str - where data is a dict with keys for the input variables that will be formatted into prompt - and callbacks is a list of functions to call with each chunk of the completion. - """ - if client is None: - raise ValueError( - "Your client is None - did you forget to set it after pickling?" +class EmbeddingModel(ABC, BaseModel): + @abstractmethod + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + pass + + +class OpenAIEmbeddingModel(EmbeddingModel): + embedding_model: str = Field(default="text-embedding-ada-002") + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + return await embed_documents( + cast(AsyncOpenAI, client), texts, self.embedding_model ) - if llm_config["model_type"] == "chat": - system_message_prompt = dict(role="system", content=system_prompt) - human_message_prompt = dict(role="user", content=prompt) - if skip_system: - chat_prompt = [human_message_prompt] - else: - chat_prompt = [system_message_prompt, human_message_prompt] - - async def execute( - data: dict, callbacks: list[Callable[[str], None]] | None = None - ) -> str: - messages = chat_prompt[:-1] + [ - dict(role="user", content=chat_prompt[-1]["content"].format(**data)) - ] - if callbacks is None: - completion = await client.chat.completions.create( - messages=messages, **process_llm_config(llm_config) - ) - output = completion.choices[0].message.content + + +class LLMModel(ABC, BaseModel): + llm_type: str = "completion" + + @abstractmethod + async def acomplete(self, client: Any, prompt: str) -> str: + pass + + @abstractmethod + async def acomplete_iter(self, client: Any, prompt: str) -> Any: + """Return an async generator that yields chunks of the completion. + + I cannot get mypy to understand the override, so marked as Any""" + pass + + @abstractmethod + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + pass + + @abstractmethod + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + """Return an async generator that yields chunks of the completion. + + I cannot get mypy to understand the override, so marked as Any""" + pass + + def make_chain( + self, + client: Any, + prompt: str, + skip_system: bool = False, + system_prompt: str = default_system_prompt, + ) -> Callable[[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, str]]: + """Create a function to execute a batch of prompts + + Args: + client: a ephemeral client to use + prompt: The prompt to use + skip_system: Whether to skip the system prompt + system_prompt: The system prompt to use + + Returns: + A function to execute a prompt. Its signature is: + execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> str + where data is a dict with keys for the input variables that will be formatted into prompt + and callbacks is a list of functions to call with each chunk of the completion. + """ + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) + if self.llm_type == "chat": + system_message_prompt = dict(role="system", content=system_prompt) + human_message_prompt = dict(role="user", content=prompt) + if skip_system: + chat_prompt = [human_message_prompt] else: - completion = await client.chat.completions.create( - messages=messages, **process_llm_config(llm_config), stream=True - ) - result = [] - async for chunk in completion: - c = chunk.choices[0].delta.content - if c: - result.append(c) - [f(c) for f in callbacks] - output = "".join(result) - return output - - return execute - elif llm_config["model_type"] == "completion": - if skip_system: - completion_prompt = prompt - else: - completion_prompt = system_prompt + "\n\n" + prompt - - async def execute( - data: dict, callbacks: list[Callable[[str], None]] | None = None - ) -> str: - if callbacks is None: - completion = await client.completions.create( - prompt=completion_prompt.format(**data), - **process_llm_config(llm_config), - ) - output = completion.choices[0].text + chat_prompt = [system_message_prompt, human_message_prompt] + + async def execute( + data: dict, callbacks: list[Callable[[str], None]] | None = None + ) -> str: + messages = chat_prompt[:-1] + [ + dict(role="user", content=chat_prompt[-1]["content"].format(**data)) + ] + if callbacks is None: + output = await self.achat(client, messages) + else: + completion = self.achat_iter(client, messages) # type: ignore + result = [] + async for chunk in completion: # type: ignore + if chunk: + result.append(chunk) + [f(chunk) for f in callbacks] + output = "".join(result) + return output + + return execute + elif self.llm_type == "completion": + if skip_system: + completion_prompt = prompt else: - completion = await client.completions.create( - prompt=completion_prompt.format(**data), - **process_llm_config(llm_config), - stream=True, - ) - result = [] - async for chunk in completion: - c = chunk.choices[0].text - if c: - result.append(c) - [f(c) for f in callbacks] - output = "".join(result) - return output - - return execute - else: - raise NotImplementedError(f"Unknown model type {llm_config['model_type']}") + completion_prompt = system_prompt + "\n\n" + prompt + + async def execute( + data: dict, callbacks: list[Callable[[str], None]] | None = None + ) -> str: + if callbacks is None: + output = await self.acomplete( + client, completion_prompt.format(**data) + ) + else: + completion = self.acomplete_iter( # type: ignore + client, + completion_prompt.format(**data), + ) + result = [] + async for chunk in completion: # type: ignore + if chunk: + result.append(chunk) + [f(chunk) for f in callbacks] + output = "".join(result) + return output + + return execute + raise ValueError(f"Unknown llm_type: {self.llm_type}") + + +class OpenAILLMModel(LLMModel): + config: dict = Field(default=dict(model="gpt-3.5-turbo", temperature=0.1)) + + @model_validator(mode="after") + @classmethod + def guess_llm_type(cls, data: Any) -> Any: + m = cast(OpenAILLMModel, data) + m.llm_type = guess_model_type(m.config["model"]) + return m + + async def acomplete(self, client: Any, prompt: str) -> str: + completion = await client.completions.create( + prompt=prompt, **process_llm_config(self.config) + ) + return completion.choices[0].text + + async def acomplete_iter(self, client: Any, prompt: str) -> Any: + completion = await client.completions.create( + prompt=prompt, **process_llm_config(self.config), stream=True + ) + async for chunk in completion: + yield chunk.choices[0].text + + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + completion = await client.chat.completions.create( + messages=messages, **process_llm_config(self.config) + ) + return completion.choices[0].message.content + + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + completion = await client.chat.completions.create( + messages=messages, **process_llm_config(self.config), stream=True + ) + async for chunk in completion: + yield chunk.choices[0].delta.content def get_score(text: str) -> int: diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index ec17bfa89..df6180dc2 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -8,7 +8,7 @@ from openai import AsyncOpenAI from paperqa import Answer, Doc, Docs, PromptCollection, Text -from paperqa.llms import get_score, make_chain +from paperqa.llms import EmbeddingModel, OpenAILLMModel, get_score from paperqa.readers import read_doc from paperqa.utils import ( iter_citations, @@ -363,15 +363,10 @@ def test_extract_score(): class TestChains(IsolatedAsyncioTestCase): async def test_chain_completion(self): client = AsyncOpenAI() - call = make_chain( + llm = OpenAILLMModel(config=dict(model="babbage-002", temperature=0.2)) + call = llm.make_chain( client, "The {animal} says", - llm_config=dict( - model_type="completion", - temperature=0, - model="babbage-002", - max_tokens=56, - ), skip_system=True, ) outputs = [] @@ -385,12 +380,12 @@ def accum(x): async def test_chain_chat(self): client = AsyncOpenAI() - call = make_chain( + llm = OpenAILLMModel( + config=dict(temperature=0, model="gpt-3.5-turbo", max_tokens=56) + ) + call = llm.make_chain( client, "The {animal} says", - llm_config=dict( - model_type="chat", temperature=0, model="gpt-3.5-turbo", max_tokens=56 - ), skip_system=True, ) outputs = [] @@ -405,7 +400,7 @@ def accum(x): def test_docs(): llm_config = dict(temperature=0.1, model="text-ada-001", model_type="completion") - docs = Docs(llm_config=llm_config) + docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", @@ -463,6 +458,11 @@ def test_duplicate(): ) +def test_custom_embedding(): + class MyEmbeds(EmbeddingModel): + pass + + class Test(IsolatedAsyncioTestCase): async def test_aquery(self): docs = Docs() @@ -498,8 +498,10 @@ def test_docs_pickle(): # get front page of wikipedia r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") f.write(r.text) - docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) - old_config = docs.llm_config + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-3.5-turbo")) + ) + old_config = docs.llm_model.config docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) os.remove(doc_path) docs_pickle = pickle.dumps(docs) @@ -510,8 +512,8 @@ def test_docs_pickle(): except ValueError: pass docs2.set_client() - assert docs2.llm_config == old_config - assert docs2.summary_llm_config == old_config + assert docs2.llm_model.config == old_config + assert docs2.summary_llm_model.config == old_config assert len(docs.docs) == len(docs2.docs) context1, context2 = ( docs.get_evidence( @@ -555,7 +557,9 @@ def test_repeat_keys(): # get wiki page about politician r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") f.write(r.text) - docs = Docs(llm_config=dict(temperature=0.0, model="babbage-002")) + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="babbage-002")) + ) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") try: docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") @@ -584,7 +588,9 @@ def test_repeat_keys(): def test_pdf_reader(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") - docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-3.5-turbo")) + ) docs.add(doc_path, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -594,7 +600,11 @@ def test_fileio_reader_pdf(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) + docs = Docs( + llm_model=OpenAILLMModel( + config=dict(temperature=0.0, model="gpt-3.5-turbo") + ) + ) docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -602,7 +612,9 @@ def test_fileio_reader_pdf(): def test_fileio_reader_txt(): # can't use curie, because it has trouble with parsed HTML - docs = Docs(llm_config=dict(temperature=0.0, model="gpt-3.5-turbo")) + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-3.5-turbo")) + ) r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") if r.status_code != 200: raise ValueError("Could not download wikipedia page") @@ -652,7 +664,9 @@ def test_prompt_length(): def test_code(): # load this script doc_path = os.path.abspath(__file__) - docs = Docs(llm_config=dict(temperature=0.0, model="babbage-002")) + docs = Docs( + llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="babbage-002")) + ) docs.add(doc_path, "test_paperqa.py", docname="test_paperqa.py", disable_check=True) assert len(docs.docs) == 1 docs.query("What function tests the preview?") @@ -788,7 +802,6 @@ def test_custom_prompts(): f.write(r.text) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") answer = docs.query("What country is Frederick Bates from?") - print(answer.answer) assert "United States" in answer.answer @@ -922,6 +935,7 @@ def test_external_texts_index(): citation="Flag Day of Canada, WikiMedia Foundation, 2023, Accessed now", ) answer = docs.query(query="On which date is flag day annually observed?") + print(answer.model_dump()) assert "February 15" in answer.answer docs.add_url( From 04d0254a05b6b34bcd685bfe8a07ccbce2aac311 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 4 Jan 2024 21:48:41 -0800 Subject: [PATCH 06/16] Added unit tests for custom embeds/llms --- paperqa/docs.py | 6 ++++-- paperqa/llms.py | 32 ++++++++++++++++++++------------ tests/test_paperqa.py | 29 +++++++++++++++++++++++++++-- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/paperqa/docs.py b/paperqa/docs.py index 4490b04f9..8a291b654 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -76,7 +76,7 @@ class Docs(BaseModel): def __init__(self, **data): if "embedding_client" in data: embedding_client = data.pop("embedding_client") - elif "client" in data: + elif "client" in data and data["client"] is not None: embedding_client = data["client"] else: embedding_client = AsyncOpenAI() @@ -427,7 +427,9 @@ async def aget_evidence( matches = self.texts else: query_vector = ( - await self.embedding.embed_documents(self._client, [answer.question]) + await self.embedding.embed_documents( + self._embedding_client, [answer.question] + ) )[0] if marginal_relevance: matches, _ = self.texts_index.max_marginal_relevance_search( diff --git a/paperqa/llms.py b/paperqa/llms.py index aec1210b3..4fe3f2f2b 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -86,27 +86,23 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa class LLMModel(ABC, BaseModel): llm_type: str = "completion" - @abstractmethod async def acomplete(self, client: Any, prompt: str) -> str: - pass + raise NotImplementedError - @abstractmethod async def acomplete_iter(self, client: Any, prompt: str) -> Any: """Return an async generator that yields chunks of the completion. I cannot get mypy to understand the override, so marked as Any""" - pass + raise NotImplementedError - @abstractmethod async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: - pass + raise NotImplementedError - @abstractmethod async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: """Return an async generator that yields chunks of the completion. I cannot get mypy to understand the override, so marked as Any""" - pass + raise NotImplementedError def make_chain( self, @@ -129,10 +125,6 @@ def make_chain( where data is a dict with keys for the input variables that will be formatted into prompt and callbacks is a list of functions to call with each chunk of the completion. """ - if client is None: - raise ValueError( - "Your client is None - did you forget to set it after pickling?" - ) if self.llm_type == "chat": system_message_prompt = dict(role="system", content=system_prompt) human_message_prompt = dict(role="user", content=prompt) @@ -201,12 +193,20 @@ def guess_llm_type(cls, data: Any) -> Any: return m async def acomplete(self, client: Any, prompt: str) -> str: + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) completion = await client.completions.create( prompt=prompt, **process_llm_config(self.config) ) return completion.choices[0].text async def acomplete_iter(self, client: Any, prompt: str) -> Any: + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) completion = await client.completions.create( prompt=prompt, **process_llm_config(self.config), stream=True ) @@ -214,12 +214,20 @@ async def acomplete_iter(self, client: Any, prompt: str) -> Any: yield chunk.choices[0].text async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) completion = await client.chat.completions.create( messages=messages, **process_llm_config(self.config) ) return completion.choices[0].message.content async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) completion = await client.chat.completions.create( messages=messages, **process_llm_config(self.config), stream=True ) diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index df6180dc2..47a57bf94 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -8,7 +8,7 @@ from openai import AsyncOpenAI from paperqa import Answer, Doc, Docs, PromptCollection, Text -from paperqa.llms import EmbeddingModel, OpenAILLMModel, get_score +from paperqa.llms import EmbeddingModel, LLMModel, OpenAILLMModel, get_score from paperqa.readers import read_doc from paperqa.utils import ( iter_citations, @@ -460,7 +460,32 @@ def test_duplicate(): def test_custom_embedding(): class MyEmbeds(EmbeddingModel): - pass + async def embed_documents(self, client, texts): + return [[1, 2, 3] for _ in texts] + + docs = Docs(embedding=MyEmbeds()) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + assert docs.docs["test"].embedding == [1, 2, 3] + + +def test_custom_llm(): + class MyLLM(LLMModel): + async def acomplete(self, client, prompt): + assert client is None + return "Echo" + + docs = Docs(llm_model=MyLLM(), client=None) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + evidence = docs.get_evidence(Answer(question="Echo")) + assert "Echo" in evidence.context class Test(IsolatedAsyncioTestCase): From 5efe49a54cbbe040ac67ba82ce22299cfb54553b Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 10 Jan 2024 11:58:56 -0800 Subject: [PATCH 07/16] Fixed langchain compatibility and updated README --- README.md | 228 +++++++++++--------------------------- paperqa/__init__.py | 18 +++ paperqa/contrib/zotero.py | 5 +- paperqa/docs.py | 123 ++++++++++++++------ paperqa/llms.py | 172 +++++++++++++++++++++++----- paperqa/types.py | 6 +- paperqa/utils.py | 29 ++++- tests/test_paperqa.py | 129 +++++++++++++++++++-- 8 files changed, 460 insertions(+), 250 deletions(-) diff --git a/README.md b/README.md index a5fd3b2a3..cfc4b1c18 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,4 @@ # Paper QA- [Paper QA](#paper-qa) -- [Paper QA- Paper QA](#paper-qa--paper-qa) - - [Output Example](#output-example) - - [References](#references) - - [Hugging Face Demo](#hugging-face-demo) - - [Install](#install) - - [Usage](#usage) - - [Adding Documents](#adding-documents) - - [Choosing Model](#choosing-model) - - [Adjusting number of sources](#adjusting-number-of-sources) - - [Using Code or HTML](#using-code-or-html) - - [Version 3 Changes](#version-3-changes) - - [New Features](#new-features) - - [Naming](#naming) - - [Breaking Changes](#breaking-changes) - - [Notebooks](#notebooks) - - [Where do I get papers?](#where-do-i-get-papers) - - [Zotero](#zotero) - - [Paper Scraper](#paper-scraper) - - [PDF Reading Options](#pdf-reading-options) - - [Typewriter View](#typewriter-view) - - [LLM/Embedding Caching](#llmembedding-caching) - - [Caching Embeddings](#caching-embeddings) - - [Customizing Prompts](#customizing-prompts) - - [Pre and Post Prompts](#pre-and-post-prompts) - - [FAQ](#faq) - - [How is this different from LlamaIndex?](#how-is-this-different-from-llamaindex) - - [How is this different from LangChain?](#how-is-this-different-from-langchain) - - [Can I use different LLMs?](#can-i-use-different-llms) - - [Where do the documents come from?](#where-do-the-documents-come-from) - - [Can I save or load?](#can-i-save-or-load) - [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/whitead/paper-qa) [![tests](https://github.com/whitead/paper-qa/actions/workflows/tests.yml/badge.svg)](https://github.com/whitead/paper-qa) @@ -40,14 +9,26 @@ PDFs or text files (which can be raw HTML). It strives to give very good answers By default, it uses [OpenAI Embeddings](https://platform.openai.com/docs/guides/embeddings) with a vector DB called [FAISS](https://github.com/facebookresearch/faiss) to embed and search documents. However, via [langchain](https://github.com/hwchase17/langchain) you can use open-source models or embeddings (see details below). -PaperQA uses the process shown below: +paper-qa uses the process shown below: 1. embed docs into vectors 2. embed query into vector 3. search for top k passages in docs 4. create summary of each passage relevant to query -5. put summaries into prompt -6. generate answer with prompt +5. score and select only relevant summaries +6. put summaries into prompt +7. generate answer with prompt + +See our paper for more details: + +```bibtex +@article{lala2023paperqa, + title={PaperQA: Retrieval-Augmented Generative Agent for Scientific Research}, + author={L{\'a}la, Jakub and O'Donoghue, Odhran and Shtedritski, Aleksandar and Cox, Sam and Rodriques, Samuel G and White, Andrew D}, + journal={arXiv preprint arXiv:2312.07559}, + year={2023} +} +``` ## Output Example @@ -63,9 +44,10 @@ Tulevski2007: Tulevski, George S., et al. "Chemically assisted directed assembly Chen2014: Chen, Haitian, et al. "Large-scale complementary macroelectronics using hybrid integration of carbon nanotubes and IGZO thin-film transistors." Nature communications 5.1 (2014): 4097. -## Hugging Face Demo -[Hugging Face Demo](https://huggingface.co/spaces/whitead/paper-qa) +## Version 4 Changes + +Version 4 removed langchain from the package because it no longer supports pickling. This also simplifies the package a bit - especially prompts. Langchain can still be used, but it's not required. You can use any LLMs from langchain, but you will need to use the `LangchainLLMModel` class to wrap the model. ## Install @@ -75,17 +57,17 @@ Install with pip: pip install paper-qa ``` -## Usage +You need to have an LLM to use paper-qa. You can use OpenAI, llama.cpp (via Server), or any LLMs from langchain. OpenAI just works, as long as you have set your OpenAI API key (`export OPENAI_API_KEY=sk-...`). See instructions below for other LLMs. -Make sure you have set your OPENAI_API_KEY environment variable to your [openai api key](https://platform.openai.com/account/api-keys) +## Usage -To use paper-qa, you need to have a list of paths (valid extensions include: .pdf, .txt) and a list of citations (strings) that correspond to the paths. You can then use the `Docs` class to add the documents and then query them. If you don't have citations, `Docs` will try to guess them from the first page of your docs. +To use paper-qa, you need to have a list of paths/files/urls (valid extensions include: .pdf, .txt). You can then use the `Docs` class to add the documents and then query them. `Docs` will try to guess citation formats from the content of the files, but you can also provide them yourself. ```python from paperqa import Docs -# get a list of paths +my_docs = ...# get a list of paths docs = Docs() for d in my_docs: @@ -95,7 +77,7 @@ answer = docs.query("What manufacturing challenges are unique to bispecific anti print(answer.formatted_answer) ``` -The answer object has the following attributes: `formatted_answer`, `answer` (answer alone), `question`, `context` (the summaries of passages found for answer), `references` (the docs from which the passages came), and `passages` which contain the raw text of the passages as a dictionary. +The answer object has the following attributes: `formatted_answer`, `answer` (answer alone), `question` , and `context` (the summaries of passages found for answer). ### Adding Documents @@ -103,7 +85,7 @@ The answer object has the following attributes: `formatted_answer`, `answer` (an ### Choosing Model -By default, it uses a hybrid of `gpt-3.5-turbo` and `gpt-4`. If you don't have gpt-4 access or would like to save money, you can adjust: +By default, it uses a hybrid of `gpt-3.5-turbo` and `gpt-4-turbo`. If you don't have gpt-4 access or would like to save money, you can adjust: ```py docs = Docs(llm='gpt-3.5-turbo') @@ -112,51 +94,49 @@ docs = Docs(llm='gpt-3.5-turbo') or you can use any other model available in [langchain](https://github.com/hwchase17/langchain): ```py -from langchain.chat_models import ChatAnthropic, ChatOpenAI -model = ChatOpenAI(model='gpt-4') -summary_model = ChatAnthropic(model="claude-instant-v1-100k", anthropic_api_key="my-api-key") -docs = Docs(llm=model, summary_llm=summary_model) +from paperqa import Docs, LangchainLLMModel +from langchain_community.chat_models import ChatAnthropic +docs = Docs(llm_model=LangchainLLMModel(), + client=ChatAnthropic()) +``` + +Notice that we split the model into `LangchainLLMModel` and `client` which is `ChatAnthropic`. This is because paper-qa can be pickled, but typically Langchain models cannot be pickled. Thus, the client is the unpicklable part. Specifically, you can save your state in paper-qa: + +```py +import pickle +docs = Docs(llm_model=LangchainLLMModel(), + client=ChatAnthropic()) +model_str = pickle.dumps(docs) +docs = pickle.loads(model_str) +# but you have to set the client after loading +docs.set_client(ChatAnthropic()) ``` #### Locally Hosted -You can also use any other models (or embeddings) available in [langchain](https://github.com/hwchase17/langchain). Here's an example of using `llama.cpp` to have locally hosted paper-qa: +You can use llama.cpp to be the LLM. Note that you should be using relatively large models, because paper-qa requires following a lot of instructions. You won't get good performance with 7B models. I recommend using the SentenceTransformer models for embeddings, rather than llama.cpp embeddings. ```py -import paperscraper -from paperqa import Docs -from langchain.llms import LlamaCpp -from langchain import PromptTemplate, LLMChain -from langchain.callbacks.manager import CallbackManager -from langchain.embeddings import LlamaCppEmbeddings -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler - -# Make sure the model path is correct for your system! -llm = LlamaCpp( - model_path="./ggml-model-q4_0.bin", callbacks=[StreamingStdOutCallbackHandler()] -) -embeddings = LlamaCppEmbeddings(model_path="./ggml-model-q4_0.bin") +from paperqa import Docs, SentenceTransformerEmbeddingModel +from openai import AsyncOpenAI -docs = Docs(llm=llm, embeddings=embeddings) +# start llamap.cpp client with: -cb -np 4 -a my-alais --embedding -keyword_search = 'bispecific antibody manufacture' -papers = paperscraper.search_papers(keyword_search, limit=2) -for path,data in papers.items(): - try: - docs.add(path,chunk_chars=500) - except ValueError as e: - print('Could not read', path, e) - -answer = docs.query("What manufacturing challenges are unique to bispecific antibodies?") -print(answer) +local_client = AsyncOpenAI( + base_url="http://localhost:8080/v1", + api_key = "sk-no-key-required" +) +docs = Docs(client=local_client, + embedding=SentenceTransformerEmbeddingModel(), + llm_model=OpenAILLMModel(config=dict(model="my-alias", temperature=0.1, frequency_penalty=1.5, max_tokens=512))) ``` ### Adjusting number of sources You can adjust the numbers of sources (passages of text) to reduce token usage or add more context. `k` refers to the top k most relevant and diverse (may from different sources) passages. Each passage is sent to the LLM to summarize, or determine if it is irrelevant. After this step, a limit of `max_sources` is applied so that the final answer can fit into the LLM context window. Thus, `k` > `max_sources` and `max_sources` is the number of sources used in the final answer. -```python +```py docs.query("What manufacturing challenges are unique to bispecific antibodies?", k = 5, max_sources = 2) ``` @@ -178,67 +158,6 @@ answer = docs.query("Where is the search bar in the header defined?") print(answer) ``` -## Version 3 Changes - -Version 3 includes many changes to type the code, make it more focused/modular, and enable performance to very large numbers of documents. The major breaking changes are documented below: - - -### New Features - -The following new features are in v3: - -1. Memory is now possible in `query` by setting `Docs(memory=True)` - this means follow-up questions will have a record of the previous question and answer. -2. `add_url` and `add_file` are now supported for adding from URLs and file objects -3. Prompts can be customized, and now can be executed pre and post query -4. Consistent use of `dockey` and `docname` for unique and natural language names enable better tracking with external databases -5. Texts and embeddings are no longer required to be part of `Docs` object, so you can use external databases or other strategies to manage them -6. Various simplifications, bug fixes, and performance improvements - -### Naming - -The following table shows the old names and the new names: - -| Old Name | New Name | Explanation | -| :--- | :---: | ---: | -| `key` | `name` | Name is a natural language name for text. | -| `dockey` | `docname` | Docname is a natural language name for a document. | -| `hash` | `dockey` | Dockey is a unique identifier for the document. | - - -### Breaking Changes - - -#### Pickled objects - -The pickled objects are not compatible with the new version. - -#### Agents - -The agent functionality has been removed, as it's not a core focus of the library - -#### Caching - -Caching has been removed because it's not a core focus of the library. See FAQ below for how to use caching. - -#### Answers - -Answers will not include passages, but instead return dockeys that can be used to retrieve the passages. Tokens/cost will also not be counted since that is built into langchain by default (see below for an example). - -#### Search Query - -The search query chain has been removed. You can use langchain directly to do this. - -## Notebooks - -If you want to use this in an jupyter notebook or colab, you need to run the following command: - -```python -import nest_asyncio -nest_asyncio.apply() -``` - -Also - if you know how to make this automated, please let me know! - ## Where do I get papers? Well that's a really good question! It's probably best to just download PDFs of papers you think will help answer your question and start from there. @@ -329,27 +248,17 @@ By default [PyPDF](https://pypi.org/project/pypdf/) is used since it's pure pyth pip install pymupdf ``` -## Typewriter View +## Callbacks Factory -To stream the completions as they occur (giving that ChatGPT typewriter look), you can simply instantiate models with those properties: +To execute a function on each chunk of LLM completions, you need to provide a function that when called with the name of the step produces a list of functions to execute on each chunk. For example, to get a typewriter view of the completions, you can do: ```python -from paperqa import Docs -from langchain.callbacks.manager import CallbackManager -from langchain.chat_models import ChatOpenAI -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler - -my_llm = ChatOpenAI(callbacks=[StreamingStdOutCallbackHandler()], streaming=True) -docs = Docs(llm=my_llm) -``` - -## LLM/Embedding Caching - -You can using the builtin langchain caching capabilities. Just run this code at the top of yours: - -```py -from langchain.cache import InMemoryCache -langchain.llm_cache = InMemoryCache() +def make_typewriter(step_name): + def typewriter(chunk): + print(chunk, end="") + return [typewriter] # <- note that this is a list of functions +... +docs.query("What manufacturing challenges are unique to bispecific antibodies?", get_callbacks=make_typewriter) ``` ### Caching Embeddings @@ -366,17 +275,14 @@ You can customize any of the prompts, using the `PromptCollection` class. For ex ```python from paperqa import Docs, Answer, PromptCollection -from langchain.prompts import PromptTemplate -my_qaprompt = PromptTemplate( - input_variables=["context", "question"], - template="Answer the question '{question}' " +my_qaprompt = "Answer the question '{question}' " "Use the context below if helpful. " "You can cite the context using the key " "like (Example2012). " "If there is insufficient context, write a poem " "about how you cannot answer.\n\n" - "Context: {context}\n\n") + "Context: {context}\n\n" prompts=PromptCollection(qa=my_qaprompt) docs = Docs(prompts=prompts) ``` @@ -395,15 +301,7 @@ It's not that different! This is similar to the tree response method in LlamaInd ### How is this different from LangChain? -It's not! We use langchain to abstract the LLMS, and the process is very similar to the `map_reduce` chain in LangChain. - -### Can I use different LLMs? - -Yes, you can use any LLMs from [langchain](https://langchain.readthedocs.io/) by passing the `llm` argument to the `Docs` class. You can use different LLMs for summarization and for question answering too. - -### Where do the documents come from? - -You can provide your own. I use some of my own code to pull papers from Google Scholar. This code is not included because it may enable people to violate Google's terms of service and publisher's terms of service. +There has been some great work on retrievers in langchain and you could say this is an example of a retreiver. ### Can I save or load? diff --git a/paperqa/__init__.py b/paperqa/__init__.py index fa06c8925..1fc9643c6 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -1,5 +1,15 @@ from .docs import Answer, Docs, PromptCollection, Doc, Text, Context from .version import __version__ +from .llms import ( + LLMModel, + EmbeddingModel, + LangchainEmbeddingModel, + OpenAIEmbeddingModel, + LangchainLLMModel, + OpenAILLMModel, + LlamaEmbeddingModel, + SentenceTransformerEmbeddingModel, +) __all__ = [ "Docs", @@ -9,4 +19,12 @@ "Doc", "Text", "Context", + "LLMModel", + "EmbeddingModel", + "OpenAIEmbeddingModel", + "OpenAILLMModel", + "LangchainLLMModel", + "LlamaEmbeddingModel", + "SentenceTransformerEmbeddingModel", + "LangchainEmbeddingModel", ] diff --git a/paperqa/contrib/zotero.py b/paperqa/contrib/zotero.py index a390cd1c3..1d4330dc9 100644 --- a/paperqa/contrib/zotero.py +++ b/paperqa/contrib/zotero.py @@ -4,10 +4,7 @@ from pathlib import Path from typing import List, Optional, Union, cast -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel +from pydantic import BaseModel try: from pyzotero import zotero diff --git a/paperqa/docs.py b/paperqa/docs.py index 8a291b654..010256089 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -13,6 +13,9 @@ from .llms import ( EmbeddingModel, + LangchainEmbeddingModel, + LangchainLLMModel, + LlamaEmbeddingModel, LLMModel, OpenAIEmbeddingModel, OpenAILLMModel, @@ -50,14 +53,17 @@ class Docs(BaseModel): """A collection of documents to be used for answering questions.""" - # ephemeral clients that should not be pickled + # ephemeral vars that should not be pickled (_things) _client: Any | None _embedding_client: Any | None llm: str = "default" summary_llm: str | None = None - llm_model: LLMModel = Field(default_factory=OpenAILLMModel) + llm_model: LLMModel = Field( + default=OpenAILLMModel(config=dict(model="gpt-4-1106-preview", temperature=0.1)) + ) summary_llm_model: LLMModel | None = Field(default=None, validate_default=True) - embedding: EmbeddingModel = OpenAIEmbeddingModel() + embedding: str | None = "default" + embedding_model: EmbeddingModel = OpenAIEmbeddingModel() docs: dict[DocKey, Doc] = {} texts: list[Text] = [] docnames: set[str] = set() @@ -66,7 +72,7 @@ class Docs(BaseModel): name: str = "default" index_path: Path | None = PAPERQA_DIR / name batch_size: int = 1 - max_concurrent: int = 5 + max_concurrent: int = 4 deleted_dockeys: set[DocKey] = set() prompts: PromptCollection = PromptCollection() jit_texts_index: bool = False @@ -74,16 +80,36 @@ class Docs(BaseModel): strip_citations: bool = True def __init__(self, **data): + # TODO: There may be a way to put this into pydantic model validator + # We do it here because we need to move things to private attributes if "embedding_client" in data: embedding_client = data.pop("embedding_client") - elif "client" in data and data["client"] is not None: + # convenience to pull embedding_client from client if reasonable + elif ( + "client" in data + and data["client"] is not None + and type(data["client"]) == AsyncOpenAI + ): + # convenience embedding_client = data["client"] else: - embedding_client = AsyncOpenAI() + # if embedding_model is explicitly set, but not client then make it None + if "embedding_model" in data and data["embedding_model"] is not None: + embedding_client = None + else: + embedding_client = AsyncOpenAI() if "client" in data: client = data.pop("client") else: - client = AsyncOpenAI() + # if llm_model is explicitly set, but not client then make it None + if "llm_model" in data and data["llm_model"] is not None: + # except if it is an OpenAILLMModel + if type(data["llm_model"]) == OpenAILLMModel: + client = AsyncOpenAI() + else: + client = None + else: + client = AsyncOpenAI() super().__init__(**data) self._client = client self._embedding_client = embedding_client @@ -95,6 +121,8 @@ def setup_alias_models(cls, data: Any) -> Any: if "llm" in data and data["llm"] != "default": if is_openai_model(data["llm"]): data["llm_model"] = OpenAILLMModel(config=dict(model=data["llm"])) + elif data["llm"] == "langchain": + data["llm_model"] = LangchainLLMModel() else: raise ValueError(f"Could not guess model type for {data['llm']}. ") if "summary_llm" in data and data["summary_llm"] is not None: @@ -104,13 +132,32 @@ def setup_alias_models(cls, data: Any) -> Any: ) else: raise ValueError(f"Could not guess model type for {data['llm']}. ") + if "embedding" in data and data["embedding"] != "default": + if data["embedding"] == "langchain": + data["embedding_model"] = LangchainEmbeddingModel() + elif data["embedding"] == "llama": + data["embedding_model"] = LlamaEmbeddingModel() + else: + raise ValueError( + f"Could not guess model type for {data['embedding']}. " + ) return data @model_validator(mode="after") @classmethod def config_summary_llm_config(cls, data: Any) -> Any: if isinstance(data, Docs): - if data.summary_llm_model is None: + # check our default gpt-4/3.5-turbo config + # default check is hard - becauise either llm is set or llm_model is set + if ( + data.summary_llm_model is None + and data.llm == "default" + and type(data.llm_model) == OpenAILLMModel + ): + data.summary_llm_model = OpenAILLMModel( + config=dict(model="gpt-3.5-turbo", temperature=0.1) + ) + elif data.summary_llm_model is None: data.summary_llm_model = data.llm_model return data @@ -140,7 +187,10 @@ def set_client( client = AsyncOpenAI() self._client = client if embedding_client is None: - embedding_client = client + if type(client) == AsyncOpenAI: + embedding_client = client + else: + embedding_client = AsyncOpenAI() self._embedding_client = embedding_client def _get_unique_name(self, docname: str) -> str: @@ -287,7 +337,7 @@ def add_texts( doc.docname = new_docname if texts[0].embedding is None: text_embeddings = asyncio.run( - self.embedding.embed_documents( + self.embedding_model.embed_documents( self._embedding_client, [t.text for t in texts] ) ) @@ -295,7 +345,9 @@ def add_texts( t.embedding = text_embeddings[i] if doc.embedding is None: doc.embedding = asyncio.run( - self.embedding.embed_documents(self._embedding_client, [doc.citation]) + self.embedding_model.embed_documents( + self._embedding_client, [doc.citation] + ) )[0] if not self.jit_texts_index: self.texts_index.add_texts_and_embeddings(texts) @@ -305,8 +357,16 @@ def add_texts( self.docnames.add(doc.docname) return True - def delete(self, name: str | None = None, dockey: DocKey | None = None) -> None: + def delete( + self, + name: str | None = None, + docname: str | None = None, + dockey: DocKey | None = None, + ) -> None: """Delete a document from the collection.""" + # name is an alias for docname + name = docname if name is None else name + if name is not None: doc = next((doc for doc in self.docs.values() if doc.docname == name), None) if doc is None: @@ -325,7 +385,7 @@ async def adoc_match( ) -> set[DocKey]: """Return a list of dockeys that match the query.""" query_vector = ( - await self.embedding.embed_documents(self._embedding_client, [query]) + await self.embedding_model.embed_documents(self._embedding_client, [query]) )[0] matches, _ = self.doc_index.max_marginal_relevance_search( query_vector, @@ -333,7 +393,10 @@ async def adoc_match( fetch_k=5 * (k + len(self.deleted_dockeys)), ) # filter the matches - matched_docs = [m for m in matches if m.dockey not in self.deleted_dockeys] + matched_docs = [ + m for m in cast(list[Doc], matches) if m.dockey not in self.deleted_dockeys + ] + if len(matched_docs) == 0: return set() # this only works for gpt-4 (in my testing) @@ -341,10 +404,8 @@ async def adoc_match( if ( rerank is None and ( - type(self.llm) == OpenAILLMModel - and cast(OpenAILLMModel, self) - .llm.config["model"] - .startswith("gpt-4") + type(self.llm_model) == OpenAILLMModel + and cast(OpenAILLMModel, self).config["model"].startswith("gpt-4") ) or rerank is True ): @@ -386,22 +447,18 @@ def get_evidence( answer: Answer, k: int = 10, max_sources: int = 5, - marginal_relevance: bool = True, get_callbacks: CallbackFactory = lambda x: None, detailed_citations: bool = False, disable_vector_search: bool = False, - disable_summarization: bool = False, ) -> Answer: return asyncio.run( self.aget_evidence( answer, k=k, max_sources=max_sources, - marginal_relevance=marginal_relevance, get_callbacks=get_callbacks, detailed_citations=detailed_citations, disable_vector_search=disable_vector_search, - disable_summarization=disable_summarization, ) ) @@ -410,11 +467,9 @@ async def aget_evidence( answer: Answer, k: int = 10, # Number of evidence pieces to retrieve max_sources: int = 5, # Number of scored contexts to use - marginal_relevance: bool = True, get_callbacks: CallbackFactory = lambda x: None, detailed_citations: bool = False, disable_vector_search: bool = False, - disable_summarization: bool = False, ) -> Answer: if len(self.docs) == 0 and self.doc_index is None: # do we have no docs? @@ -427,16 +482,16 @@ async def aget_evidence( matches = self.texts else: query_vector = ( - await self.embedding.embed_documents( + await self.embedding_model.embed_documents( self._embedding_client, [answer.question] ) )[0] - if marginal_relevance: - matches, _ = self.texts_index.max_marginal_relevance_search( + matches = cast( + list[Text], + self.texts_index.max_marginal_relevance_search( query_vector, k=_k, fetch_k=5 * _k - ) - else: - matches, _ = self.texts_index.similarity_search(query_vector, k=_k) + )[0], + ) # ok now filter (like ones from adoc_match) if answer.dockey_filter is not None: matches = [m for m in matches if m.doc.dockey in answer.dockey_filter] @@ -457,7 +512,7 @@ async def process(match): if detailed_citations: citation = match.name + ": " + citation - if self.prompts.skip_summary or disable_summarization: + if self.prompts.skip_summary: context = match.text score = 5 else: @@ -509,7 +564,7 @@ async def process(match): return c results = await gather_with_concurrency( - self.max_concurrent, *[process(m) for m in matches] + self.max_concurrent, [process(m) for m in matches] ) # filter out failures contexts = [c for c in results if c is not None] @@ -537,7 +592,6 @@ def query( k: int = 10, max_sources: int = 5, length_prompt="about 100 words", - marginal_relevance: bool = True, answer: Answer | None = None, key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, @@ -548,7 +602,6 @@ def query( k=k, max_sources=max_sources, length_prompt=length_prompt, - marginal_relevance=marginal_relevance, answer=answer, key_filter=key_filter, get_callbacks=get_callbacks, @@ -561,7 +614,6 @@ async def aquery( k: int = 10, max_sources: int = 5, length_prompt: str = "about 100 words", - marginal_relevance: bool = True, answer: Answer | None = None, key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, @@ -583,7 +635,6 @@ async def aquery( answer, k=k, max_sources=max_sources, - marginal_relevance=marginal_relevance, get_callbacks=get_callbacks, ) if self.prompts.pre is not None: diff --git a/paperqa/llms.py b/paperqa/llms.py index 4fe3f2f2b..9c3887213 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1,11 +1,20 @@ import re from abc import ABC, abstractmethod -from typing import Any, Callable, Coroutine, cast, get_args, get_type_hints +from typing import ( + Any, + AsyncGenerator, + Callable, + Coroutine, + cast, + get_args, + get_type_hints, +) from openai import AsyncOpenAI from pydantic import BaseModel, Field, model_validator from .prompts import default_system_prompt +from .utils import batch_iter, flatten, gather_with_concurrency def guess_model_type(model_name: str) -> str: @@ -84,7 +93,7 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa class LLMModel(ABC, BaseModel): - llm_type: str = "completion" + llm_type: str | None = None async def acomplete(self, client: Any, prompt: str) -> str: raise NotImplementedError @@ -104,6 +113,9 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: I cannot get mypy to understand the override, so marked as Any""" raise NotImplementedError + def infer_llm_type(self, client: Any) -> str: + return "completion" + def make_chain( self, client: Any, @@ -125,6 +137,9 @@ def make_chain( where data is a dict with keys for the input variables that will be formatted into prompt and callbacks is a list of functions to call with each chunk of the completion. """ + # check if it needs to be set + if self.llm_type is None: + self.llm_type = self.infer_llm_type(client) if self.llm_type == "chat": system_message_prompt = dict(role="system", content=system_prompt) human_message_prompt = dict(role="user", content=prompt) @@ -185,6 +200,17 @@ async def execute( class OpenAILLMModel(LLMModel): config: dict = Field(default=dict(model="gpt-3.5-turbo", temperature=0.1)) + def _check_client(self, client: Any) -> AsyncOpenAI: + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) + if not isinstance(client, AsyncOpenAI): + raise ValueError( + f"Your client is not a required AsyncOpenAI client. It is a {type(client)}" + ) + return cast(AsyncOpenAI, client) + @model_validator(mode="after") @classmethod def guess_llm_type(cls, data: Any) -> Any: @@ -193,48 +219,142 @@ def guess_llm_type(cls, data: Any) -> Any: return m async def acomplete(self, client: Any, prompt: str) -> str: - if client is None: - raise ValueError( - "Your client is None - did you forget to set it after pickling?" - ) - completion = await client.completions.create( + aclient = self._check_client(client) + completion = await aclient.completions.create( prompt=prompt, **process_llm_config(self.config) ) return completion.choices[0].text async def acomplete_iter(self, client: Any, prompt: str) -> Any: - if client is None: - raise ValueError( - "Your client is None - did you forget to set it after pickling?" - ) - completion = await client.completions.create( + aclient = self._check_client(client) + completion = await aclient.completions.create( prompt=prompt, **process_llm_config(self.config), stream=True ) async for chunk in completion: yield chunk.choices[0].text async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: - if client is None: - raise ValueError( - "Your client is None - did you forget to set it after pickling?" - ) - completion = await client.chat.completions.create( - messages=messages, **process_llm_config(self.config) + aclient = self._check_client(client) + completion = await aclient.chat.completions.create( + messages=messages, **process_llm_config(self.config) # type: ignore ) - return completion.choices[0].message.content + return completion.choices[0].message.content or "" async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: - if client is None: - raise ValueError( - "Your client is None - did you forget to set it after pickling?" - ) - completion = await client.chat.completions.create( - messages=messages, **process_llm_config(self.config), stream=True + aclient = self._check_client(client) + completion = await aclient.chat.completions.create( + messages=messages, **process_llm_config(self.config), stream=True # type: ignore ) - async for chunk in completion: + async for chunk in cast(AsyncGenerator, completion): yield chunk.choices[0].delta.content +class LangchainLLMModel(LLMModel): + """A wrapper around the wrapper langchain""" + + def infer_llm_type(self, client: Any) -> str: + from langchain_core.language_models.chat_models import BaseChatModel + + if isinstance(client, BaseChatModel): + return "chat" + return "completion" + + async def acomplete(self, client: Any, prompt: str) -> str: + return await client.ainvoke(prompt) + + async def acomplete_iter(self, client: Any, prompt: str) -> Any: + async for chunk in cast(AsyncGenerator, client.astream(prompt)): + yield chunk + + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + + lc_messages: list[BaseMessage] = [] + for m in messages: + if m["role"] == "user": + lc_messages.append(HumanMessage(content=m["content"])) + elif m["role"] == "system": + lc_messages.append(SystemMessage(content=m["content"])) + else: + raise ValueError(f"Unknown role: {m['role']}") + return (await client.ainvoke(lc_messages)).content + + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + + lc_messages: list[BaseMessage] = [] + for m in messages: + if m["role"] == "user": + lc_messages.append(HumanMessage(content=m["content"])) + elif m["role"] == "system": + lc_messages.append(SystemMessage(content=m["content"])) + else: + raise ValueError(f"Unknown role: {m['role']}") + async for chunk in client.astream(lc_messages): + yield chunk.content + + +class LangchainEmbeddingModel(EmbeddingModel): + """A wrapper around the wrapper langchain""" + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + return await client.aembed_documents(texts) + + +class LlamaEmbeddingModel(EmbeddingModel): + embedding_model: str = Field(default="llama") + + batch_size: int = Field(default=4) + concurrency: int = Field(default=1) + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + cast(AsyncOpenAI, client) + + async def process(texts: list[str]) -> list[float]: + for i in range(3): + # access httpx client directly to avoid type casting + response = await client._client.post( + client.base_url.join("../embedding"), json={"content": texts} + ) + body = response.json() + if len(texts) == 1: + if type(body) != dict or body.get("embedding") is None: + continue + return [body["embedding"]] + else: + if type(body) != list or body[0] != "results": + continue + return [e["embedding"] for e in body[1]] + raise ValueError("Failed to embed documents - response was ", body) + + return flatten( + await gather_with_concurrency( + self.concurrency, + [process(b) for b in batch_iter(texts, self.batch_size)], + ) + ) + + +class SentenceTransformerEmbeddingModel(EmbeddingModel): + embedding_model: str = Field(default="multi-qa-MiniLM-L6-cos-v1") + _model: Any = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError("Please install sentence-transformers to use this model") + + self._model = SentenceTransformer(self.embedding_model) + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + from sentence_transformers import SentenceTransformer + + embeddings = cast(SentenceTransformer, self._model).encode(texts) + return embeddings + + def get_score(text: str) -> int: # check for N/A last_line = text.split("\n")[-1] diff --git a/paperqa/types.py b/paperqa/types.py index da4705ab0..8eb9278bb 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -2,7 +2,7 @@ from typing import Any, Callable import numpy as np -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, Sequence, field_validator from .prompts import ( citation_prompt, @@ -44,7 +44,7 @@ class VectorStore(BaseModel, ABC): """Interface for vector store - very similar to LangChain's VectorStore to be compatible""" @abstractmethod - def add_texts_and_embeddings(self, texts: list[Embeddable]) -> None: + def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: pass @abstractmethod @@ -196,8 +196,6 @@ def check_select(cls, v: str) -> str: @classmethod def check_pre(cls, v: str | None) -> str | None: if v is not None: - print(v) - print(get_formatted_variables(v)) if set(get_formatted_variables(v)) != set(["question"]): raise ValueError("Pre prompt must have input variables: question") return v diff --git a/paperqa/utils.py b/paperqa/utils.py index 2b8930bbb..76105aa7d 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -3,7 +3,7 @@ import re import string from pathlib import Path -from typing import BinaryIO, List, Union +from typing import Any, BinaryIO, Coroutine, Iterator, Union import pypdf @@ -75,7 +75,7 @@ def md5sum(file_path: StrPath) -> str: return hashlib.md5(f.read()).hexdigest() -async def gather_with_concurrency(n: int, *coros: List) -> List: +async def gather_with_concurrency(n: int, coros: list[Coroutine]) -> list[Any]: # https://stackoverflow.com/a/61478547/2392535 semaphore = asyncio.Semaphore(n) @@ -100,7 +100,7 @@ def strip_citations(text: str) -> str: return text -def iter_citations(text: str) -> List[str]: +def iter_citations(text: str) -> list[str]: # Combined regex for identifying citations (see unit tests for examples) citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" result = re.findall(citation_regex, text, flags=re.MULTILINE) @@ -123,3 +123,26 @@ def extract_doi(reference: str) -> str: return "https://doi.org/" + doi_match.group() else: return "" + + +def batch_iter(iterable: list, n: int = 1) -> Iterator[list]: + """ + Batch an iterable into chunks of size n + + :param iterable: The iterable to batch + :param n: The size of the batches + :return: A list of batches + """ + length = len(iterable) + for ndx in range(0, length, n): + yield iterable[ndx : min(ndx + n, length)] + + +def flatten(iteratble: list) -> list: + """ + Flatten a list of lists + + :param l: The list of lists to flatten + :return: A flattened list + """ + return [item for sublist in iteratble for item in sublist] diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 47a57bf94..fec43a80d 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -8,7 +8,14 @@ from openai import AsyncOpenAI from paperqa import Answer, Doc, Docs, PromptCollection, Text -from paperqa.llms import EmbeddingModel, LLMModel, OpenAILLMModel, get_score +from paperqa.llms import ( + EmbeddingModel, + LangchainEmbeddingModel, + LangchainLLMModel, + LLMModel, + OpenAILLMModel, + get_score, +) from paperqa.readers import read_doc from paperqa.utils import ( iter_citations, @@ -122,14 +129,13 @@ def test_ablations(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs() + docs = Docs(prompts=PromptCollection(skip_summary=True)) docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.get_evidence( Answer( question="Which page is the statement 'Deep learning (DL) is advancing the boundaries of computational" + "chemistry because it can accurately model non-linear structure-function relationships.' on?" - ), - disable_summarization=True, + ) ) assert ( answer.contexts[0].text.text == answer.contexts[0].context @@ -463,13 +469,14 @@ class MyEmbeds(EmbeddingModel): async def embed_documents(self, client, texts): return [[1, 2, 3] for _ in texts] - docs = Docs(embedding=MyEmbeds()) + docs = Docs(embedding_model=MyEmbeds()) docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", ) assert docs.docs["test"].embedding == [1, 2, 3] + assert docs._embedding_client is None def test_custom_llm(): @@ -488,6 +495,99 @@ async def acomplete(self, client, prompt): assert "Echo" in evidence.context +def test_custom_llm_stream(): + class MyLLM(LLMModel): + async def acomplete_iter(self, client, prompt): + assert client is None + yield "Echo" + + docs = Docs(llm_model=MyLLM(), client=None) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + evidence = docs.get_evidence( + Answer(question="Echo"), get_callbacks=lambda x: [lambda y: print(y, end="")] + ) + assert "Echo" in evidence.context + + +def test_langchain_llm(): + from langchain_openai import ChatOpenAI, OpenAI + + docs = Docs(llm="langchain", client=ChatOpenAI(model="gpt-3.5-turbo")) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + assert docs._client is not None + assert type(docs.llm_model) == LangchainLLMModel + assert docs.summary_llm_model == docs.llm_model + + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?"), + get_callbacks=lambda x: [lambda y: print(y, end="")], + ) + + assert docs.llm_model.llm_type == "chat" + + # trying without callbacks (different codepath) + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?") + ) + + # now completion + + docs = Docs(llm_model=LangchainLLMModel(), client=OpenAI(model="babbage-002")) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?"), + get_callbacks=lambda x: [lambda y: print(y, end="")], + ) + + assert docs.summary_llm_model.llm_type == "completion" + + # trying without callbacks (different codepath) + docs.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?") + ) + + # now make sure we can pickle it + docs_pickle = pickle.dumps(docs) + docs2 = pickle.loads(docs_pickle) + assert docs2._client is None + docs2.set_client(OpenAI(model="babbage-002")) + docs2.get_evidence( + Answer(question="What is Frederick Bates's greatest accomplishment?"), + get_callbacks=lambda x: [lambda y: print(y)], + ) + + +def test_langchain_embeddings(): + from langchain_openai import OpenAIEmbeddings + + docs = Docs( + embedding_model=LangchainEmbeddingModel(), embedding_client=OpenAIEmbeddings() + ) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + docs = Docs(embedding="langchain", embedding_client=OpenAIEmbeddings()) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + + class Test(IsolatedAsyncioTestCase): async def test_aquery(self): docs = Docs() @@ -526,7 +626,9 @@ def test_docs_pickle(): docs = Docs( llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-3.5-turbo")) ) + assert docs._client is not None old_config = docs.llm_model.config + old_sconfig = docs.summary_llm_model.config docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) os.remove(doc_path) docs_pickle = pickle.dumps(docs) @@ -537,8 +639,9 @@ def test_docs_pickle(): except ValueError: pass docs2.set_client() + assert docs2._client is not None assert docs2.llm_model.config == old_config - assert docs2.summary_llm_model.config == old_config + assert docs2.summary_llm_model.config == old_sconfig assert len(docs.docs) == len(docs2.docs) context1, context2 = ( docs.get_evidence( @@ -742,18 +845,20 @@ def test_dockey_delete(): with open("example.txt", "w", encoding="utf-8") as f: f.write(r.text) f.write("\n\nBates could be from Angola") # so we don't have same hash - docs.add("example.txt", "WikiMedia Foundation, 2023, Accessed now", dockey="test") + docs.add("example.txt", "WikiMedia Foundation, 2023, Accessed now", docname="test") answer = Answer(question="What country was Bates born in?") - answer = docs.get_evidence(answer, marginal_relevance=False) - print(answer) + answer = docs.get_evidence( + answer, max_sources=25, k=30 + ) # we just have a lot so we get both docs keys = set([c.text.doc.dockey for c in answer.contexts]) assert len(keys) == 2 assert len(docs.docs) == 2 - docs.delete(dockey="test") - assert len(docs.docs) == 1 + docs.delete(docname="test") answer = Answer(question="What country was Bates born in?") - answer = docs.get_evidence(answer, marginal_relevance=False) + assert len(docs.docs) == 1 + assert len(docs.deleted_dockeys) == 1 + answer = docs.get_evidence(answer, max_sources=25, k=30) keys = set([c.text.doc.dockey for c in answer.contexts]) assert len(keys) == 1 From 71c70f2a5dac39e30de8c368a51ff570301951a8 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 10 Jan 2024 22:42:10 -0800 Subject: [PATCH 08/16] Refactored vector stores to maybe support langchain --- README.md | 47 +++++++++++++--- paperqa/__init__.py | 2 + paperqa/docs.py | 73 +++++++++++------------- paperqa/llms.py | 127 +++++++++++++++++++++++++++++++++++++++++- paperqa/types.py | 103 +--------------------------------- tests/test_paperqa.py | 28 ++++++---- 6 files changed, 220 insertions(+), 160 deletions(-) diff --git a/README.md b/README.md index cfc4b1c18..50b8b7f69 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This is a minimal package for doing question and answering from PDFs or text files (which can be raw HTML). It strives to give very good answers, with no hallucinations, by grounding responses with in-text citations. -By default, it uses [OpenAI Embeddings](https://platform.openai.com/docs/guides/embeddings) with a vector DB called [FAISS](https://github.com/facebookresearch/faiss) to embed and search documents. However, via [langchain](https://github.com/hwchase17/langchain) you can use open-source models or embeddings (see details below). +By default, it uses [OpenAI Embeddings](https://platform.openai.com/docs/guides/embeddings) with a simple numpy vector DB to embed and search documents. However, via [langchain](https://github.com/hwchase17/langchain) you can use open-source models or embeddings (see details below). paper-qa uses the process shown below: @@ -45,7 +45,7 @@ Tulevski2007: Tulevski, George S., et al. "Chemically assisted directed assembly Chen2014: Chen, Haitian, et al. "Large-scale complementary macroelectronics using hybrid integration of carbon nanotubes and IGZO thin-film transistors." Nature communications 5.1 (2014): 4097. -## Version 4 Changes +## What's New? Version 4 removed langchain from the package because it no longer supports pickling. This also simplifies the package a bit - especially prompts. Langchain can still be used, but it's not required. You can use any LLMs from langchain, but you will need to use the `LangchainLLMModel` class to wrap the model. @@ -85,7 +85,7 @@ The answer object has the following attributes: `formatted_answer`, `answer` (an ### Choosing Model -By default, it uses a hybrid of `gpt-3.5-turbo` and `gpt-4-turbo`. If you don't have gpt-4 access or would like to save money, you can adjust: +By default, it uses a hybrid of `gpt-3.5-turbo` and `gpt-4-turbo`. You can adjust this: ```py docs = Docs(llm='gpt-3.5-turbo') @@ -100,7 +100,7 @@ docs = Docs(llm_model=LangchainLLMModel(), client=ChatAnthropic()) ``` -Notice that we split the model into `LangchainLLMModel` and `client` which is `ChatAnthropic`. This is because paper-qa can be pickled, but typically Langchain models cannot be pickled. Thus, the client is the unpicklable part. Specifically, you can save your state in paper-qa: +Note we split the model into `LangchainLLMModel` (always empty) and `client` which is `ChatAnthropic`. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts. ```py import pickle @@ -115,21 +115,52 @@ docs.set_client(ChatAnthropic()) #### Locally Hosted -You can use llama.cpp to be the LLM. Note that you should be using relatively large models, because paper-qa requires following a lot of instructions. You won't get good performance with 7B models. I recommend using the SentenceTransformer models for embeddings, rather than llama.cpp embeddings. +You can use llama.cpp to be the LLM. Note that you should be using relatively large models, because paper-qa requires following a lot of instructions. You won't get good performance with 7B models. + +The easiest way to get set-up is to download a [llama file](https://github.com/Mozilla-Ocho/llamafile) and execute it with `-cb -np 4 -a my-llm-model --embedding` which will enable continuous batching and embeddings. + +```py +from paperqa import Docs, LlamaEmbeddingModel +from openai import AsyncOpenAI + +# start llamap.cpp client with + +local_client = AsyncOpenAI( + base_url="http://localhost:8080/v1", + api_key = "sk-no-key-required" +) + +docs = Docs(client=local_client, + embedding=LlamaEmbeddingModel(), + llm_model=OpenAILLMModel(config=dict(model="my-llm-model", temperature=0.1, frequency_penalty=1.5, max_tokens=512))) +``` + +### Changing Embedding Model + +You can use langchain embedding models, or the [SentenceTransformer](https://www.sbert.net/) models. For example ```py from paperqa import Docs, SentenceTransformerEmbeddingModel from openai import AsyncOpenAI -# start llamap.cpp client with: -cb -np 4 -a my-alais --embedding +# start llamap.cpp client with local_client = AsyncOpenAI( base_url="http://localhost:8080/v1", api_key = "sk-no-key-required" ) + docs = Docs(client=local_client, embedding=SentenceTransformerEmbeddingModel(), - llm_model=OpenAILLMModel(config=dict(model="my-alias", temperature=0.1, frequency_penalty=1.5, max_tokens=512))) + llm_model=OpenAILLMModel(config=dict(model="my-llm-model", temperature=0.1, frequency_penalty=1.5, max_tokens=512))) +``` + +Just like in the above examples, we have to split the Langchain model into a client and model to keep `Docs` serializable. +```py + +from paperqa import Docs, LangchainEmbeddingModel + +docs = Docs(embedding_model=LangchainEmbeddingModel(), embedding_client=OpenAIEmbeddings()) ``` ### Adjusting number of sources @@ -317,4 +348,6 @@ with open("my_docs.pkl", "wb") as f: # load with open("my_docs.pkl", "rb") as f: docs = pickle.load(f) + +docs.set_client() #defaults to OpenAI ``` diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 1fc9643c6..601f1bcac 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -8,6 +8,7 @@ LangchainLLMModel, OpenAILLMModel, LlamaEmbeddingModel, + NumpyVectorStore, SentenceTransformerEmbeddingModel, ) @@ -27,4 +28,5 @@ "LlamaEmbeddingModel", "SentenceTransformerEmbeddingModel", "LangchainEmbeddingModel", + "NumpyVectorStore", ] diff --git a/paperqa/docs.py b/paperqa/docs.py index 010256089..3c69f1b01 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -9,32 +9,21 @@ from typing import Any, BinaryIO, cast from openai import AsyncOpenAI -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from .llms import ( - EmbeddingModel, LangchainEmbeddingModel, LangchainLLMModel, - LlamaEmbeddingModel, LLMModel, - OpenAIEmbeddingModel, + NumpyVectorStore, OpenAILLMModel, + VectorStore, get_score, is_openai_model, ) from .paths import PAPERQA_DIR from .readers import read_doc -from .types import ( - Answer, - CallbackFactory, - Context, - Doc, - DocKey, - NumpyVectorStore, - PromptCollection, - Text, - VectorStore, -) +from .types import Answer, CallbackFactory, Context, Doc, DocKey, PromptCollection, Text from .utils import ( gather_with_concurrency, guess_is_4xx, @@ -63,12 +52,11 @@ class Docs(BaseModel): ) summary_llm_model: LLMModel | None = Field(default=None, validate_default=True) embedding: str | None = "default" - embedding_model: EmbeddingModel = OpenAIEmbeddingModel() docs: dict[DocKey, Doc] = {} texts: list[Text] = [] docnames: set[str] = set() - texts_index: VectorStore = NumpyVectorStore() - doc_index: VectorStore = NumpyVectorStore() + texts_index: VectorStore = Field(default_factory=NumpyVectorStore) + docs_index: VectorStore = Field(default_factory=NumpyVectorStore) name: str = "default" index_path: Path | None = PAPERQA_DIR / name batch_size: int = 1 @@ -78,9 +66,9 @@ class Docs(BaseModel): jit_texts_index: bool = False # This is used to strip indirect citations that come up from the summary llm strip_citations: bool = True + model_config = ConfigDict(extra="forbid") def __init__(self, **data): - # TODO: There may be a way to put this into pydantic model validator # We do it here because we need to move things to private attributes if "embedding_client" in data: embedding_client = data.pop("embedding_client") @@ -93,8 +81,7 @@ def __init__(self, **data): # convenience embedding_client = data["client"] else: - # if embedding_model is explicitly set, but not client then make it None - if "embedding_model" in data and data["embedding_model"] is not None: + if "embedding" in data and data["embedding"] != "default": embedding_client = None else: embedding_client = AsyncOpenAI() @@ -110,6 +97,9 @@ def __init__(self, **data): client = None else: client = AsyncOpenAI() + # backwards compatibility + if "doc_index" in data: + data["docs_index"] = data.pop("doc_index") super().__init__(**data) self._client = client self._embedding_client = embedding_client @@ -134,13 +124,19 @@ def setup_alias_models(cls, data: Any) -> Any: raise ValueError(f"Could not guess model type for {data['llm']}. ") if "embedding" in data and data["embedding"] != "default": if data["embedding"] == "langchain": - data["embedding_model"] = LangchainEmbeddingModel() - elif data["embedding"] == "llama": - data["embedding_model"] = LlamaEmbeddingModel() + if "texts_index" not in data: + data["texts_index"] = NumpyVectorStore( + embedding_model=LangchainEmbeddingModel() + ) + if "docs_index" not in data: + data["docs_index"] = NumpyVectorStore( + embedding_model=LangchainEmbeddingModel() + ) else: raise ValueError( - f"Could not guess model type for {data['embedding']}. " + f"Could not guess embedding model type for {data['embedding']}. " ) + return data @model_validator(mode="after") @@ -337,7 +333,7 @@ def add_texts( doc.docname = new_docname if texts[0].embedding is None: text_embeddings = asyncio.run( - self.embedding_model.embed_documents( + self.texts_index.embedding_model.embed_documents( self._embedding_client, [t.text for t in texts] ) ) @@ -345,13 +341,13 @@ def add_texts( t.embedding = text_embeddings[i] if doc.embedding is None: doc.embedding = asyncio.run( - self.embedding_model.embed_documents( + self.docs_index.embedding_model.embed_documents( self._embedding_client, [doc.citation] ) )[0] if not self.jit_texts_index: self.texts_index.add_texts_and_embeddings(texts) - self.doc_index.add_texts_and_embeddings([doc]) + self.docs_index.add_texts_and_embeddings([doc]) self.docs[doc.dockey] = doc self.texts += texts self.docnames.add(doc.docname) @@ -384,11 +380,9 @@ async def adoc_match( get_callbacks: CallbackFactory = lambda x: None, ) -> set[DocKey]: """Return a list of dockeys that match the query.""" - query_vector = ( - await self.embedding_model.embed_documents(self._embedding_client, [query]) - )[0] - matches, _ = self.doc_index.max_marginal_relevance_search( - query_vector, + matches, _ = await self.docs_index.max_marginal_relevance_search( + self._embedding_client, + query, k=k + len(self.deleted_dockeys), fetch_k=5 * (k + len(self.deleted_dockeys)), ) @@ -471,7 +465,7 @@ async def aget_evidence( detailed_citations: bool = False, disable_vector_search: bool = False, ) -> Answer: - if len(self.docs) == 0 and self.doc_index is None: + if len(self.docs) == 0 and self.docs_index is None: # do we have no docs? return answer self._build_texts_index(keys=answer.dockey_filter) @@ -481,15 +475,12 @@ async def aget_evidence( if disable_vector_search: matches = self.texts else: - query_vector = ( - await self.embedding_model.embed_documents( - self._embedding_client, [answer.question] - ) - )[0] matches = cast( list[Text], - self.texts_index.max_marginal_relevance_search( - query_vector, k=_k, fetch_k=5 * _k + ( + await self.texts_index.max_marginal_relevance_search( + self._embedding_client, answer.question, k=_k, fetch_k=5 * _k + ) )[0], ) # ok now filter (like ones from adoc_match) diff --git a/paperqa/llms.py b/paperqa/llms.py index 9c3887213..5d0391649 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -5,15 +5,18 @@ AsyncGenerator, Callable, Coroutine, + Sequence, cast, get_args, get_type_hints, ) +import numpy as np from openai import AsyncOpenAI -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from .prompts import default_system_prompt +from .types import Embeddable from .utils import batch_iter, flatten, gather_with_concurrency @@ -94,6 +97,7 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa class LLMModel(ABC, BaseModel): llm_type: str | None = None + model_config = ConfigDict(extra="forbid") async def acomplete(self, client: Any, prompt: str) -> str: raise NotImplementedError @@ -355,6 +359,127 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa return embeddings +def cosine_similarity(a, b): + dot_product = np.dot(a, b.T) + norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) + return dot_product / norm_product + + +class VectorStore(BaseModel, ABC): + """Interface for vector store - very similar to LangChain's VectorStore to be compatible""" + + embedding_model: EmbeddingModel = Field(default=OpenAIEmbeddingModel()) + model_config = ConfigDict(extra="forbid") + + @abstractmethod + def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: + pass + + @abstractmethod + async def similarity_search( + self, client: Any, query: str, k: int + ) -> tuple[Sequence[Embeddable], list[float]]: + pass + + @abstractmethod + def clear(self) -> None: + pass + + async def max_marginal_relevance_search( + self, client: Any, query: str, k: int, fetch_k: int, lambda_: float = 0.5 + ) -> tuple[Sequence[Embeddable], list[float]]: + """Vectorized implementation of Maximal Marginal Relevance (MMR) search. + + Args: + query: Query vector. + k: Number of results to return. + lambda_: Weighting of relevance and diversity. + + Returns: + List of tuples (doc, score) of length k. + """ + if fetch_k < k: + raise ValueError("fetch_k must be greater or equal to k") + + texts, scores = await self.similarity_search(client, query, fetch_k) + if len(texts) <= k: + return texts, scores + + embeddings = np.array([t.embedding for t in texts]) + np_scores = np.array(scores) + similarity_matrix = cosine_similarity(embeddings, embeddings) + + selected_indices = [0] + remaining_indices = list(range(1, len(texts))) + + while len(selected_indices) < k: + selected_similarities = similarity_matrix[:, selected_indices] + max_sim_to_selected = selected_similarities.max(axis=1) + + mmr_scores = lambda_ * np_scores - (1 - lambda_) * max_sim_to_selected + mmr_scores[selected_indices] = -np.inf # Exclude already selected documents + + max_mmr_index = mmr_scores.argmax() + selected_indices.append(max_mmr_index) + remaining_indices.remove(max_mmr_index) + + return [texts[i] for i in selected_indices], [ + scores[i] for i in selected_indices + ] + + +class NumpyVectorStore(VectorStore): + texts: list[Embeddable] = [] + _embeddings_matrix: np.ndarray | None = None + + def clear(self) -> None: + self.texts = [] + self._embeddings_matrix = None + + def add_texts_and_embeddings( + self, + texts: Sequence[Embeddable], + ) -> None: + self.texts.extend(texts) + self._embeddings_matrix = np.array([t.embedding for t in self.texts]) + + async def similarity_search( + self, client: Any, query: str, k: int + ) -> tuple[Sequence[Embeddable], list[float]]: + if len(self.texts) == 0: + return [], [] + np_query = np.array( + (await self.embedding_model.embed_documents(client, [query]))[0] + ) + similarity_scores = cosine_similarity( + np_query.reshape(1, -1), self._embeddings_matrix + )[0] + similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf) + sorted_indices = np.argsort(similarity_scores)[::-1] + return ( + [self.texts[i] for i in sorted_indices[:k]], + [similarity_scores[i] for i in sorted_indices[:k]], + ) + + +class LangchainVectorStore(VectorStore): + """A wrapper around the wrapper langchain""" + + @abstractmethod + def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: + pass + + @abstractmethod + async def similarity_search( + self, client: Any, query: str, k: int + ) -> tuple[Sequence[Embeddable], list[float]]: + pass + + @abstractmethod + def clear(self) -> None: + pass + + def get_score(text: str) -> int: # check for N/A last_line = text.split("\n")[-1] diff --git a/paperqa/types.py b/paperqa/types.py index 8eb9278bb..f5e6488eb 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -1,8 +1,6 @@ -from abc import ABC, abstractmethod from typing import Any, Callable -import numpy as np -from pydantic import BaseModel, Field, Sequence, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from .prompts import ( citation_prompt, @@ -34,104 +32,6 @@ class Text(Embeddable): doc: Doc -def cosine_similarity(a, b): - dot_product = np.dot(a, b.T) - norm_product = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) - return dot_product / norm_product - - -class VectorStore(BaseModel, ABC): - """Interface for vector store - very similar to LangChain's VectorStore to be compatible""" - - @abstractmethod - def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: - pass - - @abstractmethod - def similarity_search( - self, query: list[float], k: int - ) -> tuple[list[Embeddable], list[float]]: - pass - - @abstractmethod - def clear(self) -> None: - pass - - def max_marginal_relevance_search( - self, query: list[float], k: int, fetch_k: int, lambda_: float = 0.5 - ) -> tuple[list[Embeddable], list[float]]: - """Vectorized implementation of Maximal Marginal Relevance (MMR) search. - - Args: - query: Query vector. - k: Number of results to return. - lambda_: Weighting of relevance and diversity. - - Returns: - List of tuples (doc, score) of length k. - """ - if fetch_k < k: - raise ValueError("fetch_k must be greater or equal to k") - - texts, scores = self.similarity_search(query, fetch_k) - if len(texts) <= k: - return texts, scores - - embeddings = np.array([t.embedding for t in texts]) - np_scores = np.array(scores) - similarity_matrix = cosine_similarity(embeddings, embeddings) - - selected_indices = [0] - remaining_indices = list(range(1, len(texts))) - - while len(selected_indices) < k: - selected_similarities = similarity_matrix[:, selected_indices] - max_sim_to_selected = selected_similarities.max(axis=1) - - mmr_scores = lambda_ * np_scores - (1 - lambda_) * max_sim_to_selected - mmr_scores[selected_indices] = -np.inf # Exclude already selected documents - - max_mmr_index = mmr_scores.argmax() - selected_indices.append(max_mmr_index) - remaining_indices.remove(max_mmr_index) - - return [texts[i] for i in selected_indices], [ - scores[i] for i in selected_indices - ] - - -class NumpyVectorStore(VectorStore): - texts: list[Embeddable] = [] - _embeddings_matrix: np.ndarray | None = None - - def clear(self) -> None: - self.texts = [] - self._embeddings_matrix = None - - def add_texts_and_embeddings( - self, - texts: list[Embeddable], - ) -> None: - self.texts.extend(texts) - self._embeddings_matrix = np.array([t.embedding for t in self.texts]) - - def similarity_search( - self, query: list[float], k: int - ) -> tuple[list[Embeddable], list[float]]: - if len(self.texts) == 0: - return [], [] - np_query = np.array(query) - similarity_scores = cosine_similarity( - np_query.reshape(1, -1), self._embeddings_matrix - )[0] - similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf) - sorted_indices = np.argsort(similarity_scores)[::-1] - return ( - [self.texts[i] for i in sorted_indices[:k]], - [similarity_scores[i] for i in sorted_indices[:k]], - ) - - # Mock a dictionary and store any missing items class _FormatDict(dict): def __init__(self) -> None: @@ -242,6 +142,7 @@ class Answer(BaseModel): # if you want to use them. cost: float | None = None token_counts: dict[str, list[int]] | None = None + model_config = ConfigDict(extra="forbid") def __str__(self) -> str: """Return the answer as a string.""" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index fec43a80d..76cb58094 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -7,7 +7,7 @@ import requests from openai import AsyncOpenAI -from paperqa import Answer, Doc, Docs, PromptCollection, Text +from paperqa import Answer, Doc, Docs, NumpyVectorStore, PromptCollection, Text from paperqa.llms import ( EmbeddingModel, LangchainEmbeddingModel, @@ -469,14 +469,18 @@ class MyEmbeds(EmbeddingModel): async def embed_documents(self, client, texts): return [[1, 2, 3] for _ in texts] - docs = Docs(embedding_model=MyEmbeds()) + docs = Docs( + docs_index=NumpyVectorStore(embedding_model=MyEmbeds()), + texts_index=NumpyVectorStore(embedding_model=MyEmbeds()), + embedding_client=None, + ) + assert docs._embedding_client is None docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", ) assert docs.docs["test"].embedding == [1, 2, 3] - assert docs._embedding_client is None def test_custom_llm(): @@ -573,8 +577,12 @@ def test_langchain_embeddings(): from langchain_openai import OpenAIEmbeddings docs = Docs( - embedding_model=LangchainEmbeddingModel(), embedding_client=OpenAIEmbeddings() + texts_index=NumpyVectorStore(embedding_model=LangchainEmbeddingModel()), + docs_index=NumpyVectorStore(embedding_model=LangchainEmbeddingModel()), + embedding_client=OpenAIEmbeddings(), ) + assert docs._embedding_client is not None + docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", @@ -828,7 +836,7 @@ def test_dockey_filter(): f.write(r.text) f.write("\n") # so we don't have same hash docs.add("example.txt", "WikiMedia Foundation, 2023, Accessed now", dockey="test") - answer = Answer(question="What country is Bates from?", key_filter=["test"]) + answer = Answer(question="What country is Bates from?", dockey_filter=["test"]) docs.get_evidence(answer) @@ -1007,8 +1015,8 @@ def disabled_test_memory(): def test_add_texts(): - llm_config = dict(temperature=0.1, model="text-ada-001") - docs = Docs(llm_config=llm_config) + llm_config = dict(temperature=0.1, model="babbage-02") + docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", @@ -1027,8 +1035,8 @@ def test_add_texts(): docs2._build_texts_index() # now do it again to test after text index is already built - llm_config = dict(temperature=0.1, model="text-ada-001") - docs = Docs(llm_config=llm_config) + llm_config = dict(temperature=0.1, model="babbage-02") + docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", @@ -1052,7 +1060,7 @@ def test_external_doc_index(): dockey="test", ) evidence = docs.query(query="What is the date of flag day?", key_filter=True) - docs2 = Docs(doc_index=docs.doc_index, texts_index=docs.texts_index) + docs2 = Docs(docs_index=docs.docs_index, texts_index=docs.texts_index) assert len(docs2.docs) == 0 evidence = docs2.query("What is the date of flag day?", key_filter=True) assert "February 15" in evidence.context From cab5596e47ca3a768b6b16e586ce568b117453f5 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 10 Jan 2024 22:50:43 -0800 Subject: [PATCH 09/16] Addressed Matt's comments --- dev-requirements.txt | 2 +- paperqa/docs.py | 4 ++-- paperqa/llms.py | 13 +++++++++--- paperqa/prompts.py | 47 ++++++++++++++++++++++++-------------------- paperqa/readers.py | 1 + 5 files changed, 40 insertions(+), 27 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 6202c6408..da3104761 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,4 +6,4 @@ python-dotenv pymupdf build types-requests -numpy +langchain_openai diff --git a/paperqa/docs.py b/paperqa/docs.py index 3c69f1b01..a4091266b 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -656,8 +656,8 @@ async def aquery( get_callbacks("answer"), ) # it still happens - if "(Example2012)" in answer_text: - answer_text = answer_text.replace("(Example2012)", "") + if "(Example2012Example pages 3-4)" in answer_text: + answer_text = answer_text.replace("(Example2012Example pages 3-4)", "") for c in answer.contexts: name = c.text.name citation = c.text.doc.citation diff --git a/paperqa/llms.py b/paperqa/llms.py index 5d0391649..95572cee2 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -21,6 +21,7 @@ def guess_model_type(model_name: str) -> str: + """Guess the model type from the model name for OpenAI models""" import openai model_type = get_type_hints( @@ -129,6 +130,8 @@ def make_chain( ) -> Callable[[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, str]]: """Create a function to execute a batch of prompts + This replaces the previous use of langchain for combining prompts and LLMs. + Args: client: a ephemeral client to use prompt: The prompt to use @@ -369,6 +372,8 @@ class VectorStore(BaseModel, ABC): """Interface for vector store - very similar to LangChain's VectorStore to be compatible""" embedding_model: EmbeddingModel = Field(default=OpenAIEmbeddingModel()) + # can be tuned for different tasks + mmr_lambda: float = Field(default=0.5) model_config = ConfigDict(extra="forbid") @abstractmethod @@ -386,14 +391,13 @@ def clear(self) -> None: pass async def max_marginal_relevance_search( - self, client: Any, query: str, k: int, fetch_k: int, lambda_: float = 0.5 + self, client: Any, query: str, k: int, fetch_k: int ) -> tuple[Sequence[Embeddable], list[float]]: """Vectorized implementation of Maximal Marginal Relevance (MMR) search. Args: query: Query vector. k: Number of results to return. - lambda_: Weighting of relevance and diversity. Returns: List of tuples (doc, score) of length k. @@ -416,7 +420,10 @@ async def max_marginal_relevance_search( selected_similarities = similarity_matrix[:, selected_indices] max_sim_to_selected = selected_similarities.max(axis=1) - mmr_scores = lambda_ * np_scores - (1 - lambda_) * max_sim_to_selected + mmr_scores = ( + self.mmr_lambda * np_scores + - (1 - self.mmr_lambda) * max_sim_to_selected + ) mmr_scores[selected_indices] = -np.inf # Exclude already selected documents max_mmr_index = mmr_scores.argmax() diff --git a/paperqa/prompts.py b/paperqa/prompts.py index 1d177cb86..83c4d877d 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -1,28 +1,33 @@ summary_prompt = ( - "Summarize the text below to help answer a question. " - "Do not directly answer the question, instead summarize " - "to give evidence to help answer the question. " - "Focus on specific details, including numbers, equations, or specific quotes. " - 'Reply "Not applicable" if text is irrelevant. ' - "Use {summary_length}. At the end of your response, provide a score from 1-10 on a newline " - "indicating relevance to question. Do not explain your score. " - "\n\n" - "{text}\n\n" - "Excerpt from {citation}\n" - "Question: {question}\n" - "Relevant Information Summary:" + "Summarize the excerpt below to help answer a question.\n\n" + "Excerpt from {citation}\n\n----\n\n{text}\n\n----\n\n" + "Question: {question}\n\n" + "Do not directly answer the question, instead summarize to give evidence to help " + "answer the question. Stay detailed; report specific numbers, equations, or " + 'direct quotes (marked with quotation marks). Reply "Not applicable" if the ' + "excerpt is irrelevant. At the end of your response, provide an integer score " + "from 1-10 on a newline indicating relevance to question. Do not explain your score." + "\n\nRelevant Information Summary ({summary_length}):" ) qa_prompt = ( - "Write an answer ({answer_length}) " - "for the question below based on the provided context. Ignore irrelevant context. " - "If the context provides insufficient information and the question cannot be directly answered, " - 'reply "I cannot answer". ' - "For each part of your answer, indicate which sources most support it " - "via valid citation markers at the end of sentences, like (Example2012). \n" - "Context (with relevance scores):\n {context}\n" - "Question: {question}\n" - "Answer: " + "Answer the question below with the context.\n\n" + "Context (with relevance scores):\n\n{context}\n\n----\n\n" + "Question: {question}\n\n" + "Write an answer based on the context. " + "If the context provides insufficient information and " + "the question cannot be directly answered, reply " + '"I cannot answer."' + "For each part of your answer, indicate which sources most support " + "it via citation keys at the end of sentences, " + "like (Example2012Example pages 3-4). Only cite from the context " + "below and only use the valid keys. Write in the style of a " + "Wikipedia article, with concise sentences and coherent paragraphs. " + "The context comes from a variety of sources and is only a summary, " + "so there may inaccuracies or ambiguities. If quotes are present and " + "relevant, use them in the answer. This answer will go directly onto " + "Wikipedia, so do not add any extraneous information.\n\n" + "Answer ({answer_length}):" ) select_paper_prompt = ( diff --git a/paperqa/readers.py b/paperqa/readers.py index 9c629573b..8a74c3289 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -92,6 +92,7 @@ def parse_txt( text = html2text(text) texts: list[Text] = [] # we tokenize using tiktoken so cuts are in reasonable places + # See https://github.com/openai/tiktoken enc = tiktoken.get_encoding("cl100k_base") encoded = [enc.decode_single_token_bytes(token) for token in enc.encode(text)] split_size = 0 From a19875e2e6732f1244223bd49fc63a9111d41fdd Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 11 Jan 2024 11:16:58 -0800 Subject: [PATCH 10/16] Finished langchain vector store --- README.md | 27 ++++++ dev-requirements.txt | 2 + paperqa/__init__.py | 2 + paperqa/docs.py | 14 ++- paperqa/llms.py | 203 ++++++++++++++++++++++++++++++------------ tests/test_paperqa.py | 94 +++++++++++++++++++ 6 files changed, 279 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 50b8b7f69..5740f11fe 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,33 @@ answer = docs.query("Where is the search bar in the header defined?") print(answer) ``` +### Using External DB/Vector DB and Caching + +You may want to cache parsed texts and embeddings in an external database or file. You can then build a Docs object from those directly: + +```py +#.... + +docs = Docs() + +for ... in my_docs: + doc = Doc(docname=..., citation=..., dockey=..., citation=...) + texts = [Text(text=..., name=..., doc=doc) for ... in my_texts] + docs.add_texts(texts, doc) +``` + +If you want to use an external vector store, you can also do that directly via langchain. For example, to use the [FAISS](https://ai.meta.com/tools/faiss/) from langchain: + +```py +from paperqa import LangchainVectorStore, Docs +from langchain_community.vector_store import FAISS +from langchain_openai import OpenAIEmbeddings + +my_index = LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()) +docs = Docs(texts_index=my_index) + +``` + ## Where do I get papers? Well that's a really good question! It's probably best to just download PDFs of papers you think will help answer your question and start from there. diff --git a/dev-requirements.txt b/dev-requirements.txt index da3104761..92f2a9b2e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,3 +7,5 @@ pymupdf build types-requests langchain_openai +langchain_community +faiss-cpu diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 601f1bcac..af23e9938 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -9,6 +9,7 @@ OpenAILLMModel, LlamaEmbeddingModel, NumpyVectorStore, + LangchainVectorStore, SentenceTransformerEmbeddingModel, ) @@ -29,4 +30,5 @@ "SentenceTransformerEmbeddingModel", "LangchainEmbeddingModel", "NumpyVectorStore", + "LangchainVectorStore", ] diff --git a/paperqa/docs.py b/paperqa/docs.py index a4091266b..c9c45dc38 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -43,8 +43,8 @@ class Docs(BaseModel): """A collection of documents to be used for answering questions.""" # ephemeral vars that should not be pickled (_things) - _client: Any | None - _embedding_client: Any | None + _client: Any | None = None + _embedding_client: Any | None = None llm: str = "default" summary_llm: str | None = None llm_model: LLMModel = Field( @@ -163,6 +163,10 @@ def clear_docs(self): self.docnames = set() def __getstate__(self): + # You may wonder why make these private if we're just going + # to be overriding the behavior on setstaet/getstate anyway. + # The reason is that the other serialization methods from Pydantic - + # model_dump - will not drop private attributes. state = super().__getstate__() # remove client from private attributes del state["__pydantic_private__"]["_client"] @@ -170,9 +174,10 @@ def __getstate__(self): return state def __setstate__(self, state): + # add client back to private attributes + state["__pydantic_private__"]["_client"] = None + state["__pydantic_private__"]["_embedding_client"] = None super().__setstate__(state) - self._client = None - self._embedding_client = None def set_client( self, @@ -421,6 +426,7 @@ async def adoc_match( def _build_texts_index(self, keys: set[DocKey] | None = None): texts = self.texts if keys is not None and self.jit_texts_index: + # TODO: what is JIT even for?? if keys is not None: texts = [t for t in texts if t.doc.dockey in keys] if len(texts) == 0: diff --git a/paperqa/llms.py b/paperqa/llms.py index 95572cee2..a36ecba9d 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1,11 +1,13 @@ import re from abc import ABC, abstractmethod +from inspect import signature from typing import ( Any, AsyncGenerator, Callable, Coroutine, Sequence, + Type, cast, get_args, get_type_hints, @@ -256,58 +258,6 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: yield chunk.choices[0].delta.content -class LangchainLLMModel(LLMModel): - """A wrapper around the wrapper langchain""" - - def infer_llm_type(self, client: Any) -> str: - from langchain_core.language_models.chat_models import BaseChatModel - - if isinstance(client, BaseChatModel): - return "chat" - return "completion" - - async def acomplete(self, client: Any, prompt: str) -> str: - return await client.ainvoke(prompt) - - async def acomplete_iter(self, client: Any, prompt: str) -> Any: - async for chunk in cast(AsyncGenerator, client.astream(prompt)): - yield chunk - - async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: - from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage - - lc_messages: list[BaseMessage] = [] - for m in messages: - if m["role"] == "user": - lc_messages.append(HumanMessage(content=m["content"])) - elif m["role"] == "system": - lc_messages.append(SystemMessage(content=m["content"])) - else: - raise ValueError(f"Unknown role: {m['role']}") - return (await client.ainvoke(lc_messages)).content - - async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: - from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage - - lc_messages: list[BaseMessage] = [] - for m in messages: - if m["role"] == "user": - lc_messages.append(HumanMessage(content=m["content"])) - elif m["role"] == "system": - lc_messages.append(SystemMessage(content=m["content"])) - else: - raise ValueError(f"Unknown role: {m['role']}") - async for chunk in client.astream(lc_messages): - yield chunk.content - - -class LangchainEmbeddingModel(EmbeddingModel): - """A wrapper around the wrapper langchain""" - - async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: - return await client.aembed_documents(texts) - - class LlamaEmbeddingModel(EmbeddingModel): embedding_model: str = Field(default="llama") @@ -469,22 +419,157 @@ async def similarity_search( ) -class LangchainVectorStore(VectorStore): +# All the langchain stuff is below +# Many confusing woes here because langchain +# is not serializable and so we have to +# do some gymnastics to make it work + + +class LangchainLLMModel(LLMModel): """A wrapper around the wrapper langchain""" - @abstractmethod + def infer_llm_type(self, client: Any) -> str: + from langchain_core.language_models.chat_models import BaseChatModel + + if isinstance(client, BaseChatModel): + return "chat" + return "completion" + + async def acomplete(self, client: Any, prompt: str) -> str: + return await client.ainvoke(prompt) + + async def acomplete_iter(self, client: Any, prompt: str) -> Any: + async for chunk in cast(AsyncGenerator, client.astream(prompt)): + yield chunk + + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + + lc_messages: list[BaseMessage] = [] + for m in messages: + if m["role"] == "user": + lc_messages.append(HumanMessage(content=m["content"])) + elif m["role"] == "system": + lc_messages.append(SystemMessage(content=m["content"])) + else: + raise ValueError(f"Unknown role: {m['role']}") + return (await client.ainvoke(lc_messages)).content + + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + + lc_messages: list[BaseMessage] = [] + for m in messages: + if m["role"] == "user": + lc_messages.append(HumanMessage(content=m["content"])) + elif m["role"] == "system": + lc_messages.append(SystemMessage(content=m["content"])) + else: + raise ValueError(f"Unknown role: {m['role']}") + async for chunk in client.astream(lc_messages): + yield chunk.content + + +class LangchainEmbeddingModel(EmbeddingModel): + """A wrapper around the wrapper langchain""" + + async def embed_documents(self, client: Any, texts: list[str]) -> list[list[float]]: + return await client.aembed_documents(texts) + + +class LangchainVectorStore(VectorStore): + """A wrapper around the wrapper langchain + + Note that if you this is cleared (e.g., by `Docs` having `jit_texts_index` set to True), + this will calls the `from_texts` class method on the `store`. This means that any non-default + constructor arguments will be lost. You can override the clear method on this class. + """ + + _store_builder: Any | None = None + _store: Any | None = None + # JIT Generics - store the class type (Doc or Text) + class_type: Type[Embeddable] = Field(default=Embeddable) + model_config = ConfigDict(extra="forbid") + + def __init__(self, **data): + # we have to separate out store from the rest of the data + # because langchain objects are not serializable + store_builder = None + if "store_builder" in data: + store_builder = LangchainVectorStore.check_store_builder( + data.pop("store_builder") + ) + if "cls" in data and "embedding_model" in data: + # make a little closure + cls = data.pop("cls") + embedding_model = data.pop("embedding_model") + + def candidate(x, y): + return cls.from_embeddings(x, embedding_model, y) + + store_builder = LangchainVectorStore.check_store_builder(candidate) + super().__init__(**data) + self._store_builder = store_builder + + @classmethod + def check_store_builder(cls, builder: Any) -> Any: + # check it is a callable + if not callable(builder): + raise ValueError("store_builder must be callable") + # check it takes two arguments + # we don't use type hints because it could be + # a partial + sig = signature(builder) + if len(sig.parameters) != 2: + raise ValueError("store_builder must take two arguments") + return builder + + def __getstate__(self): + state = super().__getstate__() + # remove non-serializable private attributes + del state["__pydantic_private__"]["_store"] + del state["__pydantic_private__"]["_store_builder"] + return state + + def __setstate__(self, state): + # restore non-serializable private attributes + state["__pydantic_private__"]["_store"] = None + state["__pydantic_private__"]["_store_builder"] = None + super().__setstate__(state) + def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: - pass + if self._store_builder is None: + raise ValueError("You must set store_builder before adding texts") + self.class_type = type(texts[0]) + vec_store_text_and_embeddings = list( + map(lambda x: (x.text, x.embedding), texts) + ) + if self._store is None: + self._store = self._store_builder( # type: ignore + vec_store_text_and_embeddings, + texts, + ) + if self._store is None or not hasattr(self._store, "add_embeddings"): + raise ValueError("store_builder did not return a valid vectorstore") + self._store.add_embeddings( # type: ignore + vec_store_text_and_embeddings, + metadatas=texts, + ) - @abstractmethod async def similarity_search( self, client: Any, query: str, k: int ) -> tuple[Sequence[Embeddable], list[float]]: - pass + if self._store is None: + return [], [] + results = await self._store.asimilarity_search_with_relevance_scores(query, k=k) + texts, scores = [self.class_type(**r[0].metadata) for r in results], [ + r[1] for r in results + ] + return texts, scores - @abstractmethod def clear(self) -> None: - pass + del self._store # be explicit, because it could be large + self._store = None def get_score(text: str) -> int: diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 76cb58094..4eaa83d2f 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -12,6 +12,7 @@ EmbeddingModel, LangchainEmbeddingModel, LangchainLLMModel, + LangchainVectorStore, LLMModel, OpenAILLMModel, get_score, @@ -596,6 +597,99 @@ def test_langchain_embeddings(): ) +class TestVectorStore(IsolatedAsyncioTestCase): + async def test_langchain_vector_store(self): + from langchain_community.vectorstores.faiss import FAISS + from langchain_openai import OpenAIEmbeddings + + some_texts = [ + Text( + embedding=OpenAIEmbeddings().embed_query("test"), + text="this is a test", + name="test", + doc=Doc(docname="test", citation="test", dockey="test"), + ) + ] + + # checks on builder + try: + index = LangchainVectorStore() + index.add_texts_and_embeddings(some_texts) + raise "Failed to check for builder" + except ValueError: + pass + + try: + index = LangchainVectorStore(store_builder=lambda x: None) + raise "Failed to count arguments" + except ValueError: + pass + + try: + index = LangchainVectorStore(store_builder="foo") + raise "Failed to check if builder is callable" + except ValueError: + pass + + # now with real builder + index = LangchainVectorStore( + store_builder=lambda x, y: FAISS.from_embeddings(x, OpenAIEmbeddings(), y) + ) + assert index._store is None + index.add_texts_and_embeddings(some_texts) + assert index._store is not None + # check search returns Text obj + data, score = await index.similarity_search(None, "test", k=1) + print(data) + assert type(data[0]) == Text + + # now try with convenience + index = LangchainVectorStore(cls=FAISS, embedding_model=OpenAIEmbeddings()) + assert index._store is None + index.add_texts_and_embeddings(some_texts) + assert index._store is not None + + docs = Docs( + texts_index=LangchainVectorStore( + cls=FAISS, embedding_model=OpenAIEmbeddings() + ) + ) + assert docs._embedding_client is not None # from docs_index default + + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + # should be embedded + + # now try with JIT + docs = Docs(texts_index=index, jit_texts_index=True) + docs.add_url( + "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", + citation="WikiMedia Foundation, 2023, Accessed now", + dockey="test", + ) + # should get cleared and rebuilt here + ev = docs.get_evidence( + answer=Answer(question="What is Frederick Bates's greatest accomplishment?") + ) + assert len(ev.context) > 0 + # now with dockkey filter + docs.get_evidence( + answer=Answer( + question="What is Frederick Bates's greatest accomplishment?", + dockey_filter=["test"], + ) + ) + + # make sure we can pickle it + docs_pickle = pickle.dumps(docs) + pickle.loads(docs_pickle) + + # will not work at this point - have to reset index + + class Test(IsolatedAsyncioTestCase): async def test_aquery(self): docs = Docs() From e601be1d911d7bd9cf6f6bf62b3f4fd62f654708 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 11 Jan 2024 11:25:44 -0800 Subject: [PATCH 11/16] Unit test prompt adjustments --- tests/test_paperqa.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 4eaa83d2f..d9d05d417 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -818,11 +818,9 @@ def test_repeat_keys(): def test_pdf_reader(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") - docs = Docs( - llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-3.5-turbo")) - ) + docs = Docs(llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-4"))) docs.add(doc_path, "Wellawatte et al, XAI Review, 2023") - answer = docs.query("Are counterfactuals actionable?") + answer = docs.query("Are counterfactuals actionable? [yes/no]") assert "yes" in answer.answer or "Yes" in answer.answer @@ -830,21 +828,15 @@ def test_fileio_reader_pdf(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs( - llm_model=OpenAILLMModel( - config=dict(temperature=0.0, model="gpt-3.5-turbo") - ) - ) + docs = Docs() docs.add_file(f, "Wellawatte et al, XAI Review, 2023") - answer = docs.query("Are counterfactuals actionable?") + answer = docs.query("Are counterfactuals actionable?[yes/no]") assert "yes" in answer.answer or "Yes" in answer.answer def test_fileio_reader_txt(): # can't use curie, because it has trouble with parsed HTML - docs = Docs( - llm_model=OpenAILLMModel(config=dict(temperature=0.0, model="gpt-3.5-turbo")) - ) + docs = Docs() r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") if r.status_code != 200: raise ValueError("Could not download wikipedia page") From b8f1acfd3cbdb3ec56181d8216b38c5a455f59cf Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 11 Jan 2024 11:33:32 -0800 Subject: [PATCH 12/16] Added warning to README --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5740f11fe..c5ac1f5e3 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,13 @@ -# Paper QA- [Paper QA](#paper-qa) +# PaperQA [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/whitead/paper-qa) [![tests](https://github.com/whitead/paper-qa/actions/workflows/tests.yml/badge.svg)](https://github.com/whitead/paper-qa) [![PyPI version](https://badge.fury.io/py/paper-qa.svg)](https://badge.fury.io/py/paper-qa) +## YOU ARE LOOKING AT PRE-RELEASE README + +**This is the README for an upcoming v4 release** You can see the current stable version [here](https://github.com/whitead/paper-qa/tree/84f13ea32c22b85924cd681a4b5f4fbd174afd71) + This is a minimal package for doing question and answering from PDFs or text files (which can be raw HTML). It strives to give very good answers, with no hallucinations, by grounding responses with in-text citations. From 90793ede73aefc79d16f551af1beafb3ae0b67b8 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Thu, 11 Jan 2024 11:41:15 -0800 Subject: [PATCH 13/16] Fixed some typos in README --- README.md | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index c5ac1f5e3..dd575ff63 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,9 @@ ## YOU ARE LOOKING AT PRE-RELEASE README -**This is the README for an upcoming v4 release** You can see the current stable version [here](https://github.com/whitead/paper-qa/tree/84f13ea32c22b85924cd681a4b5f4fbd174afd71) +**This is the README for an upcoming v4 release** + +You can see the current stable version [here](https://github.com/whitead/paper-qa/tree/84f13ea32c22b85924cd681a4b5f4fbd174afd71) This is a minimal package for doing question and answering from PDFs or text files (which can be raw HTML). It strives to give very good answers, with no hallucinations, by grounding responses with in-text citations. @@ -198,7 +200,6 @@ print(answer) You may want to cache parsed texts and embeddings in an external database or file. You can then build a Docs object from those directly: ```py -#.... docs = Docs() @@ -208,7 +209,7 @@ for ... in my_docs: docs.add_texts(texts, doc) ``` -If you want to use an external vector store, you can also do that directly via langchain. For example, to use the [FAISS](https://ai.meta.com/tools/faiss/) from langchain: +If you want to use an external vector store, you can also do that directly via langchain. For example, to use the [FAISS](https://ai.meta.com/tools/faiss/) vector store from langchain: ```py from paperqa import LangchainVectorStore, Docs @@ -325,11 +326,7 @@ docs.query("What manufacturing challenges are unique to bispecific antibodies?", ### Caching Embeddings -In general, embeddings are cached when you pickle a `Docs` regardless of what vector store you use. If you would like to manage caching embeddings via an external database or other strategy, -you can populate a `Docs` object directly via -the `add_texts` object. That can take chunked texts and documents, which are serializable objects, to populate `Docs`. - -You also can simply use a separate vector database by setting the `doc_index` and `texts_index` explicitly when building the `Docs` object. +In general, embeddings are cached when you pickle a `Docs` regardless of what vector store you use. See above for details on more explicit management of them. ## Customizing Prompts From 39ba98c84d6781cf98c2bf831afab2ea73d6acdb Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 15 Jan 2024 14:20:32 -0800 Subject: [PATCH 14/16] Fixed problem for very short texts --- paperqa/readers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paperqa/readers.py b/paperqa/readers.py index 8a74c3289..816e75fc6 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -31,7 +31,7 @@ def parse_pdf_fitz(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List ) split = split[chunk_chars - overlap :] pages = [str(i + 1)] - if len(split) > overlap: + if len(split) > overlap or len(texts) == 0: pg = "-".join([pages[0], pages[-1]]) texts.append( Text(text=split[:chunk_chars], name=f"{doc.docname} pages {pg}", doc=doc) @@ -64,7 +64,7 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text ) split = split[chunk_chars - overlap :] pages = [str(i + 1)] - if len(split) > overlap: + if len(split) > overlap or len(texts) == 0: pg = "-".join([pages[0], pages[-1]]) texts.append( Text(text=split[:chunk_chars], name=f"{doc.docname} pages {pg}", doc=doc) @@ -112,7 +112,7 @@ def parse_txt( ) split = [split_flat[chunk_chars - overlap :].encode("utf-8")] split_size = len(split[0]) - if len(split) > overlap: + if split_size > overlap or len(texts) == 0: split_flat = b"".join(split).decode() texts.append( Text( @@ -134,7 +134,7 @@ def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List with open(path) as f: for i, line in enumerate(f): split += line - if len(split) > chunk_chars: + while len(split) > chunk_chars: texts.append( Text( text=split[:chunk_chars], @@ -144,7 +144,7 @@ def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List ) split = split[chunk_chars - overlap :] last_line = i - if len(split) > overlap: + if len(split) > overlap or len(texts) == 0: texts.append( Text( text=split[:chunk_chars], From 59ed8d317497c1b848f3029a8f1a34d44b014912 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 16 Jan 2024 14:50:20 -0800 Subject: [PATCH 15/16] Made it easier to access LLM names --- README.md | 9 ++++----- paperqa/docs.py | 24 +++++++++++++++++++++++- paperqa/llms.py | 12 ++++++++++++ tests/test_paperqa.py | 13 +++++++++++-- 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index dd575ff63..b4511407e 100644 --- a/README.md +++ b/README.md @@ -100,17 +100,17 @@ docs = Docs(llm='gpt-3.5-turbo') or you can use any other model available in [langchain](https://github.com/hwchase17/langchain): ```py -from paperqa import Docs, LangchainLLMModel +from paperqa import Docs from langchain_community.chat_models import ChatAnthropic -docs = Docs(llm_model=LangchainLLMModel(), +docs = Docs(llm="langchain", client=ChatAnthropic()) ``` -Note we split the model into `LangchainLLMModel` (always empty) and `client` which is `ChatAnthropic`. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts. +Note we split the model into the wrapper and `client`, which is `ChatAnthropic` here. This is because `client` stores the non-pickleable part and langchain LLMs are only sometimes serializable/pickleable. The paper-qa `Docs` must always serializable. Thus, we split the model into two parts. ```py import pickle -docs = Docs(llm_model=LangchainLLMModel(), +docs = Docs(llm="langchain", client=ChatAnthropic()) model_str = pickle.dumps(docs) docs = pickle.loads(model_str) @@ -118,7 +118,6 @@ docs = pickle.loads(model_str) docs.set_client(ChatAnthropic()) ``` - #### Locally Hosted You can use llama.cpp to be the LLM. Note that you should be using relatively large models, because paper-qa requires following a lot of instructions. You won't get good performance with 7B models. diff --git a/paperqa/docs.py b/paperqa/docs.py index c9c45dc38..8d486a111 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -103,6 +103,10 @@ def __init__(self, **data): super().__init__(**data) self._client = client self._embedding_client = embedding_client + # run this here (instead of automateically) so it has access to privates + # If I ever figure out a better way of validating privates + # I can move this back to the decorator + Docs.make_llm_names_consistent(self) @model_validator(mode="before") @classmethod @@ -136,7 +140,6 @@ def setup_alias_models(cls, data: Any) -> Any: raise ValueError( f"Could not guess embedding model type for {data['embedding']}. " ) - return data @model_validator(mode="after") @@ -157,6 +160,24 @@ def config_summary_llm_config(cls, data: Any) -> Any: data.summary_llm_model = data.llm_model return data + @classmethod + def make_llm_names_consistent(cls, data: Any) -> Any: + if isinstance(data, Docs): + data.llm = data.llm_model.name + if data.llm == "langchain": + # from langchain models - kind of hacky + # langchain models cannot know type until + # it sees client + data.llm_model.infer_llm_type(data._client) + data.llm = data.llm_model.name + if data.summary_llm_model is not None: + if data.summary_llm == "langchain": + # from langchain models - kind of hacky + data.summary_llm_model.infer_llm_type(data._client) + data.summary_llm = data.summary_llm_model.name + + return data + def clear_docs(self): self.texts = [] self.docs = {} @@ -193,6 +214,7 @@ def set_client( else: embedding_client = AsyncOpenAI() self._embedding_client = embedding_client + Docs.make_llm_names_consistent(self) def _get_unique_name(self, docname: str) -> str: """Create a unique name given proposed name""" diff --git a/paperqa/llms.py b/paperqa/llms.py index a36ecba9d..7f00f67bb 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -100,6 +100,7 @@ async def embed_documents(self, client: Any, texts: list[str]) -> list[list[floa class LLMModel(ABC, BaseModel): llm_type: str | None = None + name: str model_config = ConfigDict(extra="forbid") async def acomplete(self, client: Any, prompt: str) -> str: @@ -208,6 +209,7 @@ async def execute( class OpenAILLMModel(LLMModel): config: dict = Field(default=dict(model="gpt-3.5-turbo", temperature=0.1)) + name: str = "gpt-3.5-turbo" def _check_client(self, client: Any) -> AsyncOpenAI: if client is None: @@ -227,6 +229,13 @@ def guess_llm_type(cls, data: Any) -> Any: m.llm_type = guess_model_type(m.config["model"]) return m + @model_validator(mode="after") + @classmethod + def set_model_name(cls, data: Any) -> Any: + m = cast(OpenAILLMModel, data) + m.name = m.config["model"] + return m + async def acomplete(self, client: Any, prompt: str) -> str: aclient = self._check_client(client) completion = await aclient.completions.create( @@ -428,9 +437,12 @@ async def similarity_search( class LangchainLLMModel(LLMModel): """A wrapper around the wrapper langchain""" + name: str = "langchain" + def infer_llm_type(self, client: Any) -> str: from langchain_core.language_models.chat_models import BaseChatModel + self.name = client.model_name if isinstance(client, BaseChatModel): return "chat" return "completion" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index d9d05d417..fc483e3de 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -406,14 +406,15 @@ def accum(x): def test_docs(): - llm_config = dict(temperature=0.1, model="text-ada-001", model_type="completion") - docs = Docs(llm_model=OpenAILLMModel(config=llm_config)) + docs = Docs(llm="babbage-002") docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", dockey="test", ) assert docs.docs["test"].docname == "Wiki2023" + assert docs.llm == "babbage-002" + assert docs.summary_llm == "babbage-002" def test_evidence(): @@ -486,6 +487,8 @@ async def embed_documents(self, client, texts): def test_custom_llm(): class MyLLM(LLMModel): + name: str = "myllm" + async def acomplete(self, client, prompt): assert client is None return "Echo" @@ -502,6 +505,8 @@ async def acomplete(self, client, prompt): def test_custom_llm_stream(): class MyLLM(LLMModel): + name: str = "myllm" + async def acomplete_iter(self, client, prompt): assert client is None yield "Echo" @@ -522,6 +527,8 @@ def test_langchain_llm(): from langchain_openai import ChatOpenAI, OpenAI docs = Docs(llm="langchain", client=ChatOpenAI(model="gpt-3.5-turbo")) + assert docs.llm == "gpt-3.5-turbo" + assert docs.summary_llm == "gpt-3.5-turbo" docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", @@ -567,7 +574,9 @@ def test_langchain_llm(): docs_pickle = pickle.dumps(docs) docs2 = pickle.loads(docs_pickle) assert docs2._client is None + assert docs2.llm == "babbage-002" docs2.set_client(OpenAI(model="babbage-002")) + assert docs2.summary_llm == "babbage-002" docs2.get_evidence( Answer(question="What is Frederick Bates's greatest accomplishment?"), get_callbacks=lambda x: [lambda y: print(y)], From 8fb8bdd6983e0e43db5c31920690836a3e30d7c6 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 22 Jan 2024 13:12:42 -0800 Subject: [PATCH 16/16] Fixed text embedding errors --- paperqa/docs.py | 62 +++++++++++++++++++++------- paperqa/llms.py | 94 ++++++++++++++++++++++++++++++++----------- paperqa/readers.py | 36 ++++++++--------- paperqa/types.py | 31 ++++++++++++-- tests/test_paperqa.py | 24 +++++++++-- 5 files changed, 182 insertions(+), 65 deletions(-) diff --git a/paperqa/docs.py b/paperqa/docs.py index 8d486a111..bb0b026bf 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -23,7 +23,16 @@ ) from .paths import PAPERQA_DIR from .readers import read_doc -from .types import Answer, CallbackFactory, Context, Doc, DocKey, PromptCollection, Text +from .types import ( + Answer, + CallbackFactory, + Context, + Doc, + DocKey, + LLMResult, + PromptCollection, + Text, +) from .utils import ( gather_with_concurrency, guess_is_4xx, @@ -103,7 +112,7 @@ def __init__(self, **data): super().__init__(**data) self._client = client self._embedding_client = embedding_client - # run this here (instead of automateically) so it has access to privates + # run this here (instead of automatically) so it has access to privates # If I ever figure out a better way of validating privates # I can move this back to the decorator Docs.make_llm_names_consistent(self) @@ -171,11 +180,15 @@ def make_llm_names_consistent(cls, data: Any) -> Any: data.llm_model.infer_llm_type(data._client) data.llm = data.llm_model.name if data.summary_llm_model is not None: + if ( + data.summary_llm is None + and data.summary_llm_model is data.llm_model + ): + data.summary_llm = data.llm if data.summary_llm == "langchain": # from langchain models - kind of hacky data.summary_llm_model.infer_llm_type(data._client) data.summary_llm = data.summary_llm_model.name - return data def clear_docs(self): @@ -188,6 +201,9 @@ def __getstate__(self): # to be overriding the behavior on setstaet/getstate anyway. # The reason is that the other serialization methods from Pydantic - # model_dump - will not drop private attributes. + # So - this getstate/setstate removes private attributes for pickling + # and Pydantic will handle removing private attributes for other + # serialization methods (like model_dump) state = super().__getstate__() # remove client from private attributes del state["__pydantic_private__"]["_client"] @@ -301,9 +317,10 @@ def add( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - citation = asyncio.run( + chain_result = asyncio.run( cite_chain(dict(text=texts[0].text), None), ) + citation = chain_result.text if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -405,6 +422,7 @@ async def adoc_match( k: int = 25, rerank: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, + answer: Answer | None = None, # used for tracking tokens ) -> set[DocKey]: """Return a list of dockeys that match the query.""" matches, _ = await self.docs_index.max_marginal_relevance_search( @@ -440,7 +458,9 @@ async def adoc_match( dict(question=query, papers="\n".join(papers)), get_callbacks("filter"), ) - return set([d.dockey for d in matched_docs if d.docname in result]) + if answer: + answer.add_tokens(result) + return set([d.dockey for d in matched_docs if d.docname in str(result)]) except AttributeError: pass return set([d.dockey for d in matched_docs]) @@ -528,6 +548,8 @@ async def aget_evidence( async def process(match): callbacks = get_callbacks("evidence:" + match.name) citation = match.doc.citation + # empty result + llm_result = LLMResult(model="", date="") if detailed_citations: citation = match.name + ": " + citation @@ -547,7 +569,7 @@ async def process(match): # my best idea is see if there is a 4XX # http code in the exception try: - context = await summary_chain( + llm_result = await summary_chain( dict( question=answer.question, # Add name so chunk is stated @@ -557,15 +579,16 @@ async def process(match): ), callbacks, ) + context = llm_result.text except Exception as e: if guess_is_4xx(str(e)): - return None + return None, llm_result raise e if ( "not applicable" in context.lower() or "not relevant" in context.lower() ): - return None + return None, llm_result if self.strip_citations: # remove citations that collide with our grounded citations (for the answer LLM) context = strip_citations(context) @@ -580,13 +603,16 @@ async def process(match): ), score=score, ) - return c + return c, llm_result results = await gather_with_concurrency( self.max_concurrent, [process(m) for m in matches] ) + # update token counts + [answer.add_tokens(r[1]) for r in results] + # filter out failures - contexts = [c for c in results if c is not None] + contexts = [c for c, r in results if c is not None] answer.contexts = sorted( contexts + answer.contexts, key=lambda x: x.score, reverse=True @@ -646,7 +672,9 @@ async def aquery( # comparable - one is chunks and one is docs if key_filter or (key_filter is None and len(self.docs) > k): keys = await self.adoc_match( - answer.question, get_callbacks=get_callbacks + answer.question, + get_callbacks=get_callbacks, + answer=answer, ) if len(keys) > 0: answer.dockey_filter = keys @@ -663,7 +691,10 @@ async def aquery( system_prompt=self.prompts.system, ) pre = await chain(dict(question=answer.question), get_callbacks("pre")) - answer.context = answer.context + "\n\nExtra background information:" + pre + answer.add_tokens(pre) + answer.context = ( + answer.context + "\n\nExtra background information:" + str(pre) + ) bib = dict() if len(answer.context) < 10: # and not self.memory: answer_text = ( @@ -675,7 +706,7 @@ async def aquery( prompt=self.prompts.qa, system_prompt=self.prompts.system, ) - answer_text = await qa_chain( + answer_result = await qa_chain( dict( context=answer.context, answer_length=answer.answer_length, @@ -683,6 +714,8 @@ async def aquery( ), get_callbacks("answer"), ) + answer_text = answer_result.text + answer.add_tokens(answer_result) # it still happens if "(Example2012Example pages 3-4)" in answer_text: answer_text = answer_text.replace("(Example2012Example pages 3-4)", "") @@ -709,7 +742,8 @@ async def aquery( system_prompt=self.prompts.system, ) post = await chain(answer.model_dump(), get_callbacks("post")) - answer.answer = post + answer.answer = post.text + answer.add_tokens(post) answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n" if len(bib) > 0: answer.formatted_answer += f"\nReferences\n\n{bib_str}\n" diff --git a/paperqa/llms.py b/paperqa/llms.py index 7f00f67bb..dea6645c2 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -1,3 +1,5 @@ +import asyncio +import datetime import re from abc import ABC, abstractmethod from inspect import signature @@ -18,7 +20,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from .prompts import default_system_prompt -from .types import Embeddable +from .types import Doc, Embeddable, LLMResult, Text from .utils import batch_iter, flatten, gather_with_concurrency @@ -59,13 +61,14 @@ def process_llm_config(llm_config: dict) -> dict: result = {k: v for k, v in llm_config.items() if k != "model_type"} if "max_tokens" not in result or result["max_tokens"] == -1: model = llm_config["model"] - # now we guess! + # now we guess - we could use tiktoken to count, + # but do have the initative right now if model.startswith("gpt-4") or ( model.startswith("gpt-3.5") and "1106" in model ): - result["max_tokens"] = 4096 + result["max_tokens"] = 3000 else: - result["max_tokens"] = 2048 # ? + result["max_tokens"] = 1500 return result @@ -124,13 +127,18 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: def infer_llm_type(self, client: Any) -> str: return "completion" + def count_tokens(self, text: str) -> int: + return len(text) // 4 # gross approximation + def make_chain( self, client: Any, prompt: str, skip_system: bool = False, system_prompt: str = default_system_prompt, - ) -> Callable[[dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, str]]: + ) -> Callable[ + [dict, list[Callable[[str], None]] | None], Coroutine[Any, Any, LLMResult] + ]: """Create a function to execute a batch of prompts This replaces the previous use of langchain for combining prompts and LLMs. @@ -143,7 +151,7 @@ def make_chain( Returns: A function to execute a prompt. Its signature is: - execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> str + execute(data: dict, callbacks: list[Callable[[str], None]]] | None = None) -> LLMResult where data is a dict with keys for the input variables that will be formatted into prompt and callbacks is a list of functions to call with each chunk of the completion. """ @@ -160,21 +168,39 @@ def make_chain( async def execute( data: dict, callbacks: list[Callable[[str], None]] | None = None - ) -> str: + ) -> LLMResult: + start_clock = asyncio.get_running_loop().time() + result = LLMResult( + model=self.name, + date=datetime.datetime.now().isoformat(), + ) messages = chat_prompt[:-1] + [ dict(role="user", content=chat_prompt[-1]["content"].format(**data)) ] + result.prompt_count = sum( + [self.count_tokens(m["content"]) for m in messages] + ) + sum([self.count_tokens(m["role"]) for m in messages]) + if callbacks is None: output = await self.achat(client, messages) else: completion = self.achat_iter(client, messages) # type: ignore - result = [] + text_result = [] async for chunk in completion: # type: ignore if chunk: - result.append(chunk) + if result.seconds_to_first_token == 0: + result.seconds_to_first_token = ( + asyncio.get_running_loop().time() - start_clock + ) + text_result.append(chunk) [f(chunk) for f in callbacks] - output = "".join(result) - return output + output = "".join(text_result) + result.completion_count = self.count_tokens(output) + result.text = output + result.seconds_to_last_token = ( + asyncio.get_running_loop().time() - start_clock + ) + return result return execute elif self.llm_type == "completion": @@ -185,23 +211,38 @@ async def execute( async def execute( data: dict, callbacks: list[Callable[[str], None]] | None = None - ) -> str: + ) -> LLMResult: + start_clock = asyncio.get_running_loop().time() + result = LLMResult( + model=self.name, + date=datetime.datetime.now().isoformat(), + ) + formatted_prompt = completion_prompt.format(**data) + result.prompt_count = self.count_tokens(formatted_prompt) + if callbacks is None: - output = await self.acomplete( - client, completion_prompt.format(**data) - ) + output = await self.acomplete(client, formatted_prompt) else: completion = self.acomplete_iter( # type: ignore client, - completion_prompt.format(**data), + formatted_prompt, ) - result = [] + text_result = [] async for chunk in completion: # type: ignore if chunk: - result.append(chunk) + if result.seconds_to_first_token == 0: + result.seconds_to_first_token = ( + asyncio.get_running_loop().time() - start_clock + ) + text_result.append(chunk) [f(chunk) for f in callbacks] - output = "".join(result) - return output + output = "".join(text_result) + result.completion_count = self.count_tokens(output) + result.text = output + result.seconds_to_last_token = ( + asyncio.get_running_loop().time() - start_clock + ) + return result return execute raise ValueError(f"Unknown llm_type: {self.llm_type}") @@ -553,9 +594,16 @@ def add_texts_and_embeddings(self, texts: Sequence[Embeddable]) -> None: if self._store_builder is None: raise ValueError("You must set store_builder before adding texts") self.class_type = type(texts[0]) - vec_store_text_and_embeddings = list( - map(lambda x: (x.text, x.embedding), texts) - ) + if self.class_type == Text: + vec_store_text_and_embeddings = list( + map(lambda x: (x.text, x.embedding), cast(list[Text], texts)) + ) + elif self.class_type == Doc: + vec_store_text_and_embeddings = list( + map(lambda x: (x.citation, x.embedding), cast(list[Doc], texts)) + ) + else: + raise ValueError("Only embeddings of type Text are supported") if self._store is None: self._store = self._store_builder( # type: ignore vec_store_text_and_embeddings, diff --git a/paperqa/readers.py b/paperqa/readers.py index 816e75fc6..f23630863 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -1,3 +1,4 @@ +from math import ceil from pathlib import Path from typing import List @@ -94,30 +95,25 @@ def parse_txt( # we tokenize using tiktoken so cuts are in reasonable places # See https://github.com/openai/tiktoken enc = tiktoken.get_encoding("cl100k_base") - encoded = [enc.decode_single_token_bytes(token) for token in enc.encode(text)] - split_size = 0 - split_flat = "" + encoded = enc.encode_ordinary(text) split = [] - for chunk in encoded: - split.append(chunk) - split_size += len(chunk) - if split_size > chunk_chars: - split_flat = b"".join(split).decode() - texts.append( - Text( - text=split_flat[:chunk_chars], - name=f"{doc.docname} chunk {len(texts) + 1}", - doc=doc, - ) + # convert from characters to chunks + char_count = len(text) # e.g., 25,000 + token_count = len(encoded) # e.g., 4,500 + chars_per_token = char_count / token_count # e.g., 5.5 + chunk_tokens = chunk_chars / chars_per_token # e.g., 3000 / 5.5 = 545 + overlap_tokens = overlap / chars_per_token # e.g., 100 / 5.5 = 18 + chunk_count = ceil(token_count / chunk_tokens) # e.g., 4500 / 545 = 9 + for i in range(chunk_count): + split = encoded[ + max(int(i * chunk_tokens - overlap_tokens), 0) : int( + (i + 1) * chunk_tokens + overlap_tokens ) - split = [split_flat[chunk_chars - overlap :].encode("utf-8")] - split_size = len(split[0]) - if split_size > overlap or len(texts) == 0: - split_flat = b"".join(split).decode() + ] texts.append( Text( - text=split_flat[:chunk_chars], - name=f"{doc.docname} lines {len(texts) + 1}", + text=enc.decode(split), + name=f"{doc.docname} chunk {i + 1}", doc=doc, ) ) diff --git a/paperqa/types.py b/paperqa/types.py index f5e6488eb..d259ac6e2 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -16,6 +16,19 @@ CallbackFactory = Callable[[str], list[Callable[[str], None]] | None] +class LLMResult(BaseModel): + text: str = "" + prompt_count: int = 0 + completion_count: int = 0 + model: str + date: str + seconds_to_first_token: float = 0 + seconds_to_last_token: float = 0 + + def __str__(self): + return self.text + + class Embeddable(BaseModel): embedding: list[float] | None = Field(default=None, repr=False) @@ -137,11 +150,10 @@ class Answer(BaseModel): summary_length: str = "about 100 words" answer_length: str = "about 100 words" memory: str | None = None - # these two below are for convenience - # and are not set. But you can set them - # if you want to use them. + # just for convenience you can override this cost: float | None = None - token_counts: dict[str, list[int]] | None = None + # key is model name, value is (prompt, completion) token counts + token_counts: dict[str, list[int]] = Field(default_factory=dict) model_config = ConfigDict(extra="forbid") def __str__(self) -> str: @@ -155,3 +167,14 @@ def get_citation(self, name: str) -> str: except StopIteration: raise ValueError(f"Could not find docname {name} in contexts") return doc.citation + + def add_tokens(self, result: LLMResult): + """Update the token counts for the given result.""" + if result.model not in self.token_counts: + self.token_counts[result.model] = [ + result.prompt_count, + result.completion_count, + ] + else: + self.token_counts[result.model][0] += result.prompt_count + self.token_counts[result.model][1] += result.completion_count diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index fc483e3de..57d9a4175 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -382,8 +382,14 @@ def accum(x): outputs.append(x) completion = await call(dict(animal="duck"), callbacks=[accum]) - assert completion == "".join(outputs) - assert type(completion) == str + assert completion.seconds_to_first_token > 0 + assert completion.prompt_count > 0 + assert completion.completion_count > 0 + assert str(completion) == "".join(outputs) + + completion = await call(dict(animal="duck")) + assert completion.seconds_to_first_token == 0 + assert completion.seconds_to_last_token > 0 async def test_chain_chat(self): client = AsyncOpenAI() @@ -401,8 +407,14 @@ def accum(x): outputs.append(x) completion = await call(dict(animal="duck"), callbacks=[accum]) - assert completion == "".join(outputs) - assert type(completion) == str + assert completion.seconds_to_first_token > 0 + assert completion.prompt_count > 0 + assert completion.completion_count > 0 + assert str(completion) == "".join(outputs) + + completion = await call(dict(animal="duck")) + assert completion.seconds_to_first_token == 0 + assert completion.seconds_to_last_token > 0 def test_docs(): @@ -527,6 +539,8 @@ def test_langchain_llm(): from langchain_openai import ChatOpenAI, OpenAI docs = Docs(llm="langchain", client=ChatOpenAI(model="gpt-3.5-turbo")) + assert type(docs.llm_model) == LangchainLLMModel + assert type(docs.summary_llm_model) == LangchainLLMModel assert docs.llm == "gpt-3.5-turbo" assert docs.summary_llm == "gpt-3.5-turbo" docs.add_url( @@ -914,6 +928,8 @@ def test_citation(): assert ( list(docs.docs.values())[0].docname == "Wikipedia2024" or list(docs.docs.values())[0].docname == "Frederick2024" + or list(docs.docs.values())[0].docname == "Wikipedia" + or list(docs.docs.values())[0].docname == "Frederick" )