From deb266d3e8d92f908998fb3f5a28094ba8b1bd1f Mon Sep 17 00:00:00 2001 From: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Date: Thu, 25 Jul 2024 17:44:05 +0530 Subject: [PATCH] fix: Perform checks in classifier after LLM response, bumped to 0.0.29 (#521) * Perform checks in classifier after LLM response, bumped to 0.0.29 * Bumped SDK version to 0.39.0 in structure and text extractor tool * Removed traceback printing for classifier, taken care in worker/SDK --------- Co-authored-by: Rahul Johny <116638720+johnyrahul@users.noreply.github.com> --- tools/classifier/requirements.txt | 2 +- tools/classifier/src/config/properties.json | 2 +- .../src/config/runtime_variables.json | 37 +---- tools/classifier/src/config/spec.json | 2 +- tools/classifier/src/helper.py | 136 +++++++++++++++--- tools/classifier/src/main.py | 60 +++----- tools/structure/requirements.txt | 2 +- tools/structure/src/config/properties.json | 2 +- tools/text_extractor/requirements.txt | 2 +- .../text_extractor/src/config/properties.json | 2 +- 10 files changed, 148 insertions(+), 99 deletions(-) diff --git a/tools/classifier/requirements.txt b/tools/classifier/requirements.txt index 225c3d2c0..4002066b4 100644 --- a/tools/classifier/requirements.txt +++ b/tools/classifier/requirements.txt @@ -1,4 +1,4 @@ # Add your dependencies here # Required for all unstract tools -unstract-sdk~=0.38.1 +unstract-sdk~=0.39.0 diff --git a/tools/classifier/src/config/properties.json b/tools/classifier/src/config/properties.json index a46b2c91c..e268639b0 100644 --- a/tools/classifier/src/config/properties.json +++ b/tools/classifier/src/config/properties.json @@ -2,7 +2,7 @@ "schemaVersion": "0.0.1", "displayName": "File Classifier", "functionName": "classify", - "toolVersion": "0.0.28", + "toolVersion": "0.0.29", "description": "Classifies a file into a bin based on its contents", "input": { "description": "File to be classified" diff --git a/tools/classifier/src/config/runtime_variables.json b/tools/classifier/src/config/runtime_variables.json index 8f4ef8f50..3f72970c3 100644 --- a/tools/classifier/src/config/runtime_variables.json +++ b/tools/classifier/src/config/runtime_variables.json @@ -2,39 +2,6 @@ "title": "Runtime Variables", "description": "Runtime Variables for classifier", "type": "object", - "required": [ - "OPENAI_API_KEY", - "OPENAI_API_BASE", - "OPENAI_API_MODEL", - "OPENAI_API_ENGINE", - "OPENAI_API_VERSION" - ], - "properties": { - "OPENAI_API_KEY": { - "type": "string", - "title": "OpenAI API Key", - "description": "Your OpenAI API key" - }, - "OPENAI_API_BASE": { - "type": "string", - "title": "OpenAI API Base URL", - "description": "The base URL for the OpenAI API" - }, - "OPENAI_API_MODEL": { - "type": "string", - "title": "OpenAI API Model", - "description": "The OpenAI model to use" - }, - "OPENAI_API_ENGINE": { - "type": "string", - "title": "OpenAI API Engine", - "description": "The OpenAI engine to use" - }, - "OPENAI_API_VERSION": { - "type": "string", - "title": "OpenAI API Version", - "description": "The OpenAI API version to use", - "default": "2023-05-15" - } - } + "required": [], + "properties": {} } diff --git a/tools/classifier/src/config/spec.json b/tools/classifier/src/config/spec.json index 6f8d1619d..47d837445 100644 --- a/tools/classifier/src/config/spec.json +++ b/tools/classifier/src/config/spec.json @@ -9,7 +9,7 @@ "classificationBins": { "type": "array", "title": "Classification bins", - "description": "Specify at least two unique classification bins.", + "description": "Specify at least two unique classification bins. 'unknown' and '__unstract_failed' are reserved bins. 'unknown' indicates the LLM can't determine the classification and `__unstract_failed` indicates a tool run failure for the given file.", "items": { "type": "string" }, diff --git a/tools/classifier/src/helper.py b/tools/classifier/src/helper.py index 488b67b8f..ed1722b46 100644 --- a/tools/classifier/src/helper.py +++ b/tools/classifier/src/helper.py @@ -1,18 +1,79 @@ +import re +import shutil +from pathlib import Path from typing import Any, Optional from unstract.sdk.cache import ToolCache -from unstract.sdk.constants import ToolEnv +from unstract.sdk.constants import LogLevel, MetadataKey, ToolEnv from unstract.sdk.llm import LLM from unstract.sdk.tool.base import BaseTool from unstract.sdk.utils import ToolUtils from unstract.sdk.x2txt import TextExtractionResult, X2Text +class ReservedBins: + UNKNOWN = "unknown" + FAILED = "__unstract_failed" + + class ClassifierHelper: """Helper functions for Classifier.""" - def __init__(self, tool: BaseTool) -> None: + def __init__(self, tool: BaseTool, output_dir: str) -> None: + """Creates a helper class for the Classifier tool. + + Args: + tool (BaseTool): Base tool instance + output_dir (str): Output directory in TOOL_DATA_DIR + """ self.tool = tool + self.output_dir = output_dir + + def stream_error_and_exit( + self, message: str, bin_to_copy_to: str = ReservedBins.FAILED + ) -> None: + """Streams error logs and performs required cleanup. + + Helper which copies files to a reserved bin in case of an error. + Args: + message (str): Error message to log + bin_to_copy_to (str): The folder to copy the failed source file to. + Defaults to `__unstract_failed`. + input_file (Optional[str], optional): Input file to copy. Defaults to None. + output_dir (Optional[str], optional): Output directory to copy to. + Defaults to None. + """ + source_name = self.tool.get_exec_metadata.get(MetadataKey.SOURCE_NAME) + self.copy_source_to_output_bin( + classification=bin_to_copy_to, + source_file=self.tool.get_source_file(), + source_name=source_name, + ) + + self.tool.stream_error_and_exit(message=message) + + def copy_source_to_output_bin( + self, + classification: str, + source_file: str, + source_name: str, + ) -> None: + """Method to save result in output folder and the data directory. + + Args: + classification (str): classification result + source_file (str): Path to source file used in the workflow + source_name (str): Name of the actual input from the source + """ + try: + output_folder_bin = Path(self.output_dir) / classification + if not output_folder_bin.is_dir(): + output_folder_bin.mkdir(parents=True, exist_ok=True) + + output_file = output_folder_bin / source_name + shutil.copyfile(source_file, output_file) + except Exception as e: + self.tool.stream_error_and_exit(f"Error creating output file: {e}") def extract_text( self, file: str, text_extraction_adapter_id: Optional[str] @@ -106,22 +167,20 @@ def find_classification( f"{ToolUtils.hash_str(settings_string)}:" f"{ToolUtils.hash_str(prompt)}" ) + self.tool.stream_log("Trying to fetch result from cache.") classification = self.get_result_from_cache(cache_key=cache_key) + if classification is not None: + return classification - if classification is None: - self.tool.stream_log("No classification found in cache, calling LLM.") - classification = self.call_llm(prompt=prompt, llm=llm) - if not classification: - classification = "unknown" - classification = classification.strip().lower() - bins = [bin.lower() for bin in bins] - if classification not in bins: - self.tool.stream_error_and_exit( - f"Invalid classification done: {classification}" - ) + self.tool.stream_log("No classification found in cache, calling LLM.") + llm_response = self.call_llm(prompt=prompt, llm=llm) + classification = self.clean_llm_response(llm_response=llm_response, bins=bins) + if use_cache and cache_key: + self.tool.stream_log("Saving result to cache.") + self.save_result_to_cache(cache_key=cache_key, result=classification) return classification - def call_llm(self, prompt: str, llm: LLM) -> Optional[str]: + def call_llm(self, prompt: str, llm: LLM) -> str: """Call LLM. Args: @@ -134,11 +193,54 @@ def call_llm(self, prompt: str, llm: LLM) -> Optional[str]: try: completion = llm.complete(prompt)[LLM.RESPONSE] classification: str = completion.text.strip() - self.tool.stream_log(f"LLM response: {completion}") + self.tool.stream_log(f"LLM response: {completion}", level=LogLevel.DEBUG) return classification except Exception as e: - self.tool.stream_error_and_exit(f"Error calling LLM {e}") - return None + self.stream_error_and_exit(f"Error calling LLM: {e}") + raise e + + def clean_llm_response(self, llm_response: str, bins: list[str]) -> str: + """Cleans the response from the LLM. + + Performs a substring search to find the returned classification. + Treats it as `unknown` if the classification is not clear + from the output. + + Args: + llm_response (str): Response from LLM to clean + bins (list(str)): List of bins to classify the file into. + + Returns: + str: Cleaned classification that matches one of the bins. + """ + classification = ReservedBins.UNKNOWN + cleaned_response = llm_response.strip().lower() + bins = [bin.lower() for bin in bins] + + # Truncate llm_response to the first 100 words + words = cleaned_response.split() + truncated_response = " ".join(words[:100]) + + # Count occurrences of each bin in the truncated text + bin_counts = { + bin: len(re.findall(r"\b" + re.escape(bin) + r"\b", truncated_response)) + for bin in bins + } + + # Filter bins that have a count greater than 0 + matching_bins = [bin for bin, count in bin_counts.items() if count > 0] + + # Determine classification based on the number of matching bins + if len(matching_bins) == 1: + classification = matching_bins[0] + else: + self.stream_error_and_exit( + f"Unable to deduce classified bin from possible values of " + f"'{matching_bins}', moving file to '{ReservedBins.UNKNOWN}' " + "folder instead.", + bin_to_copy_to=ReservedBins.UNKNOWN, + ) + return classification def get_result_from_cache(self, cache_key: str) -> Optional[str]: """Get result from cache. diff --git a/tools/classifier/src/main.py b/tools/classifier/src/main.py index e22f88ef1..5d12a379e 100644 --- a/tools/classifier/src/main.py +++ b/tools/classifier/src/main.py @@ -1,9 +1,8 @@ -import shutil import sys -from pathlib import Path from typing import Any, Optional from helper import ClassifierHelper # type: ignore +from helper import ReservedBins from unstract.sdk.constants import LogLevel, LogState, MetadataKey, ToolSettingsKey from unstract.sdk.exceptions import SdkError from unstract.sdk.llm import LLM @@ -14,21 +13,26 @@ class UnstractClassifier(BaseTool): def __init__(self, log_level: str = LogLevel.INFO) -> None: super().__init__(log_level) - self.helper = ClassifierHelper(tool=self) def validate(self, input_file: str, settings: dict[str, Any]) -> None: bins: Optional[list[str]] = settings.get("classificationBins") llm_adapter_instance_id = settings.get(ToolSettingsKey.LLM_ADAPTER_ID) text_extraction_adapter_id = settings.get("textExtractorId") if not bins: - self.stream_error_and_exit("Classification bins are required") + self.stream_error_and_exit("Classification bins are required.") elif len(bins) < 2: - self.stream_error_and_exit("At least two bins are required") + self.stream_error_and_exit("At least two classification bins are required.") + elif ReservedBins.UNKNOWN in bins: + self.stream_error_and_exit( + f"Classification bin '{ReservedBins.UNKNOWN}' is reserved to mark " + "files which cannot be classified." + ) + if not llm_adapter_instance_id: - self.stream_error_and_exit("Choose an LLM to process the classifier") + self.stream_error_and_exit("Choose an LLM to perform the classification.") if not text_extraction_adapter_id: self.stream_error_and_exit( - "Choose a text extractor to extract the documents" + "Choose a text extractor to extract the documents." ) def run( @@ -41,6 +45,7 @@ def run( use_cache = settings["useCache"] text_extraction_adapter_id = settings["textExtractorId"] llm_adapter_instance_id = settings[ToolSettingsKey.LLM_ADAPTER_ID] + self.helper = ClassifierHelper(tool=self, output_dir=output_dir) # Update GUI input_log = f"### Classification bins:\n```text\n{bins}\n```\n\n" @@ -54,7 +59,7 @@ def run( text_extraction_adapter_id=text_extraction_adapter_id, ) if not text: - self.stream_error_and_exit("Unable to extract text") + self.helper.stream_error_and_exit("Unable to extract text") return self.stream_log(f"Text length: {len(text)}") @@ -70,8 +75,8 @@ def run( self.stream_update(input_log, state=LogState.INPUT_UPDATE) self.stream_update(output_log, state=LogState.OUTPUT_UPDATE) - if "unknown" not in bins: - bins.append("unknown") + if ReservedBins.UNKNOWN not in bins: + bins.append(ReservedBins.UNKNOWN) bins_with_quotes = [f"'{b}'" for b in bins] usage_kwargs: dict[Any, Any] = dict() @@ -85,7 +90,8 @@ def run( usage_kwargs=usage_kwargs, ) except SdkError: - self.stream_error_and_exit("Unable to get llm instance") + self.helper.stream_error_and_exit("Unable to get llm instance") + return max_tokens = llm.get_max_tokens(reserved_for_output=50 + 1000) max_bytes = int(max_tokens * 1.3) @@ -103,8 +109,8 @@ def run( f"Classify the following text into one of the following categories: {' '.join(bins_with_quotes)}.\n\n" # noqa: E501 "Your categorization should be strictly exactly one of the items in the " # noqa: E501 "categories given, do not provide any explanation. Find a semantic match of category if possible. " # noqa: E501 - "If it does not categorize well into any of the listed categories, categorize it as 'unknown'.\n\n" # noqa: E501 - f"Text:\n\n{text}\n\n\nCategory:" + "If it does not categorize well into any of the listed categories, categorize it as 'unknown'." # noqa: E501 + f"Do not enclose the result within single quotes.\n\nText:\n\n{text}\n\n\nCategory:" # noqa: E501 ) settings_string = "".join(str(value) for value in settings.values()) @@ -117,8 +123,7 @@ def run( ) source_name = self.get_exec_metadata.get(MetadataKey.SOURCE_NAME) - self._copy_input_to_output_bin( - output_dir=output_dir, + self.helper.copy_source_to_output_bin( classification=classification, source_file=self.get_source_file(), source_name=source_name, @@ -134,31 +139,6 @@ def run( } self.write_tool_result(data=classification_dict) - def _copy_input_to_output_bin( - self, - output_dir: str, - classification: str, - source_file: str, - source_name: str, - ) -> None: - """Method to save result in output folder and the data directory. - - Args: - output_dir (str): Output directory in TOOL_DATA_DIR - classification (str): classification result - source_file (str): Path to source file used in the workflow - source_name (str): Name of the actual input from the source - """ - try: - output_folder_bin = Path(output_dir) / classification - if not output_folder_bin.is_dir(): - output_folder_bin.mkdir(parents=True, exist_ok=True) - - output_file = output_folder_bin / source_name - shutil.copyfile(source_file, output_file) - except Exception as e: - self.stream_error_and_exit(f"Error creating output file: {e}") - if __name__ == "__main__": args = sys.argv[1:] diff --git a/tools/structure/requirements.txt b/tools/structure/requirements.txt index 225c3d2c0..4002066b4 100644 --- a/tools/structure/requirements.txt +++ b/tools/structure/requirements.txt @@ -1,4 +1,4 @@ # Add your dependencies here # Required for all unstract tools -unstract-sdk~=0.38.1 +unstract-sdk~=0.39.0 diff --git a/tools/structure/src/config/properties.json b/tools/structure/src/config/properties.json index 0e7400f0a..c45f18c62 100644 --- a/tools/structure/src/config/properties.json +++ b/tools/structure/src/config/properties.json @@ -2,7 +2,7 @@ "schemaVersion": "0.0.1", "displayName": "Structure Tool", "functionName": "structure_tool", - "toolVersion": "0.0.32", + "toolVersion": "0.0.33", "description": "This is a template tool which can answer set of input prompts designed in the Prompt Studio", "input": { "description": "File that needs to be indexed and parsed for answers" diff --git a/tools/text_extractor/requirements.txt b/tools/text_extractor/requirements.txt index 225c3d2c0..4002066b4 100644 --- a/tools/text_extractor/requirements.txt +++ b/tools/text_extractor/requirements.txt @@ -1,4 +1,4 @@ # Add your dependencies here # Required for all unstract tools -unstract-sdk~=0.38.1 +unstract-sdk~=0.39.0 diff --git a/tools/text_extractor/src/config/properties.json b/tools/text_extractor/src/config/properties.json index dcca4ca14..edafbad92 100644 --- a/tools/text_extractor/src/config/properties.json +++ b/tools/text_extractor/src/config/properties.json @@ -2,7 +2,7 @@ "schemaVersion": "0.0.1", "displayName": "Text Extractor", "functionName": "text_extractor", - "toolVersion": "0.0.26", + "toolVersion": "0.0.27", "description": "The Text Extractor is a powerful tool designed to convert documents to its text form or Extract texts from documents", "input": { "description": "Document"