Skip to content

Commit

Permalink
Merge pull request #977 from julep-ai/f/simplify-blob-store
Browse files Browse the repository at this point in the history
feat(agents-api): Remove auto_blob_store in favor of interceptor based system
  • Loading branch information
creatorrr authored Dec 21, 2024
2 parents 0e32cbe + 76819e1 commit 1982583
Show file tree
Hide file tree
Showing 37 changed files with 181 additions and 484 deletions.
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
from temporalio import activity

from ..clients import cozo, litellm
from ..common.storage_handler import auto_blob_store
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload


@auto_blob_store(deep=True)
@beartype
async def embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/excecute_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from temporalio import activity

from ..autogen.openapi_model import ApiCallDef
from ..common.storage_handler import auto_blob_store
from ..env import testing


Expand All @@ -20,7 +19,6 @@ class RequestArgs(TypedDict):
headers: Optional[dict[str, str]]


@auto_blob_store(deep=True)
@beartype
async def execute_api_call(
api_call: ApiCallDef,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from ..clients import integrations
from ..common.exceptions.tools import IntegrationExecutionException
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..common.storage_handler import auto_blob_store
from ..env import testing
from ..models.tools import get_tool_args_from_metadata


@auto_blob_store(deep=True)
@beartype
async def execute_integration(
context: StepContext,
Expand Down
7 changes: 1 addition & 6 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
VectorDocSearchRequest,
)
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote
from ..env import testing
from ..queries.developer import get_developer
from ..queries.developers import get_developer
from .utils import get_handler

# For running synchronous code in the background
process_pool_executor = ProcessPoolExecutor()


@auto_blob_store(deep=True)
@beartype
async def execute_system(
context: StepContext,
Expand All @@ -37,9 +35,6 @@ async def execute_system(
"""Execute a system call with the appropriate handler and transformed arguments."""
arguments: dict[str, Any] = system.arguments or {}

if set(arguments.keys()) == {"bucket", "key"}:
arguments = await load_from_blob_store_if_remote(arguments)

if not isinstance(context.execution_input, ExecutionInput):
raise TypeError("Expected ExecutionInput type for context.execution_input")

Expand Down
12 changes: 4 additions & 8 deletions agents-api/agents_api/activities/sync_items_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@

@beartype
async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]:
from ..common.storage_handler import store_in_blob_store_if_large
from ..common.interceptors import offload_if_large

return await asyncio.gather(
*[store_in_blob_store_if_large(input) for input in inputs]
)
return await asyncio.gather(*[offload_if_large(input) for input in inputs])


@beartype
async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]:
from ..common.storage_handler import load_from_blob_store_if_remote
from ..common.interceptors import load_if_remote

return await asyncio.gather(
*[load_from_blob_store_if_remote(input) for input in inputs]
)
return await asyncio.gather(*[load_if_remote(input) for input in inputs])


save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn)
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from temporalio import activity # noqa: E402
from thefuzz import fuzz # noqa: E402

from ...common.storage_handler import auto_blob_store # noqa: E402
from ...env import testing # noqa: E402
from ..utils import get_evaluator # noqa: E402

Expand Down Expand Up @@ -63,7 +62,6 @@ def _recursive_evaluate(expr, evaluator: SimpleEval):
raise ValueError(f"Invalid expression: {expr}")


@auto_blob_store(deep=True)
@beartype
async def base_evaluate(
exprs: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from temporalio import activity

from ... import models
from ...common.storage_handler import auto_blob_store
from ...env import testing


@auto_blob_store(deep=True)
@beartype
async def cozo_query_step(
query_name: str,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing


@auto_blob_store(deep=True)
@beartype
async def evaluate_step(
context: StepContext,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def for_each_step(context: StepContext) -> StepOutcome:
try:
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/activities/task_steps/get_value_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from temporalio import activity

from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing


# TODO: We should use this step to query the parent workflow and get the value from the workflow context
# SCRUM-1
@auto_blob_store(deep=True)


@beartype
async def get_value_step(
context: StepContext,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def if_else_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import testing


@auto_blob_store(deep=True)
@beartype
async def log_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def map_reduce_step(context: StepContext) -> StepOutcome:
try:
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import debug
from .base_evaluate import base_evaluate
Expand Down Expand Up @@ -62,7 +61,6 @@ def format_tool(tool: Tool) -> dict:


@activity.defn
@auto_blob_store(deep=True)
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from ...autogen.openapi_model import CreateTransitionRequest
from ...common.protocol.tasks import StepContext
from ...common.storage_handler import auto_blob_store
from .transition_step import original_transition_step


@activity.defn
@auto_blob_store(deep=True)
@beartype
async def raise_complete_async(context: StepContext, output: Any) -> None:
activity_info = activity.info()
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def return_step(context: StepContext) -> StepOutcome:
try:
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/activities/task_steps/set_value_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing


# TODO: We should use this step to signal to the parent workflow and set the value on the workflow context
# SCRUM-2
@auto_blob_store(deep=True)


@beartype
async def set_value_step(
context: StepContext,
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store
from ...env import testing
from ..utils import get_evaluator


@auto_blob_store(deep=True)
@beartype
async def switch_step(context: StepContext) -> StepOutcome:
try:
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
StepContext,
StepOutcome,
)
from ...common.storage_handler import auto_blob_store


# FIXME: This shouldn't be here.
Expand Down Expand Up @@ -47,7 +46,6 @@ def construct_tool_call(


@activity.defn
@auto_blob_store(deep=True)
@beartype
async def tool_call_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ToolCallStep)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...clients.temporal import get_workflow_handle
from ...common.protocol.tasks import ExecutionInput, StepContext
from ...common.storage_handler import load_from_blob_store_if_remote
from ...env import (
temporal_activity_after_retry_timeout,
testing,
Expand Down Expand Up @@ -48,11 +47,6 @@ async def transition_step(
TaskExecutionWorkflow.set_last_error, LastErrorInput(last_error=None)
)

# Load output from blob store if it is a remote object
transition_info.output = await load_from_blob_store_if_remote(
transition_info.output
)

if not isinstance(context.execution_input, ExecutionInput):
raise TypeError("Expected ExecutionInput type for context.execution_input")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def wait_for_input_step(context: StepContext) -> StepOutcome:
try:
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

from ...autogen.openapi_model import TransitionTarget, YieldStep
from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome
from ...common.storage_handler import auto_blob_store
from ...env import testing
from .base_evaluate import base_evaluate


@auto_blob_store(deep=True)
@beartype
async def yield_step(context: StepContext) -> StepOutcome:
try:
Expand Down
3 changes: 1 addition & 2 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
model_validator,
)

from ..common.storage_handler import RemoteObject
from ..common.utils.datetime import utcnow
from .Agents import *
from .Chat import *
Expand Down Expand Up @@ -358,7 +357,7 @@ def validate_subworkflows(self):


class SystemDef(SystemDef):
arguments: dict[str, Any] | None | RemoteObject = None
arguments: dict[str, Any] | None = None


class CreateTransitionRequest(Transition):
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/clients/async_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)


@alru_cache(maxsize=1024)
async def list_buckets() -> list[str]:
session = get_session()

Expand Down
9 changes: 6 additions & 3 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from datetime import timedelta
from uuid import UUID

Expand All @@ -12,9 +13,9 @@
from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig

from ..autogen.openapi_model import TransitionTarget
from ..common.interceptors import offload_if_large
from ..common.protocol.tasks import ExecutionInput
from ..common.retry_policies import DEFAULT_RETRY_POLICY
from ..common.storage_handler import store_in_blob_store_if_large
from ..env import (
temporal_client_cert,
temporal_metrics_bind_host,
Expand Down Expand Up @@ -96,8 +97,10 @@ async def run_task_execution_workflow(
client = client or (await get_client())
execution_id = execution_input.execution.id
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")
execution_input.arguments = await store_in_blob_store_if_large(
execution_input.arguments

old_args = execution_input.arguments
execution_input.arguments = await asyncio.gather(
*[offload_if_large(arg) for arg in old_args]
)

return await client.start_workflow(
Expand Down
Loading

0 comments on commit 1982583

Please sign in to comment.