Skip to content

Commit

Permalink
[WIP] Version 4 with better open source model support and no langchain (
Browse files Browse the repository at this point in the history
#223)

* First draft without langchain

* Improved unit tests to be closer

* Fixed remaining tests

* Added new dependencies

* Refactored LLMs to allow swapping

* Added unit tests for custom embeds/llms

* Fixed langchain compatibility and updated README

* Refactored vector stores to maybe support langchain

* Addressed Matt's comments

* Finished langchain vector store

* Unit test prompt adjustments

* Added warning to README

* Fixed some typos in README

* Fixed problem for very short texts

* Made it easier to access LLM names

* Fixed text embedding errors
whitead authored Jan 22, 2024
1 parent 84f13ea commit 3cb16f2
Showing 14 changed files with 1,750 additions and 987 deletions.
278 changes: 118 additions & 160 deletions README.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -6,4 +6,6 @@ python-dotenv
pymupdf
build
types-requests
numpy
langchain_openai
langchain_community
faiss-cpu
22 changes: 22 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from .docs import Answer, Docs, PromptCollection, Doc, Text, Context
from .version import __version__
from .llms import (
LLMModel,
EmbeddingModel,
LangchainEmbeddingModel,
OpenAIEmbeddingModel,
LangchainLLMModel,
OpenAILLMModel,
LlamaEmbeddingModel,
NumpyVectorStore,
LangchainVectorStore,
SentenceTransformerEmbeddingModel,
)

__all__ = [
"Docs",
@@ -9,4 +21,14 @@
"Doc",
"Text",
"Context",
"LLMModel",
"EmbeddingModel",
"OpenAIEmbeddingModel",
"OpenAILLMModel",
"LangchainLLMModel",
"LlamaEmbeddingModel",
"SentenceTransformerEmbeddingModel",
"LangchainEmbeddingModel",
"NumpyVectorStore",
"LangchainVectorStore",
]
113 changes: 0 additions & 113 deletions paperqa/chains.py

This file was deleted.

5 changes: 1 addition & 4 deletions paperqa/contrib/zotero.py
Original file line number Diff line number Diff line change
@@ -4,10 +4,7 @@
from pathlib import Path
from typing import List, Optional, Union, cast

try:
from pydantic.v1 import BaseModel
except ImportError:
from pydantic import BaseModel
from pydantic import BaseModel

try:
from pyzotero import zotero
724 changes: 392 additions & 332 deletions paperqa/docs.py

Large diffs are not rendered by default.

659 changes: 659 additions & 0 deletions paperqa/llms.py

Large diffs are not rendered by default.

73 changes: 35 additions & 38 deletions paperqa/prompts.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,51 @@
from langchain.prompts import PromptTemplate

summary_prompt = PromptTemplate(
input_variables=["text", "citation", "question", "summary_length"],
template="Summarize the text below to help answer a question. "
"Do not directly answer the question, instead summarize "
"to give evidence to help answer the question. "
"Focus on specific details, including numbers, equations, or specific quotes. "
'Reply "Not applicable" if text is irrelevant. '
"Use {summary_length}. At the end of your response, provide a score from 1-10 on a newline "
"indicating relevance to question. Do not explain your score. "
"\n\n"
"{text}\n\n"
"Excerpt from {citation}\n"
"Question: {question}\n"
"Relevant Information Summary:",
summary_prompt = (
"Summarize the excerpt below to help answer a question.\n\n"
"Excerpt from {citation}\n\n----\n\n{text}\n\n----\n\n"
"Question: {question}\n\n"
"Do not directly answer the question, instead summarize to give evidence to help "
"answer the question. Stay detailed; report specific numbers, equations, or "
'direct quotes (marked with quotation marks). Reply "Not applicable" if the '
"excerpt is irrelevant. At the end of your response, provide an integer score "
"from 1-10 on a newline indicating relevance to question. Do not explain your score."
"\n\nRelevant Information Summary ({summary_length}):"
)

qa_prompt = PromptTemplate(
input_variables=["context", "answer_length", "question"],
template="Write an answer ({answer_length}) "
"for the question below based on the provided context. "
"If the context provides insufficient information and the question cannot be directly answered, "
'reply "I cannot answer". '
"For each part of your answer, indicate which sources most support it "
"via valid citation markers at the end of sentences, like (Example2012). \n"
"Context (with relevance scores):\n {context}\n"
"Question: {question}\n"
"Answer: ",
qa_prompt = (
"Answer the question below with the context.\n\n"
"Context (with relevance scores):\n\n{context}\n\n----\n\n"
"Question: {question}\n\n"
"Write an answer based on the context. "
"If the context provides insufficient information and "
"the question cannot be directly answered, reply "
'"I cannot answer."'
"For each part of your answer, indicate which sources most support "
"it via citation keys at the end of sentences, "
"like (Example2012Example pages 3-4). Only cite from the context "
"below and only use the valid keys. Write in the style of a "
"Wikipedia article, with concise sentences and coherent paragraphs. "
"The context comes from a variety of sources and is only a summary, "
"so there may inaccuracies or ambiguities. If quotes are present and "
"relevant, use them in the answer. This answer will go directly onto "
"Wikipedia, so do not add any extraneous information.\n\n"
"Answer ({answer_length}):"
)

select_paper_prompt = PromptTemplate(
input_variables=["question", "papers"],
template="Select papers that may help answer the question below. "
select_paper_prompt = (
"Select papers that may help answer the question below. "
"Papers are listed as $KEY: $PAPER_INFO. "
"Return a list of keys, separated by commas. "
'Return "None", if no papers are applicable. '
"Choose papers that are relevant, from reputable sources, and timely "
"(if the question requires timely information). \n\n"
"Question: {question}\n\n"
"Papers: {papers}\n\n"
"Selected keys:",
"Selected keys:"
)

# We are unable to serialize with partial variables
# so TODO: update year next year
citation_prompt = PromptTemplate(
input_variables=["text"],
template="Provide the citation for the following text in MLA Format. The year is 2023\n"
citation_prompt = (
"Provide the citation for the following text in MLA Format. "
"If reporting date accessed, the current year is 2024\n\n"
"{text}\n\n"
"Citation:",
"Citation:"
)

default_system_prompt = (
50 changes: 38 additions & 12 deletions paperqa/readers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from math import ceil
from pathlib import Path
from typing import List

import tiktoken
from html2text import html2text
from langchain.text_splitter import TokenTextSplitter

from .types import Doc, Text

@@ -31,7 +32,7 @@ def parse_pdf_fitz(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List
)
split = split[chunk_chars - overlap :]
pages = [str(i + 1)]
if len(split) > overlap:
if len(split) > overlap or len(texts) == 0:
pg = "-".join([pages[0], pages[-1]])
texts.append(
Text(text=split[:chunk_chars], name=f"{doc.docname} pages {pg}", doc=doc)
@@ -64,7 +65,7 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text
)
split = split[chunk_chars - overlap :]
pages = [str(i + 1)]
if len(split) > overlap:
if len(split) > overlap or len(texts) == 0:
pg = "-".join([pages[0], pages[-1]])
texts.append(
Text(text=split[:chunk_chars], name=f"{doc.docname} pages {pg}", doc=doc)
@@ -76,6 +77,12 @@ def parse_pdf(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List[Text
def parse_txt(
path: Path, doc: Doc, chunk_chars: int, overlap: int, html: bool = False
) -> List[Text]:
"""Parse a document into chunks, based on tiktoken encoding.
NOTE: We get some byte continuation errors.
Currnetly ignored, but should explore more to make sure we
don't miss anything.
"""
try:
with open(path) as f:
text = f.read()
@@ -84,13 +91,32 @@ def parse_txt(
text = f.read()
if html:
text = html2text(text)
# yo, no idea why but the texts are not split correctly
text_splitter = TokenTextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap)
raw_texts = text_splitter.split_text(text)
texts = [
Text(text=t, name=f"{doc.docname} chunk {i}", doc=doc)
for i, t in enumerate(raw_texts)
]
texts: list[Text] = []
# we tokenize using tiktoken so cuts are in reasonable places
# See https://github.com/openai/tiktoken
enc = tiktoken.get_encoding("cl100k_base")
encoded = enc.encode_ordinary(text)
split = []
# convert from characters to chunks
char_count = len(text) # e.g., 25,000
token_count = len(encoded) # e.g., 4,500
chars_per_token = char_count / token_count # e.g., 5.5
chunk_tokens = chunk_chars / chars_per_token # e.g., 3000 / 5.5 = 545
overlap_tokens = overlap / chars_per_token # e.g., 100 / 5.5 = 18
chunk_count = ceil(token_count / chunk_tokens) # e.g., 4500 / 545 = 9
for i in range(chunk_count):
split = encoded[
max(int(i * chunk_tokens - overlap_tokens), 0) : int(
(i + 1) * chunk_tokens + overlap_tokens
)
]
texts.append(
Text(
text=enc.decode(split),
name=f"{doc.docname} chunk {i + 1}",
doc=doc,
)
)
return texts


@@ -104,7 +130,7 @@ def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List
with open(path) as f:
for i, line in enumerate(f):
split += line
if len(split) > chunk_chars:
while len(split) > chunk_chars:
texts.append(
Text(
text=split[:chunk_chars],
@@ -114,7 +140,7 @@ def parse_code_txt(path: Path, doc: Doc, chunk_chars: int, overlap: int) -> List
)
split = split[chunk_chars - overlap :]
last_line = i
if len(split) > overlap:
if len(split) > overlap or len(texts) == 0:
texts.append(
Text(
text=split[:chunk_chars],
230 changes: 93 additions & 137 deletions paperqa/types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable

from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.prompts import PromptTemplate

try:
from pydantic.v1 import BaseModel, validator
except ImportError:
from pydantic import BaseModel, validator

import re
from pydantic import BaseModel, ConfigDict, Field, field_validator

from .prompts import (
citation_prompt,
@@ -21,75 +9,117 @@
select_paper_prompt,
summary_prompt,
)
from .utils import extract_doi, iter_citations

# Just for clarity
DocKey = Any
CBManager = Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun]
CallbackFactory = Callable[[str], Union[None, List[BaseCallbackHandler]]]

CallbackFactory = Callable[[str], list[Callable[[str], None]] | None]


class LLMResult(BaseModel):
text: str = ""
prompt_count: int = 0
completion_count: int = 0
model: str
date: str
seconds_to_first_token: float = 0
seconds_to_last_token: float = 0

def __str__(self):
return self.text


class Embeddable(BaseModel):
embedding: list[float] | None = Field(default=None, repr=False)

class Doc(BaseModel):

class Doc(Embeddable):
docname: str
citation: str
dockey: DocKey


class Text(BaseModel):
class Text(Embeddable):
text: str
name: str
doc: Doc
embeddings: Optional[List[float]] = None


# Mock a dictionary and store any missing items
class _FormatDict(dict):
def __init__(self) -> None:
self.key_set: set[str] = set()

def __missing__(self, key: str) -> str:
self.key_set.add(key)
return key


def get_formatted_variables(s: str) -> set[str]:
"""Returns the set of variables implied by the format string"""
format_dict = _FormatDict()
s.format_map(format_dict)
return format_dict.key_set


class PromptCollection(BaseModel):
summary: PromptTemplate = summary_prompt
qa: PromptTemplate = qa_prompt
select: PromptTemplate = select_paper_prompt
cite: PromptTemplate = citation_prompt
pre: Optional[PromptTemplate] = None
post: Optional[PromptTemplate] = None
summary: str = summary_prompt
qa: str = qa_prompt
select: str = select_paper_prompt
cite: str = citation_prompt
pre: str | None = None
post: str | None = None
system: str = default_system_prompt
skip_summary: bool = False

@validator("summary")
def check_summary(cls, v: PromptTemplate) -> PromptTemplate:
if not set(v.input_variables).issubset(set(summary_prompt.input_variables)):
@field_validator("summary")
@classmethod
def check_summary(cls, v: str) -> str:
if not set(get_formatted_variables(v)).issubset(
set(get_formatted_variables(summary_prompt))
):
raise ValueError(
f"Summary prompt can only have variables: {summary_prompt.input_variables}"
f"Summary prompt can only have variables: {get_formatted_variables(summary_prompt)}"
)
return v

@validator("qa")
def check_qa(cls, v: PromptTemplate) -> PromptTemplate:
if not set(v.input_variables).issubset(set(qa_prompt.input_variables)):
@field_validator("qa")
@classmethod
def check_qa(cls, v: str) -> str:
if not set(get_formatted_variables(v)).issubset(
set(get_formatted_variables(qa_prompt))
):
raise ValueError(
f"QA prompt can only have variables: {qa_prompt.input_variables}"
f"QA prompt can only have variables: {get_formatted_variables(qa_prompt)}"
)
return v

@validator("select")
def check_select(cls, v: PromptTemplate) -> PromptTemplate:
if not set(v.input_variables).issubset(
set(select_paper_prompt.input_variables)
@field_validator("select")
@classmethod
def check_select(cls, v: str) -> str:
if not set(get_formatted_variables(v)).issubset(
set(get_formatted_variables(select_paper_prompt))
):
raise ValueError(
f"Select prompt can only have variables: {select_paper_prompt.input_variables}"
f"Select prompt can only have variables: {get_formatted_variables(select_paper_prompt)}"
)
return v

@validator("pre")
def check_pre(cls, v: Optional[PromptTemplate]) -> Optional[PromptTemplate]:
@field_validator("pre")
@classmethod
def check_pre(cls, v: str | None) -> str | None:
if v is not None:
if set(v.input_variables) != set(["question"]):
if set(get_formatted_variables(v)) != set(["question"]):
raise ValueError("Pre prompt must have input variables: question")
return v

@validator("post")
def check_post(cls, v: Optional[PromptTemplate]) -> Optional[PromptTemplate]:
@field_validator("post")
@classmethod
def check_post(cls, v: str | None) -> str | None:
if v is not None:
# kind of a hack to get list of attributes in answer
attrs = [a.name for a in Answer.__fields__.values()]
if not set(v.input_variables).issubset(attrs):
attrs = set(Answer.model_fields.keys())
if not set(get_formatted_variables(v)).issubset(attrs):
raise ValueError(f"Post prompt must have input variables: {attrs}")
return v

@@ -113,18 +143,18 @@ class Answer(BaseModel):
question: str
answer: str = ""
context: str = ""
contexts: List[Context] = []
contexts: list[Context] = []
references: str = ""
formatted_answer: str = ""
dockey_filter: Optional[Set[DocKey]] = None
dockey_filter: set[DocKey] | None = None
summary_length: str = "about 100 words"
answer_length: str = "about 100 words"
memory: Optional[str] = None
# these two below are for convenience
# and are not set. But you can set them
# if you want to use them.
cost: Optional[float] = None
token_counts: Optional[Dict[str, List[int]]] = None
memory: str | None = None
# just for convenience you can override this
cost: float | None = None
# key is model name, value is (prompt, completion) token counts
token_counts: dict[str, list[int]] = Field(default_factory=dict)
model_config = ConfigDict(extra="forbid")

def __str__(self) -> str:
"""Return the answer as a string."""
@@ -138,87 +168,13 @@ def get_citation(self, name: str) -> str:
raise ValueError(f"Could not find docname {name} in contexts")
return doc.citation

def markdown(self) -> Tuple[str, str]:
"""Return the answer with footnote style citations."""
# example: This is an answer.[^1]
# [^1]: This the citation.
output = self.answer
refs: Dict[str, int] = dict()
index = 1
for citation in iter_citations(self.answer):
compound = ""
strip = True
for c in re.split(",|;", citation):
c = c.strip("() ")
if c == "Extra background information":
continue
if c in refs:
compound += f"[^{refs[c]}]"
continue
# check if it is a citation
try:
self.get_citation(c)
except ValueError:
# not a citation
strip = False
continue
refs[c] = index
compound += f"[^{index}]"
index += 1
if strip:
output = output.replace(citation, compound)
formatted_refs = "\n".join(
[
f"[^{i}]: [{self.get_citation(r)}]({extract_doi(self.get_citation(r))})"
for r, i in refs.items()
def add_tokens(self, result: LLMResult):
"""Update the token counts for the given result."""
if result.model not in self.token_counts:
self.token_counts[result.model] = [
result.prompt_count,
result.completion_count,
]
)
# quick fix of space before period
output = output.replace(" .", ".")
return output, formatted_refs

def combine_with(self, other: "Answer") -> "Answer":
"""
Combine this answer object with another, merging their context/answer.
"""
combined = Answer(
question=self.question + " / " + other.question,
answer=self.answer + " " + other.answer,
context=self.context + " " + other.context,
contexts=self.contexts + other.contexts,
references=self.references + " " + other.references,
formatted_answer=self.formatted_answer + " " + other.formatted_answer,
summary_length=self.summary_length, # Assuming the same summary_length for both
answer_length=self.answer_length, # Assuming the same answer_length for both
memory=self.memory if self.memory else other.memory,
cost=self.cost if self.cost else other.cost,
token_counts=self.merge_token_counts(self.token_counts, other.token_counts),
)
# Handling dockey_filter if present in either of the Answer objects
if self.dockey_filter or other.dockey_filter:
combined.dockey_filter = (
self.dockey_filter if self.dockey_filter else set()
) | (other.dockey_filter if other.dockey_filter else set())
return combined

@staticmethod
def merge_token_counts(
counts1: Optional[Dict[str, List[int]]], counts2: Optional[Dict[str, List[int]]]
) -> Optional[Dict[str, List[int]]]:
"""
Merge two dictionaries of token counts.
"""
if counts1 is None and counts2 is None:
return None
if counts1 is None:
return counts2
if counts2 is None:
return counts1
merged_counts = counts1.copy()
for key, values in counts2.items():
if key in merged_counts:
merged_counts[key][0] += values[0]
merged_counts[key][1] += values[1]
else:
merged_counts[key] = values
return merged_counts
else:
self.token_counts[result.model][0] += result.prompt_count
self.token_counts[result.model][1] += result.completion_count
37 changes: 26 additions & 11 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,9 @@
import re
import string
from pathlib import Path
from typing import BinaryIO, List, Union
from typing import Any, BinaryIO, Coroutine, Iterator, Union

import pypdf
from langchain.base_language import BaseLanguageModel

StrPath = Union[str, Path]

@@ -76,7 +75,7 @@ def md5sum(file_path: StrPath) -> str:
return hashlib.md5(f.read()).hexdigest()


async def gather_with_concurrency(n: int, *coros: List) -> List:
async def gather_with_concurrency(n: int, coros: list[Coroutine]) -> list[Any]:
# https://stackoverflow.com/a/61478547/2392535
semaphore = asyncio.Semaphore(n)

@@ -93,13 +92,6 @@ def guess_is_4xx(msg: str) -> bool:
return False


def get_llm_name(llm: BaseLanguageModel) -> str:
try:
return llm.model_name # type: ignore
except AttributeError:
return llm.model # type: ignore


def strip_citations(text: str) -> str:
# Combined regex for identifying citations (see unit tests for examples)
citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)"
@@ -108,7 +100,7 @@ def strip_citations(text: str) -> str:
return text


def iter_citations(text: str) -> List[str]:
def iter_citations(text: str) -> list[str]:
# Combined regex for identifying citations (see unit tests for examples)
citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)"
result = re.findall(citation_regex, text, flags=re.MULTILINE)
@@ -131,3 +123,26 @@ def extract_doi(reference: str) -> str:
return "https://doi.org/" + doi_match.group()
else:
return ""


def batch_iter(iterable: list, n: int = 1) -> Iterator[list]:
"""
Batch an iterable into chunks of size n
:param iterable: The iterable to batch
:param n: The size of the batches
:return: A list of batches
"""
length = len(iterable)
for ndx in range(0, length, n):
yield iterable[ndx : min(ndx + n, length)]


def flatten(iteratble: list) -> list:
"""
Flatten a list of lists
:param l: The list of lists to flatten
:return: A flattened list
"""
return [item for sublist in iteratble for item in sublist]
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.13.5"
__version__ = "4.0.0-pre.1"
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -18,10 +18,10 @@
packages=["paperqa", "paperqa.contrib"],
install_requires=[
"pypdf",
"pydantic<2",
"langchain>=0.0.303",
"openai <1",
"faiss-cpu",
"pydantic>=2",
"openai>=1",
"numpy",
"nest-asyncio",
"PyCryptodome",
"html2text",
"tiktoken>=0.4.0",
532 changes: 358 additions & 174 deletions tests/test_paperqa.py

Large diffs are not rendered by default.

0 comments on commit 3cb16f2

Please sign in to comment.