Skip to content

Commit

Permalink
Merge branch 'dev' into d/cookbookfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Oct 16, 2024
2 parents 00cd4f4 + 5262a5b commit 002996c
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import base64
from typing import Any

from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.protocol.tasks import StepContext
from ...common.storage_handler import auto_blob_store
from .transition_step import original_transition_step


@activity.defn
@auto_blob_store
@beartype
async def raise_complete_async(context: StepContext, output: StepOutcome) -> None:
async def raise_complete_async(context: StepContext, output: Any) -> None:
activity_info = activity.info()

captured_token = base64.b64encode(activity_info.task_token).decode("ascii")
Expand Down
14 changes: 13 additions & 1 deletion agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from uuid import UUID

from temporalio.client import Client, TLSConfig
from temporalio.common import (
SearchAttributeKey,
SearchAttributePair,
TypedSearchAttributes,
)

from ..autogen.openapi_model import TransitionTarget
from ..common.protocol.tasks import ExecutionInput
Expand Down Expand Up @@ -48,6 +53,7 @@ async def run_task_execution_workflow(
from ..workflows.task_execution import TaskExecutionWorkflow

client = client or (await get_client())
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")

return await client.start_workflow(
TaskExecutionWorkflow.run,
Expand All @@ -56,7 +62,13 @@ async def run_task_execution_workflow(
id=str(job_id),
run_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
# TODO: Should add search_attributes for queryability
search_attributes=TypedSearchAttributes(
[
SearchAttributePair(
execution_id_key, str(execution_input.execution.id)
),
]
),
)


Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/common/utils/yaml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from io import StringIO
from typing import Any

import yaml
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)


# Blob Store
# ----------
use_blob_store_for_temporal: bool = env.bool(
Expand All @@ -37,6 +38,7 @@
s3_access_key: str | None = env.str("S3_ACCESS_KEY", default=None)
s3_secret_key: str | None = env.str("S3_SECRET_KEY", default=None)


# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/models/execution/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
##########
# Consts #
##########

OUTPUT_UNNEST_KEY = "$$e7w_unnest$$"
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/execution/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
verify_developer_owns_resource_query,
wrap_in_class,
)
from .constants import OUTPUT_UNNEST_KEY

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")
Expand Down Expand Up @@ -59,6 +60,9 @@ def create_execution(
data["metadata"] = data.get("metadata", {})
execution_data = data

if execution_data["output"] is not None and not isinstance(execution_data["output"], dict):
execution_data["output"] = {OUTPUT_UNNEST_KEY: execution_data["output"]}

columns, values = cozo_process_mutate_data(
{
**execution_data,
Expand Down
10 changes: 9 additions & 1 deletion agents-api/agents_api/models/execution/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
rewrap_exceptions,
wrap_in_class,
)
from .constants import OUTPUT_UNNEST_KEY

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")
Expand All @@ -26,7 +27,14 @@
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(Execution, one=True)
@wrap_in_class(
Execution,
one=True,
transform=lambda d: {
**d,
"output": d["output"][OUTPUT_UNNEST_KEY] if OUTPUT_UNNEST_KEY in d["output"] else d["output"],
},
)
@cozo_query
@beartype
def get_execution(
Expand Down
11 changes: 10 additions & 1 deletion agents-api/agents_api/models/execution/list_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
verify_developer_owns_resource_query,
wrap_in_class,
)
from .constants import OUTPUT_UNNEST_KEY

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")
Expand All @@ -27,7 +28,15 @@
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(Execution)
@wrap_in_class(
Execution,
transform=lambda d: {
**d,
"output": d["output"][OUTPUT_UNNEST_KEY]
if OUTPUT_UNNEST_KEY in d["output"]
else d["output"],
},
)
@cozo_query
@beartype
def list_executions(
Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/models/execution/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
verify_developer_owns_resource_query,
wrap_in_class,
)
from .constants import OUTPUT_UNNEST_KEY

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")
Expand Down Expand Up @@ -50,7 +51,7 @@ def update_execution(
task_id: UUID,
execution_id: UUID,
data: UpdateExecutionRequest,
output: dict | None = None,
output: dict | Any | None = None,
error: str | None = None,
) -> tuple[list[str], dict]:
developer_id = str(developer_id)
Expand All @@ -63,6 +64,9 @@ def update_execution(

execution_data: dict = data.model_dump(exclude_none=True)

if output is not None and not isinstance(output, dict):
output: dict = {OUTPUT_UNNEST_KEY: output}

columns, values = cozo_process_mutate_data(
{
**execution_data,
Expand Down

0 comments on commit 002996c

Please sign in to comment.