Skip to content

Commit

Permalink
Improved unit tests to be closer
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Jan 4, 2024
1 parent 04bcda1 commit 0009c0c
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 186 deletions.
122 changes: 65 additions & 57 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nest_asyncio # isort:skip
import asyncio
import os
import re
import sys
import tempfile
from datetime import datetime
from io import BytesIO
Expand All @@ -11,7 +11,7 @@
from openai import AsyncOpenAI
from pydantic import BaseModel, Field, field_validator, model_validator

from .chains import get_score, make_chain
from .llms import embed_documents, get_score, guess_model_type, make_chain
from .paths import PAPERQA_DIR
from .readers import read_doc
from .types import (
Expand All @@ -36,17 +36,20 @@
strip_citations,
)

# Apply the patch to allow nested loops
nest_asyncio.apply()


class Docs(BaseModel):
"""A collection of documents to be used for answering questions."""

_client: AsyncOpenAI
_client: AsyncOpenAI | None
docs: dict[DocKey, Doc] = {}
texts: list[Text] = []
docnames: set[str] = set()
texts_index: VectorStore = NumpyVectorStore()
doc_index: VectorStore = NumpyVectorStore()
llm_config: dict = dict(model="gpt-3.5-turbo")
llm_config: dict = dict(model="gpt-3.5-turbo", model_type="chat")
summary_llm_config: dict | None = Field(default=None, validate_default=True)
name: str = "default"
index_path: Path | None = PAPERQA_DIR / name
Expand All @@ -69,24 +72,39 @@ def __init__(self, **data):

@field_validator("llm_config", "summary_llm_config")
@classmethod
def llm_config_has_type(cls, v):
def llm_guess_model_type(cls, v: dict) -> dict:
if v is not None and "model_type" not in v:
raise ValueError("Must specify if chat or completion model.")
v["model_type"] = guess_model_type(v["model"])
return v

@model_validator(mode="after")
@classmethod
def config_summary_llm_conig(cls, data: Any) -> Any:
if isinstance(data, dict):
if data["summary_llm_config"] is None:
data["summary_llm_config"] = data["llm_config"]
if isinstance(data, Docs):
if data.summary_llm_config is None:
data.summary_llm_config = data.llm_config
return data

def clear_docs(self):
self.texts = []
self.docs = {}
self.docnames = set()

def __getstate__(self):
state = super().__getstate__()
# remove client from private attributes
del state["__pydantic_private__"]["_client"]
return state

def __setstate__(self, state):
super().__setstate__(state)
self._client = None

def set_client(self, client: AsyncOpenAI | None = None):
if client is None:
client = AsyncOpenAI()
self._client = client

