Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add indices migration #942

Merged
merged 14 commits into from
Dec 10, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def search_docs_by_embedding(
snippet_counter[count(item)] :=
owners[owner_type, owner_id_str],
owner_id = to_uuid(owner_id_str),
*docs {{
*docs:owner_id_metadata_doc_id_idx {{
owner_type,
owner_id,
doc_id: item,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def search_docs_by_text(

candidate[doc_id] :=
input[owner_type, owner_id],
*docs {{
*docs:owner_id_metadata_doc_id_idx {{
owner_type,
owner_id,
doc_id,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/execution/count_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def count_executions(

counter[count(id)] :=
input[task_id],
*executions {
*executions:task_id_execution_id_idx {
task_id,
execution_id: id,
}
Expand Down
279 changes: 73 additions & 206 deletions agents-api/agents_api/models/execution/create_execution_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,8 @@
from .update_execution import update_execution


def validate_transition_targets(data: CreateTransitionRequest) -> None:
# Make sure the current/next targets are valid
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
pass

### FIXME: HACK: Fix this and uncomment

### assert (
### data.next is None
### ), "Next target must be None for finish/finish_branch/error/cancelled"

case "init_branch" | "init":
assert (
data.next and data.current.step == data.next.step == 0
), "Next target must be same as current for init_branch/init and step 0"

case "wait":
assert data.next is None, "Next target must be None for wait"

case "resume" | "step":
assert data.next is not None, "Next target must be provided for resume/step"

if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"

case _:
raise ValueError(f"Invalid transition type: {data.type}")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)
@cozo_query
@increase_counter("create_execution_transition")
@beartype
def create_execution_transition(
def _create_execution_transition(
*,
developer_id: UUID,
execution_id: UUID,
Expand Down Expand Up @@ -140,7 +86,7 @@ def create_execution_transition(
]

last_transition_type[min_cost(type_created_at)] :=
*transitions {{
*transitions:execution_id_type_created_at_idx {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
Expand Down Expand Up @@ -225,167 +171,88 @@ def create_execution_transition(
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)
@cozo_query_async
@increase_counter("create_execution_transition_async")
@beartype
async def create_execution_transition_async(
*,
developer_id: UUID,
execution_id: UUID,
data: CreateTransitionRequest,
# Only one of these needed
transition_id: UUID | None = None,
task_token: str | None = None,
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()
data.metadata = data.metadata or {}
data.execution_id = execution_id

# Dump to json
if isinstance(data.output, list):
data.output = [
item.model_dump(mode="json") if hasattr(item, "model_dump") else item
for item in data.output
]

elif hasattr(data.output, "model_dump"):
data.output = data.output.model_dump(mode="json")

# TODO: This is a hack to make sure the transition is valid
# (parallel transitions are whack, we should do something better)
is_parallel = data.current.workflow.startswith("PAR:")

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})

# Parse the current and next targets
validate_transition_targets(data)
current_target = transition_data.pop("current")
next_target = transition_data.pop("next")

transition_data["current"] = (current_target["workflow"], current_target["step"])
transition_data["next"] = next_target and (
next_target["workflow"],
next_target["step"],
)

columns, transition_values = cozo_process_mutate_data(
{
**transition_data,
"task_token": str(task_token), # Converting to str for JSON serialisation
"transition_id": str(transition_id),
"execution_id": str(execution_id),
}
)

# Make sure the transition is valid
check_last_transition_query = f"""
valid_transition[start, end] <- [
{", ".join(f'["{start}", "{end}"]' for start, ends in valid_transitions.items() for end in ends)}
]
def validate_transition_targets(data: CreateTransitionRequest) -> None:
# Make sure the current/next targets are valid
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
pass

last_transition_type[min_cost(type_created_at)] :=
*transitions {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
}},
type_created_at = [type, -created_at]
### FIXME: HACK: Fix this and uncomment

matched[collect(last_type)] :=
last_transition_type[data],
last_type_data = first(data),
last_type = if(is_null(last_type_data), "init", last_type_data),
valid_transition[last_type, $next_type]
### assert (
### data.next is None
### ), "Next target must be None for finish/finish_branch/error/cancelled"
Comment on lines +175 to +186
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Make sure the current/next targets are valid
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
pass
last_transition_type[min_cost(type_created_at)] :=
*transitions {{
execution_id: to_uuid("{str(execution_id)}"),
type,
created_at,
}},
type_created_at = [type, -created_at]
### FIXME: HACK: Fix this and uncomment
matched[collect(last_type)] :=
last_transition_type[data],
last_type_data = first(data),
last_type = if(is_null(last_type_data), "init", last_type_data),
valid_transition[last_type, $next_type]
### assert (
### data.next is None
### ), "Next target must be None for finish/finish_branch/error/cancelled"
"""Validates the transition targets based on the transition type.
Args:
data (CreateTransitionRequest): The transition request data to validate.
Raises:
ValueError: If the transition type is invalid.
AssertionError: If the transition targets are not valid for the given type.
"""
match data.type:
case "finish_branch":
pass # TODO: Implement
case "finish" | "error" | "cancelled":
assert (
data.next is None
), "Next target must be None for finish/finish_branch/error/cancelled"

add docstring to validate_transition_targets function


?[valid] :=
matched[prev_transitions],
found = length(prev_transitions),
valid = if($next_type == "init", found == 0, found > 0),
assert(valid, "Invalid transition"),
case "init_branch" | "init":
assert (
data.next and data.current.step == data.next.step == 0
), "Next target must be same as current for init_branch/init and step 0"

:limit 1
"""
case "wait":
assert data.next is None, "Next target must be None for wait"

# Prepare the insert query
insert_query = f"""
?[{columns}] <- $transition_values
case "resume" | "step":
assert data.next is not None, "Next target must be provided for resume/step"

:insert transitions {{
{columns}
}}

:returning
"""
if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"

validate_status_query, update_execution_query, update_execution_params = (
"",
"",
{},
)
case _:
raise ValueError(f"Invalid transition type: {data.type}")

if update_execution_status:
assert (
task_id is not None
), "task_id is required for updating the execution status"

# Prepare the execution update query
[*_, validate_status_query, update_execution_query], update_execution_params = (
update_execution.__wrapped__(
developer_id=developer_id,
task_id=task_id,
execution_id=execution_id,
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
output=data.output if data.type != "error" else None,
error=str(data.output)
if data.type == "error" and data.output
else None,
create_execution_transition = rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)(
wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)(
cozo_query(
increase_counter("create_execution_transition")(
_create_execution_transition
)
)
)
)

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query if not is_parallel else None,
update_execution_query if not is_parallel else None,
check_last_transition_query if not is_parallel else None,
insert_query,
]

return (
queries,
{
"transition_values": transition_values,
"next_type": data.type,
"valid_transitions": valid_transitions,
**update_execution_params,
create_execution_transition_async = rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)(
wrap_in_class(
Transition,
transform=lambda d: {
**d,
"id": d["transition_id"],
"current": {"workflow": d["current"][0], "step": d["current"][1]},
"next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]},
},
one=True,
_kind="inserted",
)(
cozo_query_async(
increase_counter("create_execution_transition_async")(
_create_execution_transition
)
)
)
)
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/execution/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_execution(

?[id, task_id, status, input, output, error, session_id, metadata, created_at, updated_at] :=
input[execution_id],
*executions {
*executions:execution_id_status_idx {
task_id,
execution_id,
status,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_paused_execution_token(

check_status_query = """
?[execution_id, status] :=
*executions {
*executions:execution_id_status_idx {
execution_id,
status,
},
Expand All @@ -55,7 +55,7 @@ def get_paused_execution_token(
*executions {
execution_id,
},
*transitions {
*transitions:execution_id_type_created_at_idx {
execution_id,
created_at,
task_token,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def list_execution_transitions(

query = f"""
?[id, execution_id, type, current, next, output, metadata, updated_at, created_at] :=
*transitions {{
*transitions:execution_id_type_created_at_idx {{
execution_id,
transition_id: id,
type,
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/models/execution/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def update_execution(

validate_status_query = """
valid_status[count(status)] :=
*executions {
*executions:execution_id_status_idx {
status,
execution_id: to_uuid($execution_id),
whiterabbit1983 marked this conversation as resolved.
Show resolved Hide resolved
task_id: to_uuid($task_id),
},
status in $valid_previous_statuses

Expand Down Expand Up @@ -124,5 +125,6 @@ def update_execution(
"values": values,
"valid_previous_statuses": valid_previous_statuses,
"execution_id": str(execution_id),
"task_id": task_id,
},
)
Loading
Loading