Skip to content

Commit

Permalink
UN-1920 Fix:Dynamic passing of File storage init
Browse files Browse the repository at this point in the history
  • Loading branch information
harini-venkataraman committed Dec 17, 2024
1 parent 3c491e5 commit 3386976
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 5 deletions.
13 changes: 13 additions & 0 deletions backend/prompt_studio/prompt_studio_core_v2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class ToolStudioPromptKeys:
RECORD = "record"
FILE_PATH = "file_path"
ENABLE_HIGHLIGHT = "enable_highlight"
EXECUTION_SOURCE = "execution_source"


class FileViewTypes:
Expand Down Expand Up @@ -132,3 +133,15 @@ class DefaultPrompts:
"Do not include any explanation in the reply. "
"Only include the extracted information in the reply."
)


class ExecutionSource(Enum):
"""Enum to indicate the source of invocation.
Any new sources can be added to this enum.
This is to indicate the prompt service.
Args:
Enum (_type_): ide/tool
"""

IDE = "ide"
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from prompt_studio.prompt_profile_manager_v2.profile_manager_helper import (
ProfileManagerHelper,
)
from prompt_studio.prompt_studio_core_v2.constants import IndexingStatus, LogLevels
from prompt_studio.prompt_studio_core_v2.constants import (
ExecutionSource,
IndexingStatus,
LogLevels,
)
from prompt_studio.prompt_studio_core_v2.constants import (
ToolStudioPromptKeys as TSPKeys,
)
Expand Down Expand Up @@ -1176,6 +1180,7 @@ def _fetch_single_pass_response(
TSPKeys.FILE_HASH: file_hash,
TSPKeys.FILE_NAME: doc_name,
Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID),
TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value,
}

util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id)
Expand Down
16 changes: 16 additions & 0 deletions prompt-service/src/unstract/prompt_service/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class PromptServiceContants:
FILE_PATH = "file_path"
HIGHLIGHT_DATA = "highlight_data"
CONFIDENCE_DATA = "confidence_data"
EXECUTION_SOURCE = "execution_source"


class RunLevel(Enum):
Expand Down Expand Up @@ -100,3 +101,18 @@ class DBTableV2:
PROMPT_STUDIO_REGISTRY = "prompt_studio_registry"
PLATFORM_KEY = "platform_key"
TOKEN_USAGE = "usage"


class FileStorageKeys:
FILE_STORAGE_PROVIDER = "FILE_STORAGE_PROVIDER"
FILE_STORAGE_CREDENTIALS = "FILE_STORAGE_CREDENTIALS"


class FileStorageType(Enum):
PERMANENT = "permanent"
TEMPORARY = "temporary"


class ExecutionSource(Enum):
IDE = "ide"
TOOL = "tool"
18 changes: 17 additions & 1 deletion prompt-service/src/unstract/prompt_service/env_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
import os
from typing import Optional
from typing import Any, Optional

from unstract.prompt_service.constants import FileStorageKeys


class EnvLoader:
Expand All @@ -9,3 +12,16 @@ def get_env_or_die(env_key: str, default: Optional[str] = None) -> str:
if env_value is None or env_value == "":
raise ValueError(f"Env variable {env_key} is required")
return env_value

@staticmethod
def load_provider_credentials() -> dict[str, Any]:
cred_env_data: str = EnvLoader.get_env_or_die(
env_key=FileStorageKeys.FILE_STORAGE_CREDENTIALS
)
credentials = json.loads(cred_env_data)
provider_data: dict[str, Any] = {}
provider_data[FileStorageKeys.FILE_STORAGE_PROVIDER] = credentials["provider"]
provider_data[FileStorageKeys.FILE_STORAGE_CREDENTIALS] = credentials[
"credentials"
]
return provider_data
43 changes: 40 additions & 3 deletions prompt-service/src/unstract/prompt_service/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from dotenv import load_dotenv
from flask import Flask, current_app
from unstract.prompt_service.config import db
from unstract.prompt_service.constants import DBTableV2
from unstract.prompt_service.constants import (
DBTableV2,
ExecutionSource,
FeatureFlag,
FileStorageType,
)
from unstract.prompt_service.constants import PromptServiceContants as PSKeys
from unstract.prompt_service.db_utils import DBUtils
from unstract.prompt_service.env_manager import EnvLoader
Expand All @@ -16,6 +21,14 @@
from unstract.sdk.exceptions import SdkError
from unstract.sdk.llm import LLM

from unstract.flags.src.unstract.flags.feature_flag import check_feature_flag_status

if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE):
from unstract.prompt_service.prompt_service_file_helper import (
PromptServiceFileHelper,
)
from unstract.sdk.file_storage import FileStorage

load_dotenv()

# Global variable to store plugins
Expand Down Expand Up @@ -278,6 +291,7 @@ def run_completion(
prompt_type: Optional[str] = PSKeys.TEXT,
enable_highlight: bool = False,
file_path: str = "",
execution_source: Optional[str] = None,
) -> str:
logger: Logger = current_app.logger
try:
Expand All @@ -286,8 +300,17 @@ def run_completion(
)
highlight_data = None
if highlight_data_plugin and enable_highlight:
fs_instance: FileStorage
if execution_source == ExecutionSource.IDE.value:
fs_instance = PromptServiceFileHelper.initialize_file_storage(
type=FileStorageType.PERMANENT
)
if execution_source == ExecutionSource.TOOL.value:
fs_instance = PromptServiceFileHelper.initialize_file_storage(
type=FileStorageType.TEMPORARY
)
highlight_data = highlight_data_plugin["entrypoint_cls"](
logger=current_app.logger, file_path=file_path
logger=current_app.logger, file_path=file_path, fs_instance=fs_instance
).run
completion = llm.complete(
prompt=prompt,
Expand Down Expand Up @@ -325,6 +348,7 @@ def extract_table(
structured_output: dict[str, Any],
llm: LLM,
enforce_type: str,
execution_source: str,
) -> dict[str, Any]:
table_settings = output[PSKeys.TABLE_SETTINGS]
table_extractor: dict[str, Any] = plugins.get("table-extractor", {})
Expand All @@ -333,9 +357,22 @@ def extract_table(
"Unable to extract table details. "
"Please contact admin to resolve this issue."
)
if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE):
fs_instance: FileStorage
if execution_source == ExecutionSource.IDE.value:
fs_instance = PromptServiceFileHelper.initialize_file_storage(
type=FileStorageType.PERMANENT
)
if execution_source == ExecutionSource.TOOL.value:
fs_instance = PromptServiceFileHelper.initialize_file_storage(
type=FileStorageType.TEMPORARY
)
try:
answer = table_extractor["entrypoint_cls"].extract_large_table(
llm=llm, table_settings=table_settings, enforce_type=enforce_type
llm=llm,
table_settings=table_settings,
enforce_type=enforce_type,
fs_instance=fs_instance,
)
structured_output[output[PSKeys.NAME]] = answer
# We do not support summary and eval for table.
Expand Down
3 changes: 3 additions & 0 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def prompt_processor() -> Any:
PSKeys.CONTEXT: {},
}
variable_names: list[str] = []
# Identifier for source of invocation
execution_source = payload.get(PSKeys.EXECUTION_SOURCE, "")
publish_log(
log_events_id,
{"tool_id": tool_id, "run_id": run_id, "doc_name": doc_name},
Expand Down Expand Up @@ -225,6 +227,7 @@ def prompt_processor() -> Any:
structured_output=structured_output,
llm=llm,
enforce_type=output[PSKeys.TYPE],
execution_source=execution_source,
)
metadata = query_usage_metadata(token=platform_key, metadata=metadata)
response = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Union

from unstract.prompt_service.constants import FileStorageKeys, FileStorageType
from unstract.prompt_service.env_manager import EnvLoader
from unstract.sdk.file_storage import (
FileStorageProvider,
PermanentFileStorage,
SharedTemporaryFileStorage,
)


class PromptServiceFileHelper:
@staticmethod
def initialize_file_storage(
type: FileStorageType,
) -> Union[PermanentFileStorage, SharedTemporaryFileStorage]:
provider_data = EnvLoader.load_provider_credentials()
provider = provider_data[FileStorageKeys.FILE_STORAGE_PROVIDER]
provider_value = FileStorageProvider(provider)
credentials = provider_data[FileStorageKeys.FILE_STORAGE_CREDENTIALS]
if type.value == FileStorageType.PERMANENT.value:
file_storage = PermanentFileStorage(provider=provider_value, **credentials)
elif type.value == FileStorageType.TEMPORARY.value:
file_storage = SharedTemporaryFileStorage(
provider=provider_value, **credentials
)
else:
file_storage = PermanentFileStorage(
provider=FileStorageProvider.LOCAL, **credentials
)
return file_storage
2 changes: 2 additions & 0 deletions tools/structure/src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,5 @@ class SettingsKeys:
CONFIDENCE_DATA = "confidence_data"
EXECUTION_RUN_DATA_FOLDER = "EXECUTION_RUN_DATA_FOLDER"
FILE_PATH = "file_path"
EXECUTION_SOURCE = "execution_source"
TOOL = "tool"
1 change: 1 addition & 0 deletions tools/structure/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def run(
SettingsKeys.FILE_HASH: file_hash,
SettingsKeys.FILE_NAME: file_name,
SettingsKeys.FILE_PATH: extracted_input_file,
SettingsKeys.EXECUTION_SOURCE: SettingsKeys.TOOL,
}
# TODO: Need to split extraction and indexing
# to avoid unwanted indexing
Expand Down

0 comments on commit 3386976

Please sign in to comment.