diff --git a/keybert/llm/__init__.py b/keybert/llm/__init__.py index e7ec4ee4..e24fe48a 100644 --- a/keybert/llm/__init__.py +++ b/keybert/llm/__init__.py @@ -1,6 +1,12 @@ from keybert._utils import NotInstalled from keybert.llm._base import BaseLLM +# TextGenerationInference +try: + from keybert.llm._textgenerationinference import TextGenerationInference +except ModuleNotFoundError: + msg = "`pip install huggingface-hub pydantic ` \n\n" + TextGenerationInference = NotInstalled("TextGenerationInference", "huggingface-hub", custom_msg=msg) # TextGeneration try: @@ -43,6 +49,7 @@ "Cohere", "OpenAI", "TextGeneration", + "TextGenerationInference", "LangChain", "LiteLLM" ] diff --git a/keybert/llm/_textgenerationinference.py b/keybert/llm/_textgenerationinference.py new file mode 100644 index 00000000..f2b99a66 --- /dev/null +++ b/keybert/llm/_textgenerationinference.py @@ -0,0 +1,125 @@ +from tqdm import tqdm +from pydantic import BaseModel +from huggingface_hub import InferenceClient +from typing import Mapping, List, Any +from keybert.llm._base import BaseLLM +from keybert.llm._utils import process_candidate_keywords +import json + +DEFAULT_PROMPT = """ + I have the following document: + [DOCUMENT] + + With the following candidate keywords: + [CANDIDATES] + + Based on the information above, improve the candidate keywords to best describe the topic of the document. + + Output in JSON format: +""" + + +class Keywords(BaseModel): + keywords: List[str] + + +class TextGenerationInference(BaseLLM): + """ Tex + + Arguments: + client: InferenceClient from huggingface_hub. + prompt: The prompt to be used in the model. If no prompt is given, + `self.default_prompt_` is used instead. + NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt + to decide where the keywords and documents need to be + inserted. + client_kwargs: Kwargs that you can pass to the client.text_generation + when it is called. + json_schema: Pydantic BaseModel to be used as guidance for keywords. + By default uses: + class Keywords(BaseModel): + keywords: List[str] + + Usage: + + ```python + from pydantic import BaseModel + from huggingface_hub import InferenceClient + from keybert.llm import TextGenerationInference + from keybert import KeyLLM + + # Json Schema + class Keywords(BaseModel): + keywords: List[str] + + # Create your LLM + generator = InferenceClient('url') + llm = TextGenerationInference(generator, Keywords) + + # Load it in KeyLLM + kw_model = KeyLLM(llm) + + # Extract keywords + document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine." + keywords = kw_model.extract_keywords(document) + ``` + + You can use a custom prompt and decide where the document should + be inserted with the `[DOCUMENT]` tag: + + ```python + from keybert.llm import TextGenerationInference + + prompt = "I have the following documents '[DOCUMENT]'. Please give me the keywords that are present in this document and separate them with commas:" + + # Create your representation model + from huggingface_hub import InferenceClient + generator = InferenceClient('url') + llm = TextGenerationInference(generator) + ``` + """ + + def __init__(self, + client: InferenceClient, + prompt: str = None, + json_schema: BaseModel = Keywords + ): + self.client = client + self.prompt = prompt if prompt is not None else DEFAULT_PROMPT + self.default_prompt_ = DEFAULT_PROMPT + self.json_schema = json_schema + + def extract_keywords( + self, + documents: List[str], candidate_keywords: List[List[str]] = None, + inference_kwargs: Mapping[str, Any] = {} + ): + """ Extract topics + + Arguments: + documents: The documents to extract keywords from + candidate_keywords: A list of candidate keywords that the LLM will fine-tune + For example, it will create a nicer representation of + the candidate keywords, remove redundant keywords, or + shorten them depending on the input prompt. + + Returns: + all_keywords: All keywords for each document + """ + all_keywords = [] + candidate_keywords = process_candidate_keywords(documents, candidate_keywords) + + for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose): + prompt = self.prompt.replace("[DOCUMENT]", document) + if candidates is not None: + prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates)) + + # Extract result from generator and use that as label + response = self.client.text_generation( + prompt=prompt, + grammar={"type": "json", "value": self.json_schema.schema()}, + **inference_kwargs + ) + all_keywords = json.loads(response)["keywords"] + + return all_keywords diff --git a/pyproject.toml b/pyproject.toml index 4eba6572..795b64f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "numpy>=1.18.5", "rich>=10.4.0", "scikit-learn>=0.22.2", - "sentence-transformers>=0.3.8", + "sentence-transformers>=0.3.8" ] [project.optional-dependencies] @@ -70,6 +70,10 @@ test = [ "pytest-cov>=2.6.1", "pytest>=5.4.3", ] +tgi = [ + "huggingface-hub>=0.23.3", + "pydantic>=2.7.4" +] use = [ "tensorflow", "tensorflow_hub",