Skip to content

Commit

Permalink
fix: Perform checks in classifier after LLM response, bumped to 0.0.29 (
Browse files Browse the repository at this point in the history
#521)

* Perform checks in classifier after LLM response, bumped to 0.0.29

* Bumped SDK version to 0.39.0 in structure and text extractor tool

* Removed traceback printing for classifier, taken care in worker/SDK

---------

Co-authored-by: Rahul Johny <[email protected]>
  • Loading branch information
chandrasekharan-zipstack and johnyrahul authored Jul 25, 2024
1 parent f6fce93 commit deb266d
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 99 deletions.
2 changes: 1 addition & 1 deletion tools/classifier/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Add your dependencies here

# Required for all unstract tools
unstract-sdk~=0.38.1
unstract-sdk~=0.39.0
2 changes: 1 addition & 1 deletion tools/classifier/src/config/properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"schemaVersion": "0.0.1",
"displayName": "File Classifier",
"functionName": "classify",
"toolVersion": "0.0.28",
"toolVersion": "0.0.29",
"description": "Classifies a file into a bin based on its contents",
"input": {
"description": "File to be classified"
Expand Down
37 changes: 2 additions & 35 deletions tools/classifier/src/config/runtime_variables.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,6 @@
"title": "Runtime Variables",
"description": "Runtime Variables for classifier",
"type": "object",
"required": [
"OPENAI_API_KEY",
"OPENAI_API_BASE",
"OPENAI_API_MODEL",
"OPENAI_API_ENGINE",
"OPENAI_API_VERSION"
],
"properties": {
"OPENAI_API_KEY": {
"type": "string",
"title": "OpenAI API Key",
"description": "Your OpenAI API key"
},
"OPENAI_API_BASE": {
"type": "string",
"title": "OpenAI API Base URL",
"description": "The base URL for the OpenAI API"
},
"OPENAI_API_MODEL": {
"type": "string",
"title": "OpenAI API Model",
"description": "The OpenAI model to use"
},
"OPENAI_API_ENGINE": {
"type": "string",
"title": "OpenAI API Engine",
"description": "The OpenAI engine to use"
},
"OPENAI_API_VERSION": {
"type": "string",
"title": "OpenAI API Version",
"description": "The OpenAI API version to use",
"default": "2023-05-15"
}
}
"required": [],
"properties": {}
}
2 changes: 1 addition & 1 deletion tools/classifier/src/config/spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"classificationBins": {
"type": "array",
"title": "Classification bins",
"description": "Specify at least two unique classification bins.",
"description": "Specify at least two unique classification bins. 'unknown' and '__unstract_failed' are reserved bins. 'unknown' indicates the LLM can't determine the classification and `__unstract_failed` indicates a tool run failure for the given file.",
"items": {
"type": "string"
},
Expand Down
136 changes: 119 additions & 17 deletions tools/classifier/src/helper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,79 @@
import re
import shutil
from pathlib import Path
from typing import Any, Optional

from unstract.sdk.cache import ToolCache
from unstract.sdk.constants import ToolEnv
from unstract.sdk.constants import LogLevel, MetadataKey, ToolEnv
from unstract.sdk.llm import LLM
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils import ToolUtils
from unstract.sdk.x2txt import TextExtractionResult, X2Text


class ReservedBins:
UNKNOWN = "unknown"
FAILED = "__unstract_failed"


class ClassifierHelper:
"""Helper functions for Classifier."""

def __init__(self, tool: BaseTool) -> None:
def __init__(self, tool: BaseTool, output_dir: str) -> None:
"""Creates a helper class for the Classifier tool.
Args:
tool (BaseTool): Base tool instance
output_dir (str): Output directory in TOOL_DATA_DIR
"""
self.tool = tool
self.output_dir = output_dir

def stream_error_and_exit(
self, message: str, bin_to_copy_to: str = ReservedBins.FAILED
) -> None:
"""Streams error logs and performs required cleanup.
Helper which copies files to a reserved bin in case of an error.
Args:
message (str): Error message to log
bin_to_copy_to (str): The folder to copy the failed source file to.
Defaults to `__unstract_failed`.
input_file (Optional[str], optional): Input file to copy. Defaults to None.
output_dir (Optional[str], optional): Output directory to copy to.
Defaults to None.
"""
source_name = self.tool.get_exec_metadata.get(MetadataKey.SOURCE_NAME)
self.copy_source_to_output_bin(
classification=bin_to_copy_to,
source_file=self.tool.get_source_file(),
source_name=source_name,
)

self.tool.stream_error_and_exit(message=message)

def copy_source_to_output_bin(
self,
classification: str,
source_file: str,
source_name: str,
) -> None:
"""Method to save result in output folder and the data directory.
Args:
classification (str): classification result
source_file (str): Path to source file used in the workflow
source_name (str): Name of the actual input from the source
"""
try:
output_folder_bin = Path(self.output_dir) / classification
if not output_folder_bin.is_dir():
output_folder_bin.mkdir(parents=True, exist_ok=True)

output_file = output_folder_bin / source_name
shutil.copyfile(source_file, output_file)
except Exception as e:
self.tool.stream_error_and_exit(f"Error creating output file: {e}")

def extract_text(
self, file: str, text_extraction_adapter_id: Optional[str]
Expand Down Expand Up @@ -106,22 +167,20 @@ def find_classification(
f"{ToolUtils.hash_str(settings_string)}:"
f"{ToolUtils.hash_str(prompt)}"
)
self.tool.stream_log("Trying to fetch result from cache.")
classification = self.get_result_from_cache(cache_key=cache_key)
if classification is not None:
return classification

if classification is None:
self.tool.stream_log("No classification found in cache, calling LLM.")
classification = self.call_llm(prompt=prompt, llm=llm)
if not classification:
classification = "unknown"
classification = classification.strip().lower()
bins = [bin.lower() for bin in bins]
if classification not in bins:
self.tool.stream_error_and_exit(
f"Invalid classification done: {classification}"
)
self.tool.stream_log("No classification found in cache, calling LLM.")
llm_response = self.call_llm(prompt=prompt, llm=llm)
classification = self.clean_llm_response(llm_response=llm_response, bins=bins)
if use_cache and cache_key:
self.tool.stream_log("Saving result to cache.")
self.save_result_to_cache(cache_key=cache_key, result=classification)
return classification

def call_llm(self, prompt: str, llm: LLM) -> Optional[str]:
def call_llm(self, prompt: str, llm: LLM) -> str:
"""Call LLM.
Args:
Expand All @@ -134,11 +193,54 @@ def call_llm(self, prompt: str, llm: LLM) -> Optional[str]:
try:
completion = llm.complete(prompt)[LLM.RESPONSE]
classification: str = completion.text.strip()
self.tool.stream_log(f"LLM response: {completion}")
self.tool.stream_log(f"LLM response: {completion}", level=LogLevel.DEBUG)
return classification
except Exception as e:
self.tool.stream_error_and_exit(f"Error calling LLM {e}")
return None
self.stream_error_and_exit(f"Error calling LLM: {e}")
raise e

def clean_llm_response(self, llm_response: str, bins: list[str]) -> str:
"""Cleans the response from the LLM.
Performs a substring search to find the returned classification.
Treats it as `unknown` if the classification is not clear
from the output.
Args:
llm_response (str): Response from LLM to clean
bins (list(str)): List of bins to classify the file into.
Returns:
str: Cleaned classification that matches one of the bins.
"""
classification = ReservedBins.UNKNOWN
cleaned_response = llm_response.strip().lower()
bins = [bin.lower() for bin in bins]

# Truncate llm_response to the first 100 words
words = cleaned_response.split()
truncated_response = " ".join(words[:100])

# Count occurrences of each bin in the truncated text
bin_counts = {
bin: len(re.findall(r"\b" + re.escape(bin) + r"\b", truncated_response))
for bin in bins
}

# Filter bins that have a count greater than 0
matching_bins = [bin for bin, count in bin_counts.items() if count > 0]

# Determine classification based on the number of matching bins
if len(matching_bins) == 1:
classification = matching_bins[0]
else:
self.stream_error_and_exit(
f"Unable to deduce classified bin from possible values of "
f"'{matching_bins}', moving file to '{ReservedBins.UNKNOWN}' "
"folder instead.",
bin_to_copy_to=ReservedBins.UNKNOWN,
)
return classification

def get_result_from_cache(self, cache_key: str) -> Optional[str]:
"""Get result from cache.
Expand Down
60 changes: 20 additions & 40 deletions tools/classifier/src/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import shutil
import sys
from pathlib import Path
from typing import Any, Optional

from helper import ClassifierHelper # type: ignore
from helper import ReservedBins
from unstract.sdk.constants import LogLevel, LogState, MetadataKey, ToolSettingsKey
from unstract.sdk.exceptions import SdkError
from unstract.sdk.llm import LLM
Expand All @@ -14,21 +13,26 @@
class UnstractClassifier(BaseTool):
def __init__(self, log_level: str = LogLevel.INFO) -> None:
super().__init__(log_level)
self.helper = ClassifierHelper(tool=self)

def validate(self, input_file: str, settings: dict[str, Any]) -> None:
bins: Optional[list[str]] = settings.get("classificationBins")
llm_adapter_instance_id = settings.get(ToolSettingsKey.LLM_ADAPTER_ID)
text_extraction_adapter_id = settings.get("textExtractorId")
if not bins:
self.stream_error_and_exit("Classification bins are required")
self.stream_error_and_exit("Classification bins are required.")
elif len(bins) < 2:
self.stream_error_and_exit("At least two bins are required")
self.stream_error_and_exit("At least two classification bins are required.")
elif ReservedBins.UNKNOWN in bins:
self.stream_error_and_exit(
f"Classification bin '{ReservedBins.UNKNOWN}' is reserved to mark "
"files which cannot be classified."
)

if not llm_adapter_instance_id:
self.stream_error_and_exit("Choose an LLM to process the classifier")
self.stream_error_and_exit("Choose an LLM to perform the classification.")
if not text_extraction_adapter_id:
self.stream_error_and_exit(
"Choose a text extractor to extract the documents"
"Choose a text extractor to extract the documents."
)

def run(
Expand All @@ -41,6 +45,7 @@ def run(
use_cache = settings["useCache"]
text_extraction_adapter_id = settings["textExtractorId"]
llm_adapter_instance_id = settings[ToolSettingsKey.LLM_ADAPTER_ID]
self.helper = ClassifierHelper(tool=self, output_dir=output_dir)

# Update GUI
input_log = f"### Classification bins:\n```text\n{bins}\n```\n\n"
Expand All @@ -54,7 +59,7 @@ def run(
text_extraction_adapter_id=text_extraction_adapter_id,
)
if not text:
self.stream_error_and_exit("Unable to extract text")
self.helper.stream_error_and_exit("Unable to extract text")
return
self.stream_log(f"Text length: {len(text)}")

Expand All @@ -70,8 +75,8 @@ def run(
self.stream_update(input_log, state=LogState.INPUT_UPDATE)
self.stream_update(output_log, state=LogState.OUTPUT_UPDATE)

if "unknown" not in bins:
bins.append("unknown")
if ReservedBins.UNKNOWN not in bins:
bins.append(ReservedBins.UNKNOWN)
bins_with_quotes = [f"'{b}'" for b in bins]

usage_kwargs: dict[Any, Any] = dict()
Expand All @@ -85,7 +90,8 @@ def run(
usage_kwargs=usage_kwargs,
)
except SdkError:
self.stream_error_and_exit("Unable to get llm instance")
self.helper.stream_error_and_exit("Unable to get llm instance")
return

max_tokens = llm.get_max_tokens(reserved_for_output=50 + 1000)
max_bytes = int(max_tokens * 1.3)
Expand All @@ -103,8 +109,8 @@ def run(
f"Classify the following text into one of the following categories: {' '.join(bins_with_quotes)}.\n\n" # noqa: E501
"Your categorization should be strictly exactly one of the items in the " # noqa: E501
"categories given, do not provide any explanation. Find a semantic match of category if possible. " # noqa: E501
"If it does not categorize well into any of the listed categories, categorize it as 'unknown'.\n\n" # noqa: E501
f"Text:\n\n{text}\n\n\nCategory:"
"If it does not categorize well into any of the listed categories, categorize it as 'unknown'." # noqa: E501
f"Do not enclose the result within single quotes.\n\nText:\n\n{text}\n\n\nCategory:" # noqa: E501
)

settings_string = "".join(str(value) for value in settings.values())
Expand All @@ -117,8 +123,7 @@ def run(
)

source_name = self.get_exec_metadata.get(MetadataKey.SOURCE_NAME)
self._copy_input_to_output_bin(
output_dir=output_dir,
self.helper.copy_source_to_output_bin(
classification=classification,
source_file=self.get_source_file(),
source_name=source_name,
Expand All @@ -134,31 +139,6 @@ def run(
}
self.write_tool_result(data=classification_dict)

def _copy_input_to_output_bin(
self,
output_dir: str,
classification: str,
source_file: str,
source_name: str,
) -> None:
"""Method to save result in output folder and the data directory.
Args:
output_dir (str): Output directory in TOOL_DATA_DIR
classification (str): classification result
source_file (str): Path to source file used in the workflow
source_name (str): Name of the actual input from the source
"""
try:
output_folder_bin = Path(output_dir) / classification
if not output_folder_bin.is_dir():
output_folder_bin.mkdir(parents=True, exist_ok=True)

output_file = output_folder_bin / source_name
shutil.copyfile(source_file, output_file)
except Exception as e:
self.stream_error_and_exit(f"Error creating output file: {e}")


if __name__ == "__main__":
args = sys.argv[1:]
Expand Down
2 changes: 1 addition & 1 deletion tools/structure/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Add your dependencies here

# Required for all unstract tools
unstract-sdk~=0.38.1
unstract-sdk~=0.39.0
2 changes: 1 addition & 1 deletion tools/structure/src/config/properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"schemaVersion": "0.0.1",
"displayName": "Structure Tool",
"functionName": "structure_tool",
"toolVersion": "0.0.32",
"toolVersion": "0.0.33",
"description": "This is a template tool which can answer set of input prompts designed in the Prompt Studio",
"input": {
"description": "File that needs to be indexed and parsed for answers"
Expand Down
2 changes: 1 addition & 1 deletion tools/text_extractor/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Add your dependencies here

# Required for all unstract tools
unstract-sdk~=0.38.1
unstract-sdk~=0.39.0
Loading

0 comments on commit deb266d

Please sign in to comment.