From 0009c0c159703c4e0921a1fe0c0f06c0e3f105ce Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 3 Jan 2024 22:47:26 -0800 Subject: [PATCH] Improved unit tests to be closer --- paperqa/docs.py | 122 ++++++++++++++++--------------- paperqa/{chains.py => llms.py} | 35 ++++++++- paperqa/prompts.py | 10 +-- paperqa/readers.py | 47 ++++++++---- paperqa/types.py | 33 +++++---- tests/test_paperqa.py | 127 +++++++++------------------------ 6 files changed, 188 insertions(+), 186 deletions(-) rename paperqa/{chains.py => llms.py} (82%) diff --git a/paperqa/docs.py b/paperqa/docs.py index 70b8e30f6..7b851ca18 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -1,7 +1,7 @@ +import nest_asyncio # isort:skip import asyncio import os import re -import sys import tempfile from datetime import datetime from io import BytesIO @@ -11,7 +11,7 @@ from openai import AsyncOpenAI from pydantic import BaseModel, Field, field_validator, model_validator -from .chains import get_score, make_chain +from .llms import embed_documents, get_score, guess_model_type, make_chain from .paths import PAPERQA_DIR from .readers import read_doc from .types import ( @@ -36,17 +36,20 @@ strip_citations, ) +# Apply the patch to allow nested loops +nest_asyncio.apply() + class Docs(BaseModel): """A collection of documents to be used for answering questions.""" - _client: AsyncOpenAI + _client: AsyncOpenAI | None docs: dict[DocKey, Doc] = {} texts: list[Text] = [] docnames: set[str] = set() texts_index: VectorStore = NumpyVectorStore() doc_index: VectorStore = NumpyVectorStore() - llm_config: dict = dict(model="gpt-3.5-turbo") + llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat") summary_llm_config: dict | None = Field(default=None, validate_default=True) name: str = "default" index_path: Path | None = PAPERQA_DIR / name @@ -69,17 +72,17 @@ def __init__(self, **data): @field_validator("llm_config", "summary_llm_config") @classmethod - def llm_config_has_type(cls, v): + def llm_guess_model_type(cls, v: dict) -> dict: if v is not None and "model_type" not in v: - raise ValueError("Must specify if chat or completion model.") + v["model_type"] = guess_model_type(v["model"]) return v @model_validator(mode="after") @classmethod def config_summary_llm_conig(cls, data: Any) -> Any: - if isinstance(data, dict): - if data["summary_llm_config"] is None: - data["summary_llm_config"] = data["llm_config"] + if isinstance(data, Docs): + if data.summary_llm_config is None: + data.summary_llm_config = data.llm_config return data def clear_docs(self): @@ -87,6 +90,21 @@ def clear_docs(self): self.docs = {} self.docnames = set() + def __getstate__(self): + state = super().__getstate__() + # remove client from private attributes + del state["__pydantic_private__"]["_client"] + return state + + def __setstate__(self, state): + super().__setstate__(state) + self._client = None + + def set_client(self, client: AsyncOpenAI | None = None): + if client is None: + client = AsyncOpenAI() + self._client = client + def _get_unique_name(self, docname: str) -> str: """Create a unique name given proposed name""" suffix = "" @@ -173,12 +191,7 @@ def add( texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100) if len(texts) == 0: raise ValueError(f"Could not read document {path}. Is it empty?") - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - citation = loop.run_until_complete(cite_chain(texts[0].text)) + citation = asyncio.run(cite_chain(data=dict(text=texts[0].text))) if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" @@ -234,14 +247,20 @@ def add_texts( t.name = t.name.replace(doc.docname, new_docname) doc.docname = new_docname if texts[0].embedding is None: - text_embeddings = self.embeddings.embed_documents([t.text for t in texts]) + text_embeddings = asyncio.run( + embed_documents( + self._client, [t.text for t in texts], self.embeddings_model + ) + ) for i, t in enumerate(texts): t.embedding = text_embeddings[i] if doc.embedding is None: - doc.embedding = self.embeddings.embed_documents([doc.citation])[0] + doc.embedding = asyncio.run( + embed_documents(self._client, [doc.citation], self.embeddings_model) + )[0] if not self.jit_texts_index: self.texts_index.add_texts_and_embeddings(texts) - self.doc_index.add_texts_and_embeddings(doc) + self.doc_index.add_texts_and_embeddings([doc]) self.docs[doc.dockey] = doc self.texts += texts self.docnames.add(doc.docname) @@ -266,8 +285,13 @@ async def adoc_match( get_callbacks: CallbackFactory = lambda x: None, ) -> set[DocKey]: """Return a list of dockeys that match the query.""" + query_vector = ( + await embed_documents(self._client, [query], self.embeddings_model) + )[0] matches, _ = self.doc_index.max_marginal_relevance_search( - query, k=k + len(self.deleted_dockeys) + query_vector, + k=k + len(self.deleted_dockeys), + fetch_k=5 * (k + len(self.deleted_dockeys)), ) # filter the matches matched_docs = [m for m in matches if m.dockey not in self.deleted_dockeys] @@ -277,7 +301,7 @@ async def adoc_match( try: if ( rerank is None - and self.llm_config.model.startswith("gpt-4") + and self.llm_config["model"].startswith("gpt-4") or rerank is True ): chain = make_chain( @@ -296,7 +320,7 @@ async def adoc_match( pass return set([d.dockey for d in matched_docs]) - def _build_texts_index(self, keys: set[DocKey] = None): + def _build_texts_index(self, keys: set[DocKey] | None = None): if keys is not None and self.jit_texts_index: texts = self.texts if keys is not None: @@ -317,17 +341,7 @@ def get_evidence( disable_vector_search: bool = False, disable_summarization: bool = False, ) -> Answer: - # special case for jupyter notebooks - if "get_ipython" in globals() or "google.colab" in sys.modules: - import nest_asyncio - - nest_asyncio.apply() - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( + return asyncio.run( self.aget_evidence( answer, k=k, @@ -362,12 +376,17 @@ async def aget_evidence( if disable_vector_search: matches = self.texts else: + query_vector = ( + await embed_documents( + self._client, [answer.question], self.embeddings_model + ) + )[0] if marginal_relevance: - matches = self.texts_index.max_marginal_relevance_search( - answer.question, k=_k, fetch_k=5 * _k + matches, _ = self.texts_index.max_marginal_relevance_search( + query_vector, k=_k, fetch_k=5 * _k ) else: - matches = self.texts_index.similarity_search(answer.question, k=_k) + matches, _ = self.texts_index.similarity_search(query_vector, k=_k) # ok now filter (like ones from adoc_match) if answer.dockey_filter is not None: matches = [m for m in matches if m.doc.dockey in answer.dockey_filter] @@ -406,7 +425,7 @@ async def process(match): # http code in the exception try: context = await summary_chain( - dict( + data=dict( question=answer.question, # Add name so chunk is stated citation=citation, @@ -434,7 +453,7 @@ async def process(match): text=Text( text=match.text, name=match.name, - doc=Doc(**match.doc), + doc=Doc(**match.doc.model_dump()), ), score=score, ) @@ -474,17 +493,7 @@ def query( key_filter: bool | None = None, get_callbacks: CallbackFactory = lambda x: None, ) -> Answer: - # special case for jupyter notebooks - if "get_ipython" in globals() or "google.colab" in sys.modules: - import nest_asyncio - - nest_asyncio.apply() - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( + return asyncio.run( self.aquery( query, k=k, @@ -540,12 +549,11 @@ async def aquery( ) answer.context = answer.context + "\n\nExtra background information:" + pre bib = dict() - if len(answer.context) < 10 and not self.memory: + if len(answer.context) < 10: # and not self.memory: answer_text = ( "I cannot answer this question due to insufficient information." ) else: - callbacks = get_callbacks("answer") qa_chain = make_chain( client=self._client, prompt=self.prompts.qa, @@ -558,7 +566,7 @@ async def aquery( answer_length=answer.answer_length, question=answer.question, ), - callbacks=callbacks, + callbacks=get_callbacks("answer"), ) # it still happens if "(Example2012)" in answer_text: @@ -586,17 +594,17 @@ async def aquery( llm_config=self.llm_config, system_prompt=self.prompts.system, ) - post = await chain.arun( + post = await chain( data=answer.model_dump(), callbacks=get_callbacks("post") ) answer.answer = post answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n" if len(bib) > 0: answer.formatted_answer += f"\nReferences\n\n{bib_str}\n" - if self.memory_model is not None: - answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"] - self.memory_model.save_context( - {"Question": answer.question}, {"Answer": answer.answer} - ) + # if self.memory_model is not None: + # answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"] + # self.memory_model.save_context( + # {"Question": answer.question}, {"Answer": answer.answer} + # ) return answer diff --git a/paperqa/chains.py b/paperqa/llms.py similarity index 82% rename from paperqa/chains.py rename to paperqa/llms.py index d06c8128f..18fff67bf 100644 --- a/paperqa/chains.py +++ b/paperqa/llms.py @@ -1,11 +1,22 @@ import re -from typing import Callable +from typing import Any, Awaitable, Callable, get_args, get_type_hints from openai import AsyncOpenAI from .prompts import default_system_prompt -default_system_prompt = "End your responses with [END]" + +def guess_model_type(model_name: str) -> str: + import openai + + model_type = get_type_hints( + openai.types.chat.completion_create_params.CompletionCreateParamsBase + )["model"] + model_union = get_args(get_args(model_type)[1]) + model_arr = list(model_union) + if model_name in model_arr: + return "chat" + return "completion" def process_llm_config(llm_config: dict) -> dict: @@ -23,13 +34,27 @@ def process_llm_config(llm_config: dict) -> dict: return result +async def embed_documents( + client: AsyncOpenAI, texts: list[str], embedding_model: str +) -> list[list[float]]: + """Embed a list of documents with batching""" + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) + response = await client.embeddings.create( + model=embedding_model, input=texts, encoding_format="float" + ) + return [e.embedding for e in response.data] + + def make_chain( client: AsyncOpenAI, prompt: str, llm_config: dict, skip_system: bool = False, system_prompt: str = default_system_prompt, -) -> Callable[[list[dict], list[Callable[[str], None]] | None], list[str]]: +) -> Awaitable[Any]: """Create a function to execute a batch of prompts Args: @@ -45,6 +70,10 @@ def make_chain( where data is a dict with keys for the input variables that will be formatted into prompt and callbacks is a list of functions to call with each chunk of the completion. """ + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) if llm_config["model_type"] == "chat": system_message_prompt = dict(role="system", content=system_prompt) human_message_prompt = dict(role="user", content=prompt) diff --git a/paperqa/prompts.py b/paperqa/prompts.py index e77852138..1432244ce 100644 --- a/paperqa/prompts.py +++ b/paperqa/prompts.py @@ -10,7 +10,7 @@ "{text}\n\n" "Excerpt from {citation}\n" "Question: {question}\n" - "Relevant Information Summary:", + "Relevant Information Summary:" ) qa_prompt = ( @@ -22,7 +22,7 @@ "via valid citation markers at the end of sentences, like (Example2012). \n" "Context (with relevance scores):\n {context}\n" "Question: {question}\n" - "Answer: ", + "Answer: " ) select_paper_prompt = ( @@ -34,12 +34,12 @@ "(if the question requires timely information). \n\n" "Question: {question}\n\n" "Papers: {papers}\n\n" - "Selected keys:", + "Selected keys:" ) citation_prompt = ( - "Provide the citation for the following text in MLA Format.\n\n" + "Provide the citation for the following text in MLA Format. If reporting date accessed, the current year is 2024\n\n" "{text}\n\n" - "Citation:", + "Citation:" ) default_system_prompt = ( diff --git a/paperqa/readers.py b/paperqa/readers.py index de6b901f6..fd8be1b8a 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import List +import tiktoken from html2text import html2text from .types import Doc, Text @@ -75,6 +76,10 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text def parse_txt( path: Path, doc: Doc, chunk_chars: int, overlap: int, html: bool = False ) -> List[Text]: + """Parse a document into chunks, based on tiktoken encoding. + + NOTE: We get some byte continuation errors. Currnetly ignored, but should explore more to make sure we don't miss anything. + """ try: with open(path) as f: text = f.read() @@ -83,18 +88,36 @@ def parse_txt( text = f.read() if html: text = html2text(text) - # chunk into size chunk_chars with overlap overlap - raw_texts = [] - start = 0 - while start < len(text): - end = min(start + chunk_chars, len(text)) - raw_texts.append(text[start:end]) - start = end - overlap - - texts = [ - Text(text=t, name=f"{doc.docname} chunk {i}", doc=doc) - for i, t in enumerate(raw_texts) - ] + texts = [] + # we tokenize using tiktoken so cuts are in reasonable places + enc = tiktoken.get_encoding("cl100k_base") + encoded = [enc.decode_single_token_bytes(token) for token in enc.encode(text)] + split_size = 0 + split_flat = "" + split = [] + for chunk in encoded: + split.append(chunk) + split_size += len(chunk) + if split_size > chunk_chars: + split_flat = b"".join(split).decode() + texts.append( + Text( + text=split_flat[:chunk_chars], + name=f"{doc.docname} chunk {len(texts) + 1}", + doc=doc, + ) + ) + split = [split_flat[chunk_chars - overlap :].encode("utf-8")] + split_size = len(split[0]) + if len(split) > overlap: + split_flat = b"".join(split).decode() + texts.append( + Text( + text=split_flat[:chunk_chars], + name=f"{doc.docname} lines {len(texts) + 1}", + doc=doc, + ) + ) return texts diff --git a/paperqa/types.py b/paperqa/types.py index a7603f894..1e80c8caa 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -50,7 +50,7 @@ def add_texts_and_embeddings(self, texts: list[Embeddable]) -> None: @abstractmethod def similarity_search( self, query: list[float], k: int - ) -> list[tuple[Embeddable, float]]: + ) -> tuple[list[Embeddable], list[float]]: pass @abstractmethod @@ -59,7 +59,7 @@ def clear(self) -> None: def max_marginal_relevance_search( self, query: list[float], k: int, fetch_k: int, lambda_: float = 0.5 - ) -> list[tuple[Embeddable, float]]: + ) -> tuple[list[Embeddable], list[float]]: """Vectorized implementation of Maximal Marginal Relevance (MMR) search. Args: @@ -73,16 +73,16 @@ def max_marginal_relevance_search( if fetch_k < k: raise ValueError("fetch_k must be greater or equal to k") - initial_results = self.similarity_search(query, fetch_k) - if len(initial_results) <= k: - return initial_results + texts, scores = self.similarity_search(query, fetch_k) + if len(texts) <= k: + return texts, scores - embeddings = np.array([t.embedding for t, _ in initial_results]) - scores = np.array([score for _, score in initial_results]) + embeddings = np.array([t.embedding for t in texts]) + scores = np.array(scores) similarity_matrix = cosine_similarity(embeddings, embeddings) selected_indices = [0] - remaining_indices = list(range(1, len(initial_results))) + remaining_indices = list(range(1, len(texts))) while len(selected_indices) < k: selected_similarities = similarity_matrix[:, selected_indices] @@ -95,7 +95,9 @@ def max_marginal_relevance_search( selected_indices.append(max_mmr_index) remaining_indices.remove(max_mmr_index) - return [(initial_results[i][0], scores[i]) for i in selected_indices] + return [texts[i] for i in selected_indices], [ + scores[i] for i in selected_indices + ] class NumpyVectorStore(VectorStore): @@ -115,16 +117,19 @@ def add_texts_and_embeddings( def similarity_search( self, query: list[float], k: int - ) -> list[tuple[Embeddable, float]]: + ) -> tuple[list[Embeddable], list[float]]: if len(self.texts) == 0: - return [] - query = np.array(query) + return [], [] + np_query = np.array(query) similarity_scores = cosine_similarity( - query.reshape(1, -1), self._embeddings_matrix + np_query.reshape(1, -1), self._embeddings_matrix )[0] similarity_scores = np.nan_to_num(similarity_scores, nan=-np.inf) sorted_indices = np.argsort(similarity_scores)[::-1] - return [(self.texts[i], similarity_scores[i]) for i in sorted_indices[:k]] + return ( + [self.texts[i] for i in sorted_indices[:k]], + [similarity_scores[i] for i in sorted_indices[:k]], + ) class _FormatDict(dict): diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index b467d398a..1ce15db55 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -7,7 +7,7 @@ import requests from paperqa import Answer, Doc, Docs, PromptCollection, Text -from paperqa.chains import get_score +from paperqa.llms import get_score from paperqa.readers import read_doc from paperqa.utils import ( iter_citations, @@ -370,15 +370,6 @@ def test_docs(): assert docs.docs["test"].docname == "Wiki2023" -def test_update_llm(): - doc = Docs() - doc.update_llm("gpt-3.5-turbo") - assert doc.llm == doc.summary_llm - - doc.update_llm(OpenAI(client=None, temperature=0.1, model="text-ada-001")) - assert doc.llm == doc.summary_llm - - def test_evidence(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: @@ -451,7 +442,6 @@ async def test_adoc_match(self): "What is Frederick Bates's greatest accomplishment?" ) assert len(sources) > 0 - docs.update_llm("gpt-3.5-turbo") sources = await docs.adoc_match( "What is Frederick Bates's greatest accomplishment?" ) @@ -464,15 +454,20 @@ def test_docs_pickle(): # get front page of wikipedia r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") f.write(r.text) - llm = OpenAI(client=None, temperature=0.0, model="text-curie-001") - docs = Docs(llm=llm) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) + old_config = docs.llm_config docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) os.remove(doc_path) docs_pickle = pickle.dumps(docs) docs2 = pickle.loads(docs_pickle) - docs2.update_llm(llm) - assert llm.model_name == docs2.llm.model_name - assert docs2.summary_llm.model_name == docs2.llm.model_name + # make sure it fails if we haven't set client + try: + docs2.query("What date is bring your dog to work in the US?") + except ValueError: + pass + docs2.set_client() + assert docs2.llm_config == old_config + assert docs2.summary_llm_config == old_config assert len(docs.docs) == len(docs2.docs) context1, context2 = ( docs.get_evidence( @@ -497,45 +492,6 @@ def test_docs_pickle(): docs.query("What date is bring your dog to work in the US?") -def test_docs_pickle_no_faiss(): - doc_path = "example.html" - with open(doc_path, "w", encoding="utf-8") as f: - # get front page of wikipedia - r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day") - f.write(r.text) - llm = OpenAI(client=None, temperature=0.0, model="text-curie-001") - docs = Docs(llm=llm) - docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000) - docs.doc_index = None - docs.texts_index = None - docs_pickle = pickle.dumps(docs) - docs2 = pickle.loads(docs_pickle) - docs2.update_llm(llm) - assert len(docs.docs) == len(docs2.docs) - assert ( - strings_similarity( - docs.get_evidence( - Answer( - question="What date is bring your dog to work in the US?", - summary_length="about 20 words", - ), - k=3, - max_sources=1, - ).context, - docs2.get_evidence( - Answer( - question="What date is bring your dog to work in the US?", - summary_length="about 20 words", - ), - k=3, - max_sources=1, - ).context, - ) - > 0.75 - ) - os.remove(doc_path) - - def test_bad_context(): doc_path = "example.html" with open(doc_path, "w", encoding="utf-8") as f: @@ -555,7 +511,7 @@ def test_repeat_keys(): # get wiki page about politician r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") f.write(r.text) - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-ada-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="text-ada-001")) docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") try: docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") @@ -584,7 +540,7 @@ def test_repeat_keys(): def test_pdf_reader(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-curie-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) docs.add(doc_path, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -594,7 +550,7 @@ def test_fileio_reader_pdf(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf") with open(doc_path, "rb") as f: - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-curie-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) docs.add_file(f, "Wellawatte et al, XAI Review, 2023") answer = docs.query("Are counterfactuals actionable?") assert "yes" in answer.answer or "Yes" in answer.answer @@ -602,7 +558,7 @@ def test_fileio_reader_pdf(): def test_fileio_reader_txt(): # can't use curie, because it has trouble with parsed HTML - docs = Docs(llm=OpenAI(client=None, temperature=0.0)) + docs = Docs(llm_config=dict(temperature=0.0, model="davinci-002")) r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") if r.status_code != 200: raise ValueError("Could not download wikipedia page") @@ -652,7 +608,7 @@ def test_prompt_length(): def test_code(): # load this script doc_path = os.path.abspath(__file__) - docs = Docs(llm=OpenAI(client=None, temperature=0.0, model="text-ada-001")) + docs = Docs(llm_config=dict(temperature=0.0, model="babbage-002")) docs.add(doc_path, "test_paperqa.py", docname="test_paperqa.py", disable_check=True) assert len(docs.docs) == 1 docs.query("What function tests the preview?") @@ -667,8 +623,8 @@ def test_citation(): docs = Docs() docs.add(doc_path) assert ( - list(docs.docs.values())[0].docname == "Wikipedia2023" - or list(docs.docs.values())[0].docname == "Frederick2023" + list(docs.docs.values())[0].docname == "Wikipedia2024" + or list(docs.docs.values())[0].docname == "Frederick2024" ) @@ -740,19 +696,6 @@ def test_query_filter(): # the filter shouldn't trigger, so just checking that it doesn't crash -def test_nonopenai_client(): - responses = ["This is a test", "This is another test"] * 50 - model = FakeListLLM(responses=responses) - doc_path = "example.txt" - with open(doc_path, "w", encoding="utf-8") as f: - # get wiki page about politician - r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)") - f.write(r.text) - docs = Docs(llm=model) - docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now") - docs.query("What country is Bates from?") - - def test_zotera(): from paperqa.contrib import ZoteroDB @@ -786,11 +729,10 @@ def test_too_much_evidence(): def test_custom_prompts(): - my_qaprompt = PromptTemplate( - input_variables=["question", "context"], - template="Answer the question '{question}' " + my_qaprompt = ( + "Answer the question '{question}' " "using the country name alone. For example: " - "A: United States\nA: Canada\nA: Mexico\n\n Using the context:\n\n{context}\n\nA: ", + "A: United States\nA: Canada\nA: Mexico\n\n Using the context:\n\n{context}\n\nA: " ) docs = Docs(prompts=PromptCollection(qa=my_qaprompt)) @@ -807,11 +749,7 @@ def test_custom_prompts(): def test_pre_prompt(): - pre = PromptTemplate( - input_variables=["question"], - template="Provide context you have memorized " - "that could help answer '{question}'. ", - ) + pre = "Provide context you have memorized " "that could help answer '{question}'. " docs = Docs(prompts=PromptCollection(pre=pre)) @@ -825,13 +763,12 @@ def test_pre_prompt(): def test_post_prompt(): - post = PromptTemplate( - input_variables=["question", "answer"], - template="We are trying to answer the question below " + post = ( + "We are trying to answer the question below " "and have an answer provided. " "Please edit the answer be extremely terse, with no extra words or formatting" "with no extra information.\n\n" - "Q: {question}\nA: {answer}\n\n", + "Q: {question}\nA: {answer}\n\n" ) docs = Docs(prompts=PromptCollection(post=post)) @@ -883,8 +820,8 @@ def disabled_test_memory(): def test_add_texts(): - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + llm_config = dict(temperature=0.1, model="text-ada-001") + docs = Docs(llm_config=llm_config) docs.add_url( "https://en.wikipedia.org/wiki/National_Flag_of_Canada_Day", citation="WikiMedia Foundation, 2023, Accessed now", @@ -894,17 +831,17 @@ def test_add_texts(): docs2 = Docs() texts = [Text(**dict(t)) for t in docs.texts] for t in texts: - t.embeddings = None + t.embedding = None docs2.add_texts(texts, list(docs.docs.values())[0]) for t1, t2 in zip(docs2.texts, docs.texts): assert t1.text == t2.text - assert np.allclose(t1.embeddings, t2.embeddings, atol=1e-3) + assert np.allclose(t1.embedding, t2.embedding, atol=1e-3) docs2._build_texts_index() # now do it again to test after text index is already built - llm = OpenAI(client=None, temperature=0.1, model="text-ada-001") - docs = Docs(llm=llm) + llm_config = dict(temperature=0.1, model="text-ada-001") + docs = Docs(llm_config=llm_config) docs.add_url( "https://en.wikipedia.org/wiki/Frederick_Bates_(politician)", citation="WikiMedia Foundation, 2023, Accessed now", @@ -913,7 +850,7 @@ def test_add_texts(): texts = [Text(**dict(t)) for t in docs.texts] for t in texts: - t.embeddings = None + t.embedding = None docs2.add_texts(texts, list(docs.docs.values())[0]) assert len(docs2.docs) == 2