Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated OpenAI compatibility to v1 #207

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paperqa/chains.py
Original file line number Diff line number Diff line change
@@ -83,6 +83,14 @@ def make_chain(
[system_message_prompt, human_message_prompt]
)
return FallbackLLMChain(prompt=chat_prompt, llm=llm)
if not skip_system:
return FallbackLLMChain(
prompt=PromptTemplate(
input_variables=prompt.input_variables,
template=system_prompt + "\n\n" + prompt.template,
),
llm=llm,
)
return FallbackLLMChain(prompt=prompt, llm=llm)


71 changes: 57 additions & 14 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
@@ -18,9 +18,9 @@
from langchain.vectorstores import FAISS

try:
from pydantic.v1 import BaseModel, validator
from pydantic.v1 import BaseModel, Field, validator
except ImportError:
from pydantic import BaseModel, validator
from pydantic import BaseModel, validator, Field

from .chains import get_score, make_chain
from .paths import PAPERQA_DIR
@@ -47,13 +47,18 @@ class Docs(BaseModel, arbitrary_types_allowed=True, smart_union=True):
docnames: Set[str] = set()
texts_index: Optional[VectorStore] = None
doc_index: Optional[VectorStore] = None
llm: Union[str, BaseLanguageModel] = ChatOpenAI(
temperature=0.1, model="gpt-3.5-turbo", client=None
llm: Optional[Union[str, BaseLanguageModel]] = Field(
default_factory=lambda: ChatOpenAI(
model="gpt-3.5-turbo", temperature=0.1, request_timeout=30, client=None
)
)

summary_llm: Optional[Union[str, BaseLanguageModel]] = None
name: str = "default"
index_path: Optional[Path] = PAPERQA_DIR / name
embeddings: Embeddings = OpenAIEmbeddings(client=None)
embeddings: Optional[Embeddings] = Field(
default_factory=lambda: OpenAIEmbeddings(client=None)
)
max_concurrent: int = 5
deleted_dockeys: Set[DocKey] = set()
prompts: PromptCollection = PromptCollection()
@@ -93,6 +98,16 @@ def check_memory_model(cls, v, values):
return values["memory_model"]
return None

def _check_is_set(self) -> None:
"""Check if the non-serialized fields are set."""
status = self.embeddings is not None and self.llm is not None
if not status:
raise ValueError(
"You need to set embeddings and llm before using the Docs object. "
"This is a new requirement in 4.0.0 due to inability to "
"pickle OpenAI (version > 1) objects. "
)

def clear_docs(self):
self.texts = []
self.docs = {}
@@ -105,13 +120,21 @@ def update_llm(
) -> None:
"""Update the LLM for answering questions."""
if type(llm) is str:
llm = ChatOpenAI(temperature=0.1, model=llm, client=None)
llm = ChatOpenAI(
temperature=0.1, model=llm, client=None, request_timeout=30
)
if type(summary_llm) is str:
summary_llm = ChatOpenAI(temperature=0.1, model=summary_llm, client=None)
summary_llm = ChatOpenAI(
temperature=0.1, model=summary_llm, client=None, request_timeout=30
)
self.llm = cast(BaseLanguageModel, llm)
if summary_llm is None:
summary_llm = llm
self.summary_llm = cast(BaseLanguageModel, summary_llm)
# set the embeddings to OpenAI if not set already
# and we predict it is an OpenAI model
if self.embeddings is None and "OpenAI" in self.llm.__class__.__name__:
self.embeddings = OpenAIEmbeddings(client=None)

def _get_unique_name(self, docname: str) -> str:
"""Create a unique name given proposed name"""
@@ -244,6 +267,7 @@ def add_texts(

Returns True if the document was added, False if it was already in the collection.
"""
self._check_is_set()
if doc.dockey in self.docs:
return False
if len(texts) == 0:
@@ -254,7 +278,9 @@ def add_texts(
t.name = t.name.replace(doc.docname, new_docname)
doc.docname = new_docname
if texts[0].embeddings is None:
text_embeddings = self.embeddings.embed_documents([t.text for t in texts])
text_embeddings = cast(Embeddings, self.embeddings).embed_documents(
[t.text for t in texts]
)
for i, t in enumerate(texts):
t.embeddings = text_embeddings[i]
else:
@@ -299,6 +325,7 @@ async def adoc_match(
get_callbacks: CallbackFactory = lambda x: None,
) -> Set[DocKey]:
"""Return a list of dockeys that match the query."""
self._check_is_set()
if self.doc_index is None:
if len(self.docs) == 0:
return set()
@@ -350,6 +377,13 @@ def __getstate__(self):
state["texts_index"].save_local(self.index_path)
del state["texts_index"]
del state["doc_index"]
# check for name this way to catch ChatOpenAI and OpenAI
if "OpenAI" in state["llm"].__class__.__name__:
del state["llm"]
if "OpenAI" in state["summary_llm"].__class__.__name__:
del state["summary_llm"]
if state["embeddings"].__class__.__name__ == "OpenAIEmbeddings":
del state["embeddings"]
return {"__dict__": state, "__fields_set__": self.__fields_set__}

def __setstate__(self, state):
@@ -361,6 +395,12 @@ def __setstate__(self, state):
# they use some special exception type, but I don't want to import it
self.texts_index = None
self.doc_index = None
if not hasattr(self, "llm"):
self.llm = None
if not hasattr(self, "summary_llm"):
self.summary_llm = None
if not hasattr(self, "embeddings"):
self.embeddings = None

def _build_texts_index(self, keys: Optional[Set[DocKey]] = None):
if keys is not None and self.jit_texts_index:
@@ -432,6 +472,7 @@ async def aget_evidence(
disable_vector_search: bool = False,
disable_summarization: bool = False,
) -> Answer:
self._check_is_set()
if disable_vector_search:
k = k * 10000
if len(self.docs) == 0 and self.doc_index is None:
@@ -481,12 +522,6 @@ async def process(match):
memory=self.memory_model,
system_prompt=self.prompts.system,
)
# This is dangerous because it
# could mask errors that are important- like auth errors
# I also cannot know what the exception
# type is because any model could be used
# my best idea is see if there is a 4XX
# http code in the exception
try:
citation = match.metadata["doc"]["citation"]
if detailed_citations:
@@ -503,6 +538,13 @@ async def process(match):
callbacks=callbacks,
)
except Exception as e:
# This is to catch timeouts
# This is dangerous because it
# could mask errors that are important- like auth errors
# I also cannot know what the exception
# type is because any model could be used
# my best idea is see if there is a 4XX
# http code in the exception
if guess_is_4xx(str(e)):
return None
raise e
@@ -605,6 +647,7 @@ async def aquery(
key_filter: Optional[bool] = None,
get_callbacks: CallbackFactory = lambda x: None,
) -> Answer:
self._check_is_set()
if k < max_sources:
raise ValueError("k should be greater than max_sources")
if answer is None:
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.13.2"
__version__ = "4.0.0rc1"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
"pypdf",
"pydantic<2",
"langchain>=0.0.303",
"openai <1",
"openai>1",
"faiss-cpu",
"PyCryptodome",
"html2text",
24 changes: 24 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import numpy as np
import requests
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.llms.fake import FakeListLLM
from langchain.prompts import PromptTemplate
@@ -486,6 +487,29 @@ def test_docs_pickle():
docs.query("What date is bring your dog to work in the US?")


def test_docs_pickle_noopenai():
doc_path = "example.html"
with open(doc_path, "w", encoding="utf-8") as f:
# get front page of wikipedia
r = requests.get("https://en.wikipedia.org/wiki/Take_Your_Dog_to_Work_Day")
f.write(r.text)
llm = FakeListLLM(responses=["This is a test", "This is another test"] * 50)
docs = Docs(llm=llm)
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now", chunk_chars=1000)
os.remove(doc_path)
docs_pickle = pickle.dumps(docs)
docs2 = pickle.loads(docs_pickle)
# should raise ValueError because no embeddings set
try:
docs2.query("What date is bring your dog to work in the US?")
raise RuntimeError("Should have raised ValueError")
except ValueError as e:
assert "embeddings" in str(e)
# now set them
docs2.embeddings = OpenAIEmbeddings(client=None)
docs2.query("What date is bring your dog to work in the US?")


def test_docs_pickle_no_faiss():
doc_path = "example.html"
with open(doc_path, "w", encoding="utf-8") as f: