From 3386976d0bb8595d654a24b66bbcb1dada6a8bb0 Mon Sep 17 00:00:00 2001 From: harini-venkataraman Date: Tue, 17 Dec 2024 15:50:03 +0530 Subject: [PATCH] UN-1920 Fix:Dynamic passing of File storage init --- .../prompt_studio_core_v2/constants.py | 13 ++++++ .../prompt_studio_helper.py | 7 ++- .../src/unstract/prompt_service/constants.py | 16 +++++++ .../unstract/prompt_service/env_manager.py | 18 +++++++- .../src/unstract/prompt_service/helper.py | 43 +++++++++++++++++-- .../src/unstract/prompt_service/main.py | 3 ++ .../prompt_service_file_helper.py | 31 +++++++++++++ tools/structure/src/constants.py | 2 + tools/structure/src/main.py | 1 + 9 files changed, 129 insertions(+), 5 deletions(-) create mode 100644 prompt-service/src/unstract/prompt_service/prompt_service_file_helper.py diff --git a/backend/prompt_studio/prompt_studio_core_v2/constants.py b/backend/prompt_studio/prompt_studio_core_v2/constants.py index cb335b90f..2c6a80ac6 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/constants.py +++ b/backend/prompt_studio/prompt_studio_core_v2/constants.py @@ -96,6 +96,7 @@ class ToolStudioPromptKeys: RECORD = "record" FILE_PATH = "file_path" ENABLE_HIGHLIGHT = "enable_highlight" + EXECUTION_SOURCE = "execution_source" class FileViewTypes: @@ -132,3 +133,15 @@ class DefaultPrompts: "Do not include any explanation in the reply. " "Only include the extracted information in the reply." ) + + +class ExecutionSource(Enum): + """Enum to indicate the source of invocation. + Any new sources can be added to this enum. + This is to indicate the prompt service. + + Args: + Enum (_type_): ide/tool + """ + + IDE = "ide" diff --git a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py index 6bcacf340..7985173a1 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py @@ -19,7 +19,11 @@ from prompt_studio.prompt_profile_manager_v2.profile_manager_helper import ( ProfileManagerHelper, ) -from prompt_studio.prompt_studio_core_v2.constants import IndexingStatus, LogLevels +from prompt_studio.prompt_studio_core_v2.constants import ( + ExecutionSource, + IndexingStatus, + LogLevels, +) from prompt_studio.prompt_studio_core_v2.constants import ( ToolStudioPromptKeys as TSPKeys, ) @@ -1176,6 +1180,7 @@ def _fetch_single_pass_response( TSPKeys.FILE_HASH: file_hash, TSPKeys.FILE_NAME: doc_name, Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), + TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, } util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) diff --git a/prompt-service/src/unstract/prompt_service/constants.py b/prompt-service/src/unstract/prompt_service/constants.py index 96807e4d8..0d4dc4521 100644 --- a/prompt-service/src/unstract/prompt_service/constants.py +++ b/prompt-service/src/unstract/prompt_service/constants.py @@ -72,6 +72,7 @@ class PromptServiceContants: FILE_PATH = "file_path" HIGHLIGHT_DATA = "highlight_data" CONFIDENCE_DATA = "confidence_data" + EXECUTION_SOURCE = "execution_source" class RunLevel(Enum): @@ -100,3 +101,18 @@ class DBTableV2: PROMPT_STUDIO_REGISTRY = "prompt_studio_registry" PLATFORM_KEY = "platform_key" TOKEN_USAGE = "usage" + + +class FileStorageKeys: + FILE_STORAGE_PROVIDER = "FILE_STORAGE_PROVIDER" + FILE_STORAGE_CREDENTIALS = "FILE_STORAGE_CREDENTIALS" + + +class FileStorageType(Enum): + PERMANENT = "permanent" + TEMPORARY = "temporary" + + +class ExecutionSource(Enum): + IDE = "ide" + TOOL = "tool" diff --git a/prompt-service/src/unstract/prompt_service/env_manager.py b/prompt-service/src/unstract/prompt_service/env_manager.py index c51f5cf80..745b9083d 100644 --- a/prompt-service/src/unstract/prompt_service/env_manager.py +++ b/prompt-service/src/unstract/prompt_service/env_manager.py @@ -1,5 +1,8 @@ +import json import os -from typing import Optional +from typing import Any, Optional + +from unstract.prompt_service.constants import FileStorageKeys class EnvLoader: @@ -9,3 +12,16 @@ def get_env_or_die(env_key: str, default: Optional[str] = None) -> str: if env_value is None or env_value == "": raise ValueError(f"Env variable {env_key} is required") return env_value + + @staticmethod + def load_provider_credentials() -> dict[str, Any]: + cred_env_data: str = EnvLoader.get_env_or_die( + env_key=FileStorageKeys.FILE_STORAGE_CREDENTIALS + ) + credentials = json.loads(cred_env_data) + provider_data: dict[str, Any] = {} + provider_data[FileStorageKeys.FILE_STORAGE_PROVIDER] = credentials["provider"] + provider_data[FileStorageKeys.FILE_STORAGE_CREDENTIALS] = credentials[ + "credentials" + ] + return provider_data diff --git a/prompt-service/src/unstract/prompt_service/helper.py b/prompt-service/src/unstract/prompt_service/helper.py index 62a24f2d1..b64522542 100644 --- a/prompt-service/src/unstract/prompt_service/helper.py +++ b/prompt-service/src/unstract/prompt_service/helper.py @@ -7,7 +7,12 @@ from dotenv import load_dotenv from flask import Flask, current_app from unstract.prompt_service.config import db -from unstract.prompt_service.constants import DBTableV2 +from unstract.prompt_service.constants import ( + DBTableV2, + ExecutionSource, + FeatureFlag, + FileStorageType, +) from unstract.prompt_service.constants import PromptServiceContants as PSKeys from unstract.prompt_service.db_utils import DBUtils from unstract.prompt_service.env_manager import EnvLoader @@ -16,6 +21,14 @@ from unstract.sdk.exceptions import SdkError from unstract.sdk.llm import LLM +from unstract.flags.src.unstract.flags.feature_flag import check_feature_flag_status + +if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE): + from unstract.prompt_service.prompt_service_file_helper import ( + PromptServiceFileHelper, + ) + from unstract.sdk.file_storage import FileStorage + load_dotenv() # Global variable to store plugins @@ -278,6 +291,7 @@ def run_completion( prompt_type: Optional[str] = PSKeys.TEXT, enable_highlight: bool = False, file_path: str = "", + execution_source: Optional[str] = None, ) -> str: logger: Logger = current_app.logger try: @@ -286,8 +300,17 @@ def run_completion( ) highlight_data = None if highlight_data_plugin and enable_highlight: + fs_instance: FileStorage + if execution_source == ExecutionSource.IDE.value: + fs_instance = PromptServiceFileHelper.initialize_file_storage( + type=FileStorageType.PERMANENT + ) + if execution_source == ExecutionSource.TOOL.value: + fs_instance = PromptServiceFileHelper.initialize_file_storage( + type=FileStorageType.TEMPORARY + ) highlight_data = highlight_data_plugin["entrypoint_cls"]( - logger=current_app.logger, file_path=file_path + logger=current_app.logger, file_path=file_path, fs_instance=fs_instance ).run completion = llm.complete( prompt=prompt, @@ -325,6 +348,7 @@ def extract_table( structured_output: dict[str, Any], llm: LLM, enforce_type: str, + execution_source: str, ) -> dict[str, Any]: table_settings = output[PSKeys.TABLE_SETTINGS] table_extractor: dict[str, Any] = plugins.get("table-extractor", {}) @@ -333,9 +357,22 @@ def extract_table( "Unable to extract table details. " "Please contact admin to resolve this issue." ) + if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE): + fs_instance: FileStorage + if execution_source == ExecutionSource.IDE.value: + fs_instance = PromptServiceFileHelper.initialize_file_storage( + type=FileStorageType.PERMANENT + ) + if execution_source == ExecutionSource.TOOL.value: + fs_instance = PromptServiceFileHelper.initialize_file_storage( + type=FileStorageType.TEMPORARY + ) try: answer = table_extractor["entrypoint_cls"].extract_large_table( - llm=llm, table_settings=table_settings, enforce_type=enforce_type + llm=llm, + table_settings=table_settings, + enforce_type=enforce_type, + fs_instance=fs_instance, ) structured_output[output[PSKeys.NAME]] = answer # We do not support summary and eval for table. diff --git a/prompt-service/src/unstract/prompt_service/main.py b/prompt-service/src/unstract/prompt_service/main.py index 531617cf8..5d9840641 100644 --- a/prompt-service/src/unstract/prompt_service/main.py +++ b/prompt-service/src/unstract/prompt_service/main.py @@ -110,6 +110,8 @@ def prompt_processor() -> Any: PSKeys.CONTEXT: {}, } variable_names: list[str] = [] + # Identifier for source of invocation + execution_source = payload.get(PSKeys.EXECUTION_SOURCE, "") publish_log( log_events_id, {"tool_id": tool_id, "run_id": run_id, "doc_name": doc_name}, @@ -225,6 +227,7 @@ def prompt_processor() -> Any: structured_output=structured_output, llm=llm, enforce_type=output[PSKeys.TYPE], + execution_source=execution_source, ) metadata = query_usage_metadata(token=platform_key, metadata=metadata) response = { diff --git a/prompt-service/src/unstract/prompt_service/prompt_service_file_helper.py b/prompt-service/src/unstract/prompt_service/prompt_service_file_helper.py new file mode 100644 index 000000000..b194c5170 --- /dev/null +++ b/prompt-service/src/unstract/prompt_service/prompt_service_file_helper.py @@ -0,0 +1,31 @@ +from typing import Union + +from unstract.prompt_service.constants import FileStorageKeys, FileStorageType +from unstract.prompt_service.env_manager import EnvLoader +from unstract.sdk.file_storage import ( + FileStorageProvider, + PermanentFileStorage, + SharedTemporaryFileStorage, +) + + +class PromptServiceFileHelper: + @staticmethod + def initialize_file_storage( + type: FileStorageType, + ) -> Union[PermanentFileStorage, SharedTemporaryFileStorage]: + provider_data = EnvLoader.load_provider_credentials() + provider = provider_data[FileStorageKeys.FILE_STORAGE_PROVIDER] + provider_value = FileStorageProvider(provider) + credentials = provider_data[FileStorageKeys.FILE_STORAGE_CREDENTIALS] + if type.value == FileStorageType.PERMANENT.value: + file_storage = PermanentFileStorage(provider=provider_value, **credentials) + elif type.value == FileStorageType.TEMPORARY.value: + file_storage = SharedTemporaryFileStorage( + provider=provider_value, **credentials + ) + else: + file_storage = PermanentFileStorage( + provider=FileStorageProvider.LOCAL, **credentials + ) + return file_storage diff --git a/tools/structure/src/constants.py b/tools/structure/src/constants.py index 8cf7c8653..cf8999905 100644 --- a/tools/structure/src/constants.py +++ b/tools/structure/src/constants.py @@ -75,3 +75,5 @@ class SettingsKeys: CONFIDENCE_DATA = "confidence_data" EXECUTION_RUN_DATA_FOLDER = "EXECUTION_RUN_DATA_FOLDER" FILE_PATH = "file_path" + EXECUTION_SOURCE = "execution_source" + TOOL = "tool" diff --git a/tools/structure/src/main.py b/tools/structure/src/main.py index a1736c554..776af010d 100644 --- a/tools/structure/src/main.py +++ b/tools/structure/src/main.py @@ -115,6 +115,7 @@ def run( SettingsKeys.FILE_HASH: file_hash, SettingsKeys.FILE_NAME: file_name, SettingsKeys.FILE_PATH: extracted_input_file, + SettingsKeys.EXECUTION_SOURCE: SettingsKeys.TOOL, } # TODO: Need to split extraction and indexing # to avoid unwanted indexing