Skip to content

Commit

Permalink
fix(agents-api): Fix tests for workflows
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Dec 27, 2024
1 parent 4a3e14d commit 505a25d
Show file tree
Hide file tree
Showing 18 changed files with 47 additions and 182 deletions.
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.background import BackgroundTasks
from temporalio import activity

from ..app import lifespan
from ..app import app, lifespan
from ..autogen.openapi_model import (
ChatInput,
CreateDocRequest,
Expand All @@ -29,7 +29,7 @@
process_pool_executor = ProcessPoolExecutor()


@lifespan(container)
@lifespan(app, container) # Both are needed because we are using the routes
@beartype
async def execute_system(
context: StepContext,
Expand Down
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from temporalio import activity # noqa: E402
from thefuzz import fuzz # noqa: E402

from ...env import testing # noqa: E402
from ..utils import get_evaluator # noqa: E402


Expand Down Expand Up @@ -62,6 +61,7 @@ def _recursive_evaluate(expr, evaluator: SimpleEval):
raise ValueError(f"Invalid expression: {expr}")


@activity.defn
@beartype
async def base_evaluate(
exprs: Any,
Expand Down Expand Up @@ -100,12 +100,3 @@ async def base_evaluate(
# Recursively evaluate the expression
result = _recursive_evaluate(exprs, evaluator)
return result


# Note: This is here just for clarity. We could have just imported base_evaluate directly
# They do the same thing, so we dont need to mock the base_evaluate function
mock_base_evaluate = base_evaluate

base_evaluate = activity.defn(name="base_evaluate")(
base_evaluate if not testing else mock_base_evaluate
)
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


@activity.defn
@beartype
async def evaluate_step(
context: StepContext,
Expand All @@ -31,12 +31,3 @@ async def evaluate_step(
except BaseException as e:
activity.logger.error(f"Error in evaluate_step: {e}")
return StepOutcome(error=str(e) or repr(e))


# Note: This is here just for clarity. We could have just imported evaluate_step directly
# They do the same thing, so we dont need to mock the evaluate_step function
mock_evaluate_step = evaluate_step

evaluate_step = activity.defn(name="evaluate_step")(
evaluate_step if not testing else mock_evaluate_step
)
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
StepContext,
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@activity.defn
@beartype
async def for_each_step(context: StepContext) -> StepOutcome:
try:
Expand All @@ -23,12 +23,3 @@ async def for_each_step(context: StepContext) -> StepOutcome:
except BaseException as e:
activity.logger.error(f"Error in for_each_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported if_else_step directly
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = for_each_step

for_each_step = activity.defn(name="for_each_step")(
for_each_step if not testing else mock_if_else_step
)
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/get_value_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@
from temporalio import activity

from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing

# TODO: We should use this step to query the parent workflow and get the value from the workflow context
# SCRUM-1


@activity.defn
@beartype
async def get_value_step(
context: StepContext,
) -> StepOutcome:
key: str = context.current_step.get # noqa: F841
raise NotImplementedError("Not implemented yet")


# Note: This is here just for clarity. We could have just imported get_value_step directly
# They do the same thing, so we dont need to mock the get_value_step function
mock_get_value_step = get_value_step

get_value_step = activity.defn(name="get_value_step")(
get_value_step if not testing else mock_get_value_step
)
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
StepContext,
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@activity.defn
@beartype
async def if_else_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
Expand All @@ -27,12 +27,3 @@ async def if_else_step(context: StepContext) -> StepOutcome:
except BaseException as e:
activity.logger.error(f"Error in if_else_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported if_else_step directly
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = if_else_step

if_else_step = activity.defn(name="if_else_step")(
if_else_step if not testing else mock_if_else_step
)
9 changes: 1 addition & 8 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
StepOutcome,
)
from ...common.utils.template import render_template
from ...env import testing


@activity.defn
@beartype
async def log_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
Expand All @@ -30,10 +30,3 @@ async def log_step(context: StepContext) -> StepOutcome:
except BaseException as e:
activity.logger.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported log_step directly
# They do the same thing, so we dont need to mock the log_step function
mock_log_step = log_step

log_step = activity.defn(name="log_step")(log_step if not testing else mock_log_step)
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/map_reduce_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
StepContext,
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@activity.defn
@beartype
async def map_reduce_step(context: StepContext) -> StepOutcome:
try:
Expand All @@ -26,12 +26,3 @@ async def map_reduce_step(context: StepContext) -> StepOutcome:
except BaseException as e:
logging.error(f"Error in map_reduce_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported if_else_step directly
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = map_reduce_step

map_reduce_step = activity.defn(name="map_reduce_step")(
map_reduce_step if not testing else mock_if_else_step
)
12 changes: 2 additions & 10 deletions agents-api/agents_api/activities/task_steps/pg_query_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from ... import queries
from ...app import lifespan
from ...env import pg_dsn, testing
from ...env import pg_dsn
from ..container import container


@activity.defn
@lifespan(container)
@beartype
async def pg_query_step(
Expand All @@ -21,12 +22,3 @@ async def pg_query_step(
module = getattr(queries, module_name)
query = getattr(module, name)
return await query(**values, connection_pool=container.state.postgres_pool)


# Note: This is here just for clarity. We could have just imported pg_query_step directly
# They do the same thing, so we dont need to mock the pg_query_step function
mock_pg_query_step = pg_query_step

pg_query_step = activity.defn(name="pg_query_step")(
pg_query_step if not testing else mock_pg_query_step
)
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
StepContext,
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@activity.defn
@beartype
async def return_step(context: StepContext) -> StepOutcome:
try:
Expand All @@ -24,12 +24,3 @@ async def return_step(context: StepContext) -> StepOutcome:
except BaseException as e:
activity.logger.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported return_step directly
# They do the same thing, so we dont need to mock the return_step function
mock_return_step = return_step

return_step = activity.defn(name="return_step")(
return_step if not testing else mock_return_step
)
11 changes: 1 addition & 10 deletions agents-api/agents_api/activities/task_steps/set_value_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing

# TODO: We should use this step to signal to the parent workflow and set the value on the workflow context
# SCRUM-2


@activity.defn
@beartype
async def set_value_step(
context: StepContext,
Expand All @@ -29,12 +29,3 @@ async def set_value_step(
except BaseException as e:
activity.logger.error(f"Error in set_value_step: {e}")
return StepOutcome(error=str(e) or repr(e))


# Note: This is here just for clarity. We could have just imported set_value_step directly
# They do the same thing, so we dont need to mock the set_value_step function
mock_set_value_step = set_value_step

set_value_step = activity.defn(name="set_value_step")(
set_value_step if not testing else mock_set_value_step
)
9 changes: 1 addition & 8 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
StepContext,
StepOutcome,
)
from ...env import testing
from ..utils import get_evaluator


@activity.defn
@beartype
async def switch_step(context: StepContext) -> StepOutcome:
try:
Expand All @@ -34,10 +34,3 @@ async def switch_step(context: StepContext) -> StepOutcome:
except BaseException as e:
activity.logger.error(f"Error in switch_step: {e}")
return StepOutcome(error=str(e))


mock_switch_step = switch_step

switch_step = activity.defn(name="switch_step")(
switch_step if not testing else mock_switch_step
)
12 changes: 3 additions & 9 deletions agents-api/agents_api/activities/task_steps/transition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...clients.temporal import get_workflow_handle
from ...common.protocol.tasks import ExecutionInput, StepContext
from ...env import (
temporal_activity_after_retry_timeout,
testing,
transition_requests_per_minute,
)
from ...env import temporal_activity_after_retry_timeout, transition_requests_per_minute
from ...exceptions import LastErrorInput, TooManyRequestsError
from ...queries.executions.create_execution_transition import (
create_execution_transition,
Expand Down Expand Up @@ -74,9 +70,7 @@ async def transition_step(
return transition


# NOTE: Here because needed by a different step
original_transition_step = transition_step
mock_transition_step = transition_step

transition_step = activity.defn(name="transition_step")(
transition_step if not testing else mock_transition_step
)
transition_step = activity.defn(transition_step)
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing
from .base_evaluate import base_evaluate


@activity.defn
@beartype
async def wait_for_input_step(context: StepContext) -> StepOutcome:
try:
Expand All @@ -21,10 +21,3 @@ async def wait_for_input_step(context: StepContext) -> StepOutcome:
except BaseException as e:
activity.logger.error(f"Error in wait_for_input_step: {e}")
return StepOutcome(error=str(e))


mock_wait_for_input_step = wait_for_input_step

wait_for_input_step = activity.defn(name="wait_for_input_step")(
wait_for_input_step if not testing else mock_wait_for_input_step
)
13 changes: 1 addition & 12 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Callable

from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import TransitionTarget, YieldStep
from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome
from ...env import testing
from .base_evaluate import base_evaluate


@activity.defn
@beartype
async def yield_step(context: StepContext) -> StepOutcome:
try:
Expand Down Expand Up @@ -39,12 +37,3 @@ async def yield_step(context: StepContext) -> StepOutcome:
except BaseException as e:
activity.logger.error(f"Error in yield_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported yield_step directly
# They do the same thing, so we dont need to mock the yield_step function
mock_yield_step: Callable[[StepContext], StepOutcome] = yield_step

yield_step: Callable[[StepContext], StepOutcome] = activity.defn(name="yield_step")(
yield_step if not testing else mock_yield_step
)
Loading

0 comments on commit 505a25d

Please sign in to comment.