def _get_unique_name(self, docname: str) -> str:
"""Create a unique name given proposed name"""
suffix = ""
Expand Down Expand Up @@ -173,12 +191,7 @@ def add(
texts = read_doc(path, fake_doc, chunk_chars=chunk_chars, overlap=100)
if len(texts) == 0:
raise ValueError(f"Could not read document {path}. Is it empty?")
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
citation = loop.run_until_complete(cite_chain(texts[0].text))
citation = asyncio.run(cite_chain(data=dict(text=texts[0].text)))
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"

Expand Down Expand Up @@ -234,14 +247,20 @@ def add_texts(
t.name = t.name.replace(doc.docname, new_docname)
doc.docname = new_docname
if texts[0].embedding is None:
text_embeddings = self.embeddings.embed_documents([t.text for t in texts])
text_embeddings = asyncio.run(
embed_documents(
self._client, [t.text for t in texts], self.embeddings_model
)
)
for i, t in enumerate(texts):
t.embedding = text_embeddings[i]
if doc.embedding is None:
doc.embedding = self.embeddings.embed_documents([doc.citation])[0]
doc.embedding = asyncio.run(
embed_documents(self._client, [doc.citation], self.embeddings_model)
)[0]
if not self.jit_texts_index:
self.texts_index.add_texts_and_embeddings(texts)
self.doc_index.add_texts_and_embeddings(doc)
self.doc_index.add_texts_and_embeddings([doc])
self.docs[doc.dockey] = doc
self.texts += texts
self.docnames.add(doc.docname)
Expand All @@ -266,8 +285,13 @@ async def adoc_match(
get_callbacks: CallbackFactory = lambda x: None,
) -> set[DocKey]:
"""Return a list of dockeys that match the query."""
query_vector = (
await embed_documents(self._client, [query], self.embeddings_model)
)[0]
matches, _ = self.doc_index.max_marginal_relevance_search(
query, k=k + len(self.deleted_dockeys)
query_vector,
k=k + len(self.deleted_dockeys),
fetch_k=5 * (k + len(self.deleted_dockeys)),
)
# filter the matches
matched_docs = [m for m in matches if m.dockey not in self.deleted_dockeys]
Expand All @@ -277,7 +301,7 @@ async def adoc_match(
try:
if (
rerank is None
and self.llm_config.model.startswith("gpt-4")
and self.llm_config["model"].startswith("gpt-4")
or rerank is True
):
chain = make_chain(
Expand All @@ -296,7 +320,7 @@ async def adoc_match(
pass
return set([d.dockey for d in matched_docs])

def _build_texts_index(self, keys: set[DocKey] = None):
def _build_texts_index(self, keys: set[DocKey] | None = None):
if keys is not None and self.jit_texts_index:
texts = self.texts
if keys is not None:
Expand All @@ -317,17 +341,7 @@ def get_evidence(
disable_vector_search: bool = False,
disable_summarization: bool = False,
) -> Answer:
# special case for jupyter notebooks
if "get_ipython" in globals() or "google.colab" in sys.modules:
import nest_asyncio

nest_asyncio.apply()
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
return asyncio.run(
self.aget_evidence(
answer,
k=k,
Expand Down Expand Up @@ -362,12 +376,17 @@ async def aget_evidence(
if disable_vector_search:
matches = self.texts
else:
query_vector = (
await embed_documents(
self._client, [answer.question], self.embeddings_model
)
)[0]
if marginal_relevance:
matches = self.texts_index.max_marginal_relevance_search(
answer.question, k=_k, fetch_k=5 * _k
matches, _ = self.texts_index.max_marginal_relevance_search(
query_vector, k=_k, fetch_k=5 * _k
)
else:
matches = self.texts_index.similarity_search(answer.question, k=_k)
matches, _ = self.texts_index.similarity_search(query_vector, k=_k)
# ok now filter (like ones from adoc_match)
if answer.dockey_filter is not None:
matches = [m for m in matches if m.doc.dockey in answer.dockey_filter]
Expand Down Expand Up @@ -406,7 +425,7 @@ async def process(match):
# http code in the exception
try:
context = await summary_chain(
dict(
data=dict(
question=answer.question,
# Add name so chunk is stated
citation=citation,
Expand Down Expand Up @@ -434,7 +453,7 @@ async def process(match):
text=Text(
text=match.text,
name=match.name,
doc=Doc(**match.doc),
doc=Doc(**match.doc.model_dump()),
),
score=score,
)
Expand Down Expand Up @@ -474,17 +493,7 @@ def query(
key_filter: bool | None = None,
get_callbacks: CallbackFactory = lambda x: None,
) -> Answer:
# special case for jupyter notebooks
if "get_ipython" in globals() or "google.colab" in sys.modules:
import nest_asyncio

nest_asyncio.apply()
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
return asyncio.run(
self.aquery(
query,
k=k,
Expand Down Expand Up @@ -540,12 +549,11 @@ async def aquery(
)
answer.context = answer.context + "\n\nExtra background information:" + pre
bib = dict()
if len(answer.context) < 10 and not self.memory:
if len(answer.context) < 10: # and not self.memory:
answer_text = (
"I cannot answer this question due to insufficient information."
)
else:
callbacks = get_callbacks("answer")
qa_chain = make_chain(
client=self._client,
prompt=self.prompts.qa,
Expand All @@ -558,7 +566,7 @@ async def aquery(
answer_length=answer.answer_length,
question=answer.question,
),
callbacks=callbacks,
callbacks=get_callbacks("answer"),
)
# it still happens
if "(Example2012)" in answer_text:
Expand Down Expand Up @@ -586,17 +594,17 @@ async def aquery(
llm_config=self.llm_config,
system_prompt=self.prompts.system,
)
post = await chain.arun(
post = await chain(
data=answer.model_dump(), callbacks=get_callbacks("post")
)
answer.answer = post
answer.formatted_answer = f"Question: {answer.question}\n\n{post}\n"
if len(bib) > 0:
answer.formatted_answer += f"\nReferences\n\n{bib_str}\n"
if self.memory_model is not None:
answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"]
self.memory_model.save_context(
{"Question": answer.question}, {"Answer": answer.answer}
)
# if self.memory_model is not None:
# answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"]
# self.memory_model.save_context(
# {"Question": answer.question}, {"Answer": answer.answer}
# )

return answer
35 changes: 32 additions & 3 deletions paperqa/chains.py → paperqa/llms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import re
from typing import Callable
from typing import Any, Awaitable, Callable, get_args, get_type_hints

from openai import AsyncOpenAI

from .prompts import default_system_prompt

default_system_prompt = "End your responses with [END]"

def guess_model_type(model_name: str) -> str:
import openai

model_type = get_type_hints(
openai.types.chat.completion_create_params.CompletionCreateParamsBase
)["model"]
model_union = get_args(get_args(model_type)[1])
model_arr = list(model_union)
if model_name in model_arr:
return "chat"
return "completion"


def process_llm_config(llm_config: dict) -> dict:
Expand All @@ -23,13 +34,27 @@ def process_llm_config(llm_config: dict) -> dict:
return result


async def embed_documents(
client: AsyncOpenAI, texts: list[str], embedding_model: str
) -> list[list[float]]:
"""Embed a list of documents with batching"""
if client is None:
raise ValueError(
"Your client is None - did you forget to set it after pickling?"
)
response = await client.embeddings.create(
model=embedding_model, input=texts, encoding_format="float"
)
return [e.embedding for e in response.data]


def make_chain(
client: AsyncOpenAI,
prompt: str,
llm_config: dict,
skip_system: bool = False,
system_prompt: str = default_system_prompt,
) -> Callable[[list[dict], list[Callable[[str], None]] | None], list[str]]:
) -> Awaitable[Any]:
"""Create a function to execute a batch of prompts
Args:
Expand All @@ -45,6 +70,10 @@ def make_chain(
where data is a dict with keys for the input variables that will be formatted into prompt
and callbacks is a list of functions to call with each chunk of the completion.
"""
if client is None:
raise ValueError(
"Your client is None - did you forget to set it after pickling?"
)
if llm_config["model_type"] == "chat":
system_message_prompt = dict(role="system", content=system_prompt)
human_message_prompt = dict(role="user", content=prompt)
Expand Down
10 changes: 5 additions & 5 deletions paperqa/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"{text}\n\n"
"Excerpt from {citation}\n"
"Question: {question}\n"
"Relevant Information Summary:",
"Relevant Information Summary:"
)

qa_prompt = (
Expand All @@ -22,7 +22,7 @@
"via valid citation markers at the end of sentences, like (Example2012). \n"
"Context (with relevance scores):\n {context}\n"
"Question: {question}\n"
"Answer: ",
"Answer: "
)

select_paper_prompt = (
Expand All @@ -34,12 +34,12 @@
"(if the question requires timely information). \n\n"
"Question: {question}\n\n"
"Papers: {papers}\n\n"
"Selected keys:",
"Selected keys:"
)
citation_prompt = (
"Provide the citation for the following text in MLA Format.\n\n"
"Provide the citation for the following text in MLA Format. If reporting date accessed, the current year is 2024\n\n"
"{text}\n\n"
"Citation:",
"Citation:"
)

default_system_prompt = (
Expand Down
Loading

0 comments on commit 0009c0c

Please sign in to comment.