diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index c6c7663c3..a9a7cae44 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -7,13 +7,11 @@ from temporalio import activity from ..clients import cozo, litellm -from ..common.storage_handler import auto_blob_store from ..env import testing from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query from .types import EmbedDocsPayload -@auto_blob_store(deep=True) @beartype async def embed_docs( payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100 diff --git a/agents-api/agents_api/activities/excecute_api_call.py b/agents-api/agents_api/activities/excecute_api_call.py index 09a33aaa8..2167aaead 100644 --- a/agents-api/agents_api/activities/excecute_api_call.py +++ b/agents-api/agents_api/activities/excecute_api_call.py @@ -6,7 +6,6 @@ from temporalio import activity from ..autogen.openapi_model import ApiCallDef -from ..common.storage_handler import auto_blob_store from ..env import testing @@ -20,7 +19,6 @@ class RequestArgs(TypedDict): headers: Optional[dict[str, str]] -@auto_blob_store(deep=True) @beartype async def execute_api_call( api_call: ApiCallDef, diff --git a/agents-api/agents_api/activities/execute_integration.py b/agents-api/agents_api/activities/execute_integration.py index 3316ad6f5..d058553c4 100644 --- a/agents-api/agents_api/activities/execute_integration.py +++ b/agents-api/agents_api/activities/execute_integration.py @@ -7,12 +7,10 @@ from ..clients import integrations from ..common.exceptions.tools import IntegrationExecutionException from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store from ..env import testing from ..models.tools import get_tool_args_from_metadata -@auto_blob_store(deep=True) @beartype async def execute_integration( context: StepContext, diff --git a/agents-api/agents_api/activities/execute_system.py b/agents-api/agents_api/activities/execute_system.py index 590849080..647327a8a 100644 --- a/agents-api/agents_api/activities/execute_system.py +++ b/agents-api/agents_api/activities/execute_system.py @@ -19,16 +19,14 @@ VectorDocSearchRequest, ) from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote from ..env import testing -from ..queries.developer import get_developer +from ..queries.developers import get_developer from .utils import get_handler # For running synchronous code in the background process_pool_executor = ProcessPoolExecutor() -@auto_blob_store(deep=True) @beartype async def execute_system( context: StepContext, @@ -37,9 +35,6 @@ async def execute_system( """Execute a system call with the appropriate handler and transformed arguments.""" arguments: dict[str, Any] = system.arguments or {} - if set(arguments.keys()) == {"bucket", "key"}: - arguments = await load_from_blob_store_if_remote(arguments) - if not isinstance(context.execution_input, ExecutionInput): raise TypeError("Expected ExecutionInput type for context.execution_input") diff --git a/agents-api/agents_api/activities/sync_items_remote.py b/agents-api/agents_api/activities/sync_items_remote.py index d71a5c566..14751c2b6 100644 --- a/agents-api/agents_api/activities/sync_items_remote.py +++ b/agents-api/agents_api/activities/sync_items_remote.py @@ -9,20 +9,16 @@ @beartype async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]: - from ..common.storage_handler import store_in_blob_store_if_large + from ..common.interceptors import offload_if_large - return await asyncio.gather( - *[store_in_blob_store_if_large(input) for input in inputs] - ) + return await asyncio.gather(*[offload_if_large(input) for input in inputs]) @beartype async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]: - from ..common.storage_handler import load_from_blob_store_if_remote + from ..common.interceptors import load_if_remote - return await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) + return await asyncio.gather(*[load_if_remote(input) for input in inputs]) save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn) diff --git a/agents-api/agents_api/activities/task_steps/base_evaluate.py b/agents-api/agents_api/activities/task_steps/base_evaluate.py index d87b961d3..3bb04e390 100644 --- a/agents-api/agents_api/activities/task_steps/base_evaluate.py +++ b/agents-api/agents_api/activities/task_steps/base_evaluate.py @@ -13,7 +13,6 @@ from temporalio import activity # noqa: E402 from thefuzz import fuzz # noqa: E402 -from ...common.storage_handler import auto_blob_store # noqa: E402 from ...env import testing # noqa: E402 from ..utils import get_evaluator # noqa: E402 @@ -63,7 +62,6 @@ def _recursive_evaluate(expr, evaluator: SimpleEval): raise ValueError(f"Invalid expression: {expr}") -@auto_blob_store(deep=True) @beartype async def base_evaluate( exprs: Any, diff --git a/agents-api/agents_api/activities/task_steps/cozo_query_step.py b/agents-api/agents_api/activities/task_steps/cozo_query_step.py index 16e9a53d8..8d28d83c9 100644 --- a/agents-api/agents_api/activities/task_steps/cozo_query_step.py +++ b/agents-api/agents_api/activities/task_steps/cozo_query_step.py @@ -4,11 +4,9 @@ from temporalio import activity from ... import models -from ...common.storage_handler import auto_blob_store from ...env import testing -@auto_blob_store(deep=True) @beartype async def cozo_query_step( query_name: str, diff --git a/agents-api/agents_api/activities/task_steps/evaluate_step.py b/agents-api/agents_api/activities/task_steps/evaluate_step.py index 904ec3b9d..08fa6cd55 100644 --- a/agents-api/agents_api/activities/task_steps/evaluate_step.py +++ b/agents-api/agents_api/activities/task_steps/evaluate_step.py @@ -5,11 +5,9 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing -@auto_blob_store(deep=True) @beartype async def evaluate_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/for_each_step.py b/agents-api/agents_api/activities/task_steps/for_each_step.py index f51c1ef76..ca84eb75d 100644 --- a/agents-api/agents_api/activities/task_steps/for_each_step.py +++ b/agents-api/agents_api/activities/task_steps/for_each_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def for_each_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/get_value_step.py b/agents-api/agents_api/activities/task_steps/get_value_step.py index ca38bc4fe..feeb71bbf 100644 --- a/agents-api/agents_api/activities/task_steps/get_value_step.py +++ b/agents-api/agents_api/activities/task_steps/get_value_step.py @@ -2,13 +2,12 @@ from temporalio import activity from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing - # TODO: We should use this step to query the parent workflow and get the value from the workflow context # SCRUM-1 -@auto_blob_store(deep=True) + + @beartype async def get_value_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/if_else_step.py b/agents-api/agents_api/activities/task_steps/if_else_step.py index cf3764199..ec4368640 100644 --- a/agents-api/agents_api/activities/task_steps/if_else_step.py +++ b/agents-api/agents_api/activities/task_steps/if_else_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def if_else_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression diff --git a/agents-api/agents_api/activities/task_steps/log_step.py b/agents-api/agents_api/activities/task_steps/log_step.py index 28fea2dae..f54018683 100644 --- a/agents-api/agents_api/activities/task_steps/log_step.py +++ b/agents-api/agents_api/activities/task_steps/log_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import testing -@auto_blob_store(deep=True) @beartype async def log_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression diff --git a/agents-api/agents_api/activities/task_steps/map_reduce_step.py b/agents-api/agents_api/activities/task_steps/map_reduce_step.py index 872988bb4..c39bace20 100644 --- a/agents-api/agents_api/activities/task_steps/map_reduce_step.py +++ b/agents-api/agents_api/activities/task_steps/map_reduce_step.py @@ -8,12 +8,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def map_reduce_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index cf8b169d5..47560cadd 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -8,7 +8,6 @@ litellm, # We dont directly import `acompletion` so we can mock it ) from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import debug from .base_evaluate import base_evaluate @@ -62,7 +61,6 @@ def format_tool(tool: Tool) -> dict: @activity.defn -@auto_blob_store(deep=True) @beartype async def prompt_step(context: StepContext) -> StepOutcome: # Get context data diff --git a/agents-api/agents_api/activities/task_steps/raise_complete_async.py b/agents-api/agents_api/activities/task_steps/raise_complete_async.py index 640d6ae4e..bbf27c500 100644 --- a/agents-api/agents_api/activities/task_steps/raise_complete_async.py +++ b/agents-api/agents_api/activities/task_steps/raise_complete_async.py @@ -6,12 +6,10 @@ from ...autogen.openapi_model import CreateTransitionRequest from ...common.protocol.tasks import StepContext -from ...common.storage_handler import auto_blob_store from .transition_step import original_transition_step @activity.defn -@auto_blob_store(deep=True) @beartype async def raise_complete_async(context: StepContext, output: Any) -> None: activity_info = activity.info() diff --git a/agents-api/agents_api/activities/task_steps/return_step.py b/agents-api/agents_api/activities/task_steps/return_step.py index 08ac20de4..f15354536 100644 --- a/agents-api/agents_api/activities/task_steps/return_step.py +++ b/agents-api/agents_api/activities/task_steps/return_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def return_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/set_value_step.py b/agents-api/agents_api/activities/task_steps/set_value_step.py index 1c97b6551..96db5d0d1 100644 --- a/agents-api/agents_api/activities/task_steps/set_value_step.py +++ b/agents-api/agents_api/activities/task_steps/set_value_step.py @@ -5,13 +5,12 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing - # TODO: We should use this step to signal to the parent workflow and set the value on the workflow context # SCRUM-2 -@auto_blob_store(deep=True) + + @beartype async def set_value_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/switch_step.py b/agents-api/agents_api/activities/task_steps/switch_step.py index 6a95e98d2..100d8020a 100644 --- a/agents-api/agents_api/activities/task_steps/switch_step.py +++ b/agents-api/agents_api/activities/task_steps/switch_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from ..utils import get_evaluator -@auto_blob_store(deep=True) @beartype async def switch_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index 5725a75d1..a2d7fd7c2 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -11,7 +11,6 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store # FIXME: This shouldn't be here. @@ -47,7 +46,6 @@ def construct_tool_call( @activity.defn -@auto_blob_store(deep=True) @beartype async def tool_call_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ToolCallStep) diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py index 44046a5e7..11c7befb5 100644 --- a/agents-api/agents_api/activities/task_steps/transition_step.py +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -8,7 +8,6 @@ from ...autogen.openapi_model import CreateTransitionRequest, Transition from ...clients.temporal import get_workflow_handle from ...common.protocol.tasks import ExecutionInput, StepContext -from ...common.storage_handler import load_from_blob_store_if_remote from ...env import ( temporal_activity_after_retry_timeout, testing, @@ -48,11 +47,6 @@ async def transition_step( TaskExecutionWorkflow.set_last_error, LastErrorInput(last_error=None) ) - # Load output from blob store if it is a remote object - transition_info.output = await load_from_blob_store_if_remote( - transition_info.output - ) - if not isinstance(context.execution_input, ExecutionInput): raise TypeError("Expected ExecutionInput type for context.execution_input") diff --git a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py index ad6eeb63e..a3cb00f67 100644 --- a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py +++ b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py @@ -3,12 +3,10 @@ from ...autogen.openapi_model import WaitForInputStep from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def wait_for_input_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py index 199008703..18e5383cc 100644 --- a/agents-api/agents_api/activities/task_steps/yield_step.py +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -5,12 +5,10 @@ from ...autogen.openapi_model import TransitionTarget, YieldStep from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def yield_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index af73e8015..d809e0a35 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -14,7 +14,6 @@ model_validator, ) -from ..common.storage_handler import RemoteObject from ..common.utils.datetime import utcnow from .Agents import * from .Chat import * @@ -358,7 +357,7 @@ def validate_subworkflows(self): class SystemDef(SystemDef): - arguments: dict[str, Any] | None | RemoteObject = None + arguments: dict[str, Any] | None = None class CreateTransitionRequest(Transition): diff --git a/agents-api/agents_api/clients/async_s3.py b/agents-api/agents_api/clients/async_s3.py index 0cd5235ee..b6ba76d8b 100644 --- a/agents-api/agents_api/clients/async_s3.py +++ b/agents-api/agents_api/clients/async_s3.py @@ -16,6 +16,7 @@ ) +@alru_cache(maxsize=1024) async def list_buckets() -> list[str]: session = get_session() diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index da2d7f6fa..cd2178d95 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -1,3 +1,4 @@ +import asyncio from datetime import timedelta from uuid import UUID @@ -12,9 +13,9 @@ from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig from ..autogen.openapi_model import TransitionTarget +from ..common.interceptors import offload_if_large from ..common.protocol.tasks import ExecutionInput from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..common.storage_handler import store_in_blob_store_if_large from ..env import ( temporal_client_cert, temporal_metrics_bind_host, @@ -96,8 +97,10 @@ async def run_task_execution_workflow( client = client or (await get_client()) execution_id = execution_input.execution.id execution_id_key = SearchAttributeKey.for_keyword("CustomStringField") - execution_input.arguments = await store_in_blob_store_if_large( - execution_input.arguments + + old_args = execution_input.arguments + execution_input.arguments = await asyncio.gather( + *[offload_if_large(arg) for arg in old_args] ) return await client.start_workflow( diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 40600a818..bfd64c374 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -4,8 +4,12 @@ certain types of errors that are known to be non-retryable. """ -from typing import Optional, Type +import asyncio +import sys +from functools import wraps +from typing import Any, Awaitable, Callable, Optional, Sequence, Type +from temporalio import workflow from temporalio.activity import _CompleteAsyncError as CompleteAsyncError from temporalio.exceptions import ApplicationError, FailureError, TemporalError from temporalio.service import RPCError @@ -23,7 +27,97 @@ ReadOnlyContextError, ) -from .exceptions.tasks import is_retryable_error +with workflow.unsafe.imports_passed_through(): + from ..env import blob_store_cutoff_kb, use_blob_store_for_temporal + from .exceptions.tasks import is_retryable_error + from .protocol.remote import RemoteObject + +# Common exceptions that should be re-raised without modification +PASSTHROUGH_EXCEPTIONS = ( + ContinueAsNewError, + ReadOnlyContextError, + NondeterminismError, + RPCError, + CompleteAsyncError, + TemporalError, + FailureError, + ApplicationError, +) + + +def is_too_large(result: Any) -> bool: + return sys.getsizeof(result) > blob_store_cutoff_kb * 1024 + + +async def load_if_remote[T](arg: T | RemoteObject[T]) -> T: + if use_blob_store_for_temporal and isinstance(arg, RemoteObject): + return await arg.load() + + return arg + + +async def offload_if_large[T](result: T) -> T: + if use_blob_store_for_temporal and is_too_large(result): + return await RemoteObject.from_value(result) + + return result + + +def offload_to_blob_store[S, T]( + func: Callable[[S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T]], +) -> Callable[ + [S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T | RemoteObject[T]] +]: + @wraps(func) + async def wrapper( + self, + input: ExecuteActivityInput | ExecuteWorkflowInput, + ) -> T | RemoteObject[T]: + # Load all remote arguments from the blob store + args: Sequence[Any] = input.args + + if use_blob_store_for_temporal: + input.args = await asyncio.gather(*[load_if_remote(arg) for arg in args]) + + # Execute the function + result = await func(self, input) + + # Save the result to the blob store if necessary + return await offload_if_large(result) + + return wrapper + + +async def handle_execution_with_errors[I, T]( + execution_fn: Callable[[I], Awaitable[T]], + input: I, +) -> T: + """ + Common error handling logic for both activities and workflows. + + Args: + execution_fn: Async function to execute with error handling + input: Input to the execution function + + Returns: + The result of the execution function + + Raises: + ApplicationError: For non-retryable errors + Any other exception: For retryable errors + """ + try: + return await execution_fn(input) + except PASSTHROUGH_EXCEPTIONS: + raise + except BaseException as e: + if not is_retryable_error(e): + raise ApplicationError( + str(e), + type=type(e).__name__, + non_retryable=True, + ) + raise class CustomActivityInterceptor(ActivityInboundInterceptor): @@ -35,95 +129,45 @@ class CustomActivityInterceptor(ActivityInboundInterceptor): as non-retryable errors. """ - async def execute_activity(self, input: ExecuteActivityInput): + @offload_to_blob_store + async def execute_activity(self, input: ExecuteActivityInput) -> Any: """ - 🎭 The Activity Whisperer: Handles activity execution with style and grace - - This is like a safety net for your activities - catching errors and deciding - their fate with the wisdom of a fortune cookie. + Handles activity execution by intercepting errors and determining their retry behavior. """ - try: - return await super().execute_activity(input) - except ( - ContinueAsNewError, # When you need a fresh start - ReadOnlyContextError, # When someone tries to write in a museum - NondeterminismError, # When chaos theory kicks in - RPCError, # When computers can't talk to each other - CompleteAsyncError, # When async goes wrong - TemporalError, # When time itself rebels - FailureError, # When failure is not an option, but happens anyway - ApplicationError, # When the app says "nope" - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # If it's not retryable, we wrap it in a nice bow (ApplicationError) - # and mark it as non-retryable to prevent further attempts - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # For retryable errors, we'll let Temporal retry with backoff - # Default retry policy ensures at least 2 retries - raise + return await handle_execution_with_errors( + super().execute_activity, + input, + ) class CustomWorkflowInterceptor(WorkflowInboundInterceptor): """ - 🎪 The Workflow Circus Ringmaster + Custom interceptor for Temporal workflows. - This interceptor is like a circus ringmaster - keeping all the workflow acts - running smoothly and catching any lions (errors) that escape their cages. + Handles workflow execution errors and determines their retry behavior. """ - async def execute_workflow(self, input: ExecuteWorkflowInput): + @offload_to_blob_store + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: """ - 🎪 The Main Event: Workflow Execution Extravaganza! - - Watch as we gracefully handle errors like a trapeze artist catching their partner! + Executes workflows and handles error cases appropriately. """ - try: - return await super().execute_workflow(input) - except ( - ContinueAsNewError, # The show must go on! - ReadOnlyContextError, # No touching, please! - NondeterminismError, # When butterflies cause hurricanes - RPCError, # Lost in translation - CompleteAsyncError, # Async said "bye" too soon - TemporalError, # Time is relative, errors are absolute - FailureError, # Task failed successfully - ApplicationError, # App.exe has stopped working - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # Pack the error in a nice box with a "do not retry" sticker - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # Let it retry - everyone deserves a second (or third) chance! - raise + return await handle_execution_with_errors( + super().execute_workflow, + input, + ) class CustomInterceptor(Interceptor): """ - 🎭 The Grand Interceptor: Master of Ceremonies - - This is like the backstage manager of a theater - making sure both the - activity actors and workflow directors have their interceptor costumes on. + Main interceptor class that provides both activity and workflow interceptors. """ def intercept_activity( self, next: ActivityInboundInterceptor ) -> ActivityInboundInterceptor: """ - 🎬 Activity Interceptor Factory: Where the magic begins! - - Creating custom activity interceptors faster than a caffeinated barista - makes espresso shots. + Creates and returns a custom activity interceptor. """ return CustomActivityInterceptor(super().intercept_activity(next)) @@ -131,9 +175,6 @@ def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput ) -> Optional[Type[WorkflowInboundInterceptor]]: """ - 🎪 Workflow Interceptor Class Selector - - Like a matchmaker for workflows and their interceptors - a match made in - exception handling heaven! + Returns the custom workflow interceptor class. """ return CustomWorkflowInterceptor diff --git a/agents-api/agents_api/common/protocol/remote.py b/agents-api/agents_api/common/protocol/remote.py index ce2a2a63a..86add1949 100644 --- a/agents-api/agents_api/common/protocol/remote.py +++ b/agents-api/agents_api/common/protocol/remote.py @@ -1,91 +1,34 @@ from dataclasses import dataclass -from typing import Any +from typing import Generic, Self, Type, TypeVar, cast -from temporalio import activity, workflow +from temporalio import workflow with workflow.unsafe.imports_passed_through(): - from pydantic import BaseModel - + from ...clients import async_s3 from ...env import blob_store_bucket + from ...worker.codec import deserialize, serialize -@dataclass -class RemoteObject: - key: str - bucket: str = blob_store_bucket - - -class BaseRemoteModel(BaseModel): - _remote_cache: dict[str, Any] - - class Config: - arbitrary_types_allowed = True - - def __init__(self, **data: Any): - super().__init__(**data) - self._remote_cache = {} - - async def load_item(self, item: Any | RemoteObject) -> Any: - if not activity.in_activity(): - return item - - from ..storage_handler import load_from_blob_store_if_remote - - return await load_from_blob_store_if_remote(item) +T = TypeVar("T") - async def save_item(self, item: Any) -> Any: - if not activity.in_activity(): - return item - from ..storage_handler import store_in_blob_store_if_large - - return await store_in_blob_store_if_large(item) - - async def get_attribute(self, name: str) -> Any: - if name.startswith("_"): - return super().__getattribute__(name) - - try: - value = super().__getattribute__(name) - except AttributeError: - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) - - if isinstance(value, RemoteObject): - cache = super().__getattribute__("_remote_cache") - if name in cache: - return cache[name] - - loaded_data = await self.load_item(value) - cache[name] = loaded_data - return loaded_data - - return value - - async def set_attribute(self, name: str, value: Any) -> None: - if name.startswith("_"): - super().__setattr__(name, value) - return +@dataclass +class RemoteObject(Generic[T]): + _type: Type[T] + key: str + bucket: str - stored_value = await self.save_item(value) - super().__setattr__(name, stored_value) + @classmethod + async def from_value(cls, x: T) -> Self: + await async_s3.setup() - if isinstance(stored_value, RemoteObject): - cache = self.__dict__.get("_remote_cache", {}) - cache.pop(name, None) + serialized = serialize(x) - async def load_all(self) -> None: - for name in self.model_fields_set: - await self.get_attribute(name) + key = await async_s3.add_object_with_hash(serialized) + return RemoteObject[T](key=key, bucket=blob_store_bucket, _type=type(x)) - async def unload_attribute(self, name: str) -> None: - if name in self._remote_cache: - data = self._remote_cache.pop(name) - remote_obj = await self.save_item(data) - super().__setattr__(name, remote_obj) + async def load(self) -> T: + await async_s3.setup() - async def unload_all(self) -> "BaseRemoteModel": - for name in list(self._remote_cache.keys()): - await self.unload_attribute(name) - return self + fetched = await async_s3.get_object(self.key) + return cast(self._type, deserialize(fetched)) diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index 121afe702..3b04178e1 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -103,7 +103,7 @@ def get_active_tools(self) -> list[Tool]: return active_toolset.tools - def get_chat_environment(self) -> dict[str, dict | list[dict]]: + def get_chat_environment(self) -> dict[str, dict | list[dict] | None]: """ Get the chat environment from the session data. """ diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 430a62f36..f3bb81d07 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -1,9 +1,8 @@ -import asyncio from typing import Annotated, Any, Literal from uuid import UUID from beartype import beartype -from temporalio import activity, workflow +from temporalio import workflow from temporalio.exceptions import ApplicationError with workflow.unsafe.imports_passed_through(): @@ -33,8 +32,6 @@ Workflow, WorkflowStep, ) - from ...common.storage_handler import load_from_blob_store_if_remote - from .remote import BaseRemoteModel, RemoteObject # TODO: Maybe we should use a library for this @@ -146,16 +143,16 @@ class ExecutionInput(BaseModel): task: TaskSpecDef agent: Agent agent_tools: list[Tool | CreateToolRequest] - arguments: dict[str, Any] | RemoteObject + arguments: dict[str, Any] # Not used at the moment user: User | None = None session: Session | None = None -class StepContext(BaseRemoteModel): - execution_input: ExecutionInput | RemoteObject - inputs: list[Any] | RemoteObject +class StepContext(BaseModel): + execution_input: ExecutionInput + inputs: list[Any] cursor: TransitionTarget @computed_field @@ -242,17 +239,9 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]: return dump | execution_input - async def prepare_for_step( - self, *args, include_remote: bool = True, **kwargs - ) -> dict[str, Any]: + async def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]: current_input = self.current_input inputs = self.inputs - if activity.in_activity() and include_remote: - await self.load_all() - inputs = await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) - current_input = await load_from_blob_store_if_remote(current_input) # Merge execution inputs into the dump dict dump = self.model_dump(*args, **kwargs) diff --git a/agents-api/agents_api/common/storage_handler.py b/agents-api/agents_api/common/storage_handler.py deleted file mode 100644 index 42beef270..000000000 --- a/agents-api/agents_api/common/storage_handler.py +++ /dev/null @@ -1,226 +0,0 @@ -import asyncio -import sys -from datetime import timedelta -from functools import wraps -from typing import Any, Callable - -from pydantic import BaseModel -from temporalio import workflow - -from ..activities.sync_items_remote import load_inputs_remote -from ..clients import async_s3 -from ..common.protocol.remote import BaseRemoteModel, RemoteObject -from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..env import ( - blob_store_cutoff_kb, - debug, - temporal_heartbeat_timeout, - temporal_schedule_to_close_timeout, - testing, - use_blob_store_for_temporal, -) -from ..worker.codec import deserialize, serialize - - -async def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - serialized = serialize(x) - data_size = sys.getsizeof(serialized) - - if data_size > blob_store_cutoff_kb * 1024: - key = await async_s3.add_object_with_hash(serialized) - return RemoteObject(key=key) - - return x - - -async def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - if isinstance(x, RemoteObject): - fetched = await async_s3.get_object(x.key) - return deserialize(fetched) - - elif isinstance(x, dict) and set(x.keys()) == {"bucket", "key"}: - fetched = await async_s3.get_object(x["key"]) - return deserialize(fetched) - - return x - - -# Decorator that automatically does two things: -# 1. store in blob store if the output of a function is large -# 2. load from blob store if the input is a RemoteObject - - -def auto_blob_store(f: Callable | None = None, *, deep: bool = False) -> Callable: - def auto_blob_store_decorator(f: Callable) -> Callable: - async def load_args( - args: list | tuple, kwargs: dict[str, Any] - ) -> tuple[list | tuple, dict[str, Any]]: - new_args = await asyncio.gather( - *[load_from_blob_store_if_remote(arg) for arg in args] - ) - kwargs_keys, kwargs_values = list(zip(*kwargs.items())) or ([], []) - new_kwargs = await asyncio.gather( - *[load_from_blob_store_if_remote(v) for v in kwargs_values] - ) - new_kwargs = dict(zip(kwargs_keys, new_kwargs)) - - if deep: - args = new_args - kwargs = new_kwargs - - new_args = [] - - for arg in args: - if isinstance(arg, list): - new_args.append( - await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in arg] - ) - ) - elif isinstance(arg, dict): - keys, values = list(zip(*arg.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_args.append(dict(zip(keys, values))) - - elif isinstance(arg, BaseRemoteModel): - new_args.append(await arg.unload_all()) - - elif isinstance(arg, BaseModel): - for field in arg.model_fields.keys(): - if isinstance(getattr(arg, field), RemoteObject): - setattr( - arg, - field, - await load_from_blob_store_if_remote( - getattr(arg, field) - ), - ) - elif isinstance(getattr(arg, field), list): - setattr( - arg, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(arg, field) - ] - ), - ) - elif isinstance(getattr(arg, field), BaseRemoteModel): - setattr( - arg, - field, - await getattr(arg, field).unload_all(), - ) - - new_args.append(arg) - - else: - new_args.append(arg) - - new_kwargs = {} - - for k, v in kwargs.items(): - if isinstance(v, list): - new_kwargs[k] = await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in v] - ) - - elif isinstance(v, dict): - keys, values = list(zip(*v.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_kwargs[k] = dict(zip(keys, values)) - - elif isinstance(v, BaseRemoteModel): - new_kwargs[k] = await v.unload_all() - - elif isinstance(v, BaseModel): - for field in v.model_fields.keys(): - if isinstance(getattr(v, field), RemoteObject): - setattr( - v, - field, - await load_from_blob_store_if_remote( - getattr(v, field) - ), - ) - elif isinstance(getattr(v, field), list): - setattr( - v, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(v, field) - ] - ), - ) - elif isinstance(getattr(v, field), BaseRemoteModel): - setattr( - v, - field, - await getattr(v, field).unload_all(), - ) - new_kwargs[k] = v - - else: - new_kwargs[k] = v - - return new_args, new_kwargs - - async def unload_return_value(x: Any | BaseRemoteModel) -> Any: - if isinstance(x, BaseRemoteModel): - await x.unload_all() - - return await store_in_blob_store_if_large(x) - - @wraps(f) - async def async_wrapper(*args, **kwargs) -> Any: - new_args, new_kwargs = await load_args(args, kwargs) - output = await f(*new_args, **new_kwargs) - - return await unload_return_value(output) - - return async_wrapper if use_blob_store_for_temporal else f - - return auto_blob_store_decorator(f) if f else auto_blob_store_decorator - - -def auto_blob_store_workflow(f: Callable) -> Callable: - @wraps(f) - async def wrapper(*args, **kwargs) -> Any: - keys = kwargs.keys() - values = [kwargs[k] for k in keys] - - loaded = await workflow.execute_activity( - load_inputs_remote, - args=[[*args, *values]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - - loaded_args = loaded[: len(args)] - loaded_kwargs = dict(zip(keys, loaded[len(args) :])) - - result = await f(*loaded_args, **loaded_kwargs) - - return result - - return wrapper if use_blob_store_for_temporal else f diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 8b9fd4dae..7baa24653 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -36,8 +36,8 @@ # Blob Store # ---------- -use_blob_store_for_temporal: bool = ( - env.bool("USE_BLOB_STORE_FOR_TEMPORAL", default=False) if not testing else False +use_blob_store_for_temporal: bool = testing or env.bool( + "USE_BLOB_STORE_FOR_TEMPORAL", default=False ) blob_store_bucket: str = env.str("BLOB_STORE_BUCKET", default="agents-api") diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 63fbdc940..058462cf8 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -8,10 +8,8 @@ from ...autogen.openapi_model import ( CreateSessionRequest, - ResourceCreatedResponse, Session, ) -from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class diff --git a/agents-api/agents_api/routers/healthz/check_health.py b/agents-api/agents_api/routers/healthz/check_health.py new file mode 100644 index 000000000..5a466ba39 --- /dev/null +++ b/agents-api/agents_api/routers/healthz/check_health.py @@ -0,0 +1,19 @@ +import logging +from uuid import UUID + +from ...models.agent.list_agents import list_agents as list_agents_query +from .router import router + + +@router.get("/healthz", tags=["healthz"]) +async def check_health() -> dict: + try: + # Check if the database is reachable + list_agents_query( + developer_id=UUID("00000000-0000-0000-0000-000000000000"), + ) + except Exception as e: + logging.error("An error occurred while checking health: %s", str(e)) + return {"status": "error", "message": "An internal error has occurred."} + + return {"status": "ok"} diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 6ea9239df..a76c13975 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -15,7 +15,7 @@ from ...activities.excecute_api_call import execute_api_call from ...activities.execute_integration import execute_integration from ...activities.execute_system import execute_system - from ...activities.sync_items_remote import load_inputs_remote, save_inputs_remote + from ...activities.sync_items_remote import save_inputs_remote from ...autogen.openapi_model import ( ApiCallDef, BaseIntegrationDef, @@ -214,16 +214,6 @@ async def run( # 3. Then, based on the outcome and step type, decide what to do next workflow.logger.info(f"Processing outcome for step {context.cursor.step}") - [outcome] = await workflow.execute_activity( - load_inputs_remote, - args=[[outcome]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - # Init state state = None diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 1d68322f5..b2df640a7 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -19,11 +19,9 @@ ExecutionInput, StepContext, ) - from ...common.storage_handler import auto_blob_store_workflow from ...env import task_max_parallelism, temporal_heartbeat_timeout -@auto_blob_store_workflow async def continue_as_child( execution_input: ExecutionInput, start: TransitionTarget, @@ -50,7 +48,6 @@ async def continue_as_child( ) -@auto_blob_store_workflow async def execute_switch_branch( *, context: StepContext, @@ -84,7 +81,6 @@ async def execute_switch_branch( ) -@auto_blob_store_workflow async def execute_if_else_branch( *, context: StepContext, @@ -123,7 +119,6 @@ async def execute_if_else_branch( ) -@auto_blob_store_workflow async def execute_foreach_step( *, context: StepContext, @@ -161,7 +156,6 @@ async def execute_foreach_step( return results -@auto_blob_store_workflow async def execute_map_reduce_step( *, context: StepContext, @@ -209,7 +203,6 @@ async def execute_map_reduce_step( return result -@auto_blob_store_workflow async def execute_map_reduce_step_parallel( *, context: StepContext, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 286fd10fb..430a2e3c5 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,6 +1,5 @@ import random import string -import time from uuid import UUID from fastapi.testclient import TestClient diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 171e56aa8..4673d6fc5 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -10,7 +10,6 @@ CreateOrUpdateSessionRequest, CreateSessionRequest, PatchSessionRequest, - ResourceCreatedResponse, ResourceDeletedResponse, ResourceUpdatedResponse, Session,