Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Text Generation Inference with JSON output #235

Merged
merged 8 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keybert/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from keybert._utils import NotInstalled
from keybert.llm._base import BaseLLM

# TextGenerationInference
try:
from keybert.llm._textgenerationinference import TextGenerationInference
except ModuleNotFoundError:
msg = "`pip install huggingface-hub pydantic ` \n\n"
TextGenerationInference = NotInstalled("TextGenerationInference", "huggingface-hub", custom_msg=msg)

# TextGeneration
try:
Expand Down Expand Up @@ -43,6 +49,7 @@
"Cohere",
"OpenAI",
"TextGeneration",
"TextGenerationInference",
"LangChain",
"LiteLLM"
]
125 changes: 125 additions & 0 deletions keybert/llm/_textgenerationinference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from tqdm import tqdm
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from typing import Mapping, List, Any
from keybert.llm._base import BaseLLM
from keybert.llm._utils import process_candidate_keywords
import json

DEFAULT_PROMPT = """
I have the following document:
[DOCUMENT]

With the following candidate keywords:
[CANDIDATES]

Based on the information above, improve the candidate keywords to best describe the topic of the document.

Output in JSON format:
"""


class Keywords(BaseModel):
keywords: List[str]


class TextGenerationInference(BaseLLM):
""" Tex

Arguments:
client: InferenceClient from huggingface_hub.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
client_kwargs: Kwargs that you can pass to the client.text_generation
when it is called.
json_schema: Pydantic BaseModel to be used as guidance for keywords.
By default uses:
class Keywords(BaseModel):
keywords: List[str]

Usage:

```python
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from keybert.llm import TextGenerationInference
from keybert import KeyLLM

# Json Schema
class Keywords(BaseModel):
keywords: List[str]

# Create your LLM
generator = InferenceClient('url')
llm = TextGenerationInference(generator, Keywords)

# Load it in KeyLLM
kw_model = KeyLLM(llm)

# Extract keywords
document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
keywords = kw_model.extract_keywords(document)
```

You can use a custom prompt and decide where the document should
be inserted with the `[DOCUMENT]` tag:

```python
from keybert.llm import TextGenerationInference

prompt = "I have the following documents '[DOCUMENT]'. Please give me the keywords that are present in this document and separate them with commas:"

# Create your representation model
from huggingface_hub import InferenceClient
generator = InferenceClient('url')
llm = TextGenerationInference(generator)
```
"""

def __init__(self,
client: InferenceClient,
prompt: str = None,
json_schema: BaseModel = Keywords
):
self.client = client
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.json_schema = json_schema

def extract_keywords(
self,
documents: List[str], candidate_keywords: List[List[str]] = None,
inference_kwargs: Mapping[str, Any] = {}
):
""" Extract topics

Arguments:
documents: The documents to extract keywords from
candidate_keywords: A list of candidate keywords that the LLM will fine-tune
For example, it will create a nicer representation of
the candidate keywords, remove redundant keywords, or
shorten them depending on the input prompt.

Returns:
all_keywords: All keywords for each document
"""
all_keywords = []
candidate_keywords = process_candidate_keywords(documents, candidate_keywords)

for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
prompt = self.prompt.replace("[DOCUMENT]", document)
if candidates is not None:
prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))

# Extract result from generator and use that as label
response = self.client.text_generation(
prompt=prompt,
grammar={"type": "json", "value": self.json_schema.schema()},
**inference_kwargs
)
all_keywords = json.loads(response)["keywords"]

return all_keywords
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"numpy>=1.18.5",
"rich>=10.4.0",
"scikit-learn>=0.22.2",
"sentence-transformers>=0.3.8",
"sentence-transformers>=0.3.8"
]

[project.optional-dependencies]
Expand Down Expand Up @@ -70,6 +70,10 @@ test = [
"pytest-cov>=2.6.1",
"pytest>=5.4.3",
]
tgi = [
"huggingface-hub>=0.23.3",
"pydantic>=2.7.4"
]
use = [
"tensorflow",
"tensorflow_hub",
Expand Down