Skip to content

Commit

Permalink
add mutation for cancel task
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinriverlane committed Sep 25, 2024
1 parent 84fb0f9 commit 9589de4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 26 deletions.
61 changes: 35 additions & 26 deletions pyaqueduct/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from pyaqueduct.client.experiment_types import ExperimentData, ExperimentsInfo, TagsData
from pyaqueduct.client.extension_types import (
ExtensionCancelResultData,
ExtensionData,
ExtensionExecutionResultData,
)
Expand All @@ -30,6 +31,7 @@
)
from pyaqueduct.schemas.mutations import (
add_tags_to_experiment_mutation,
cancel_task_mutation,
create_experiment_mutation,
execute_extension_action_mutation,
remove_experiment_mutation,
Expand Down Expand Up @@ -134,9 +136,7 @@ def create_experiment(
create_experiment_mutation,
{"title": title, "description": description, "tags": tags or []},
)
experiment_obj = ExperimentData.from_dict(
data["createExperiment"] # pylint: disable=unsubscriptable-object
)
experiment_obj = ExperimentData.from_dict(data["createExperiment"])
logging.info("Created experiment - %s - %s", experiment_obj.uuid, experiment_obj.title)
return experiment_obj

Expand All @@ -163,9 +163,7 @@ def update_experiment(
"description": description,
},
)
experiment_obj = ExperimentData.from_dict(
data["updateExperiment"] # pylint: disable=unsubscriptable-object
)
experiment_obj = ExperimentData.from_dict(data["updateExperiment"])
logging.info("Updated experiment - %s", experiment_obj.uuid)
return experiment_obj

Expand Down Expand Up @@ -204,9 +202,7 @@ def get_experiments(
"tags": tags,
},
)
experiments_obj = ExperimentsInfo.from_dict(
data["experiments"] # pylint: disable=unsubscriptable-object
)
experiments_obj = ExperimentsInfo.from_dict(data["experiments"])
logging.info(
"Fetched %s experiments, total %s experiments",
len(experiments_obj.experiments),
Expand All @@ -229,9 +225,7 @@ def get_experiment(self, experiment_uuid: UUID) -> ExperimentData:
get_experiment_query,
{"type": "UUID", "value": str(experiment_uuid)},
)
experiment_obj = ExperimentData.from_dict(
data["experiment"] # pylint: disable=unsubscriptable-object
)
experiment_obj = ExperimentData.from_dict(data["experiment"])
logging.info("Fetched experiment - %s", experiment_obj.title)
return experiment_obj

Expand All @@ -250,9 +244,7 @@ def get_experiment_by_eid(self, eid: str) -> ExperimentData:
get_experiment_query,
{"type": "EID", "value": eid},
)
experiment_obj = ExperimentData.from_dict(
data["experiment"] # pylint: disable=unsubscriptable-object
)
experiment_obj = ExperimentData.from_dict(data["experiment"])
logging.info("Fetched experiment - %s", experiment_obj.title)
return experiment_obj

Expand All @@ -272,9 +264,7 @@ def add_tags_to_experiment(self, experiment_uuid: UUID, tags: List[str]) -> Expe
add_tags_to_experiment_mutation,
{"uuid": str(experiment_uuid), "tags": tags},
)
experiment_obj = ExperimentData.from_dict(
data["addTagsToExperiment"] # pylint: disable=unsubscriptable-object
)
experiment_obj = ExperimentData.from_dict(data["addTagsToExperiment"])
logging.info("Added tags %s to experiment <%s>", tags, experiment_obj.title)
return experiment_obj

Expand Down Expand Up @@ -307,9 +297,7 @@ def remove_tag_from_experiment(self, experiment_uuid: UUID, tag: str) -> Experim
{"uuid": str(experiment_uuid), "tag": tag},
)

experiment_obj = ExperimentData.from_dict(
data["removeTagFromExperiment"] # pylint: disable=unsubscriptable-object
)
experiment_obj = ExperimentData.from_dict(data["removeTagFromExperiment"])
logging.info("Removed tag %s from experiment <%s>", tag, experiment_obj.title)
return experiment_obj

Expand Down Expand Up @@ -349,7 +337,7 @@ def get_tags(self, limit: int, offset: int, dangling: bool = True) -> TagsData:
get_all_tags_query,
{"limit": limit, "offset": offset, "dangling": dangling},
)
tags_obj = TagsData.from_dict(data["tags"]) # pylint: disable=unsubscriptable-object
tags_obj = TagsData.from_dict(data["tags"])
logging.info("Fetched %s tags, total %s tags", len(tags_obj.tags), tags_obj.total_count)
return tags_obj

Expand Down Expand Up @@ -439,7 +427,7 @@ def get_extensions(self) -> List[ExtensionData]:
extensions_list = list(
map(
ExtensionData.from_dict,
extensions_response["extensions"], # pylint: disable=unsubscriptable-object
extensions_response["extensions"],
)
)
logging.info("Fetched %s extensions", len(extensions_list))
Expand Down Expand Up @@ -480,13 +468,34 @@ def execute_extension_action(
error.errors if error.errors else "Unknown error occurred in the remote operation."
) from error

result = ExtensionExecutionResultData.from_dict(
extension_result["executeExtension"] # pylint: disable=unsubscriptable-object
)
result = ExtensionExecutionResultData.from_dict(extension_result["executeExtension"])
logging.info(
"Executed a %s / %s extension action with result code %d",
extension,
action,
result.returnCode,
)
return result

def cancel_task(self, task_id: str) -> ExtensionCancelResultData:
"""Stops and cancels task running in Celery
Args:
task_id: Task identifier
"""
try:
revoke_result = self._gql_client.execute(
cancel_task_mutation,
variable_values={"taskId": task_id},
)
except gql_exceptions.TransportServerError as error:
if error.code:
process_response_common(codes(error.code))
raise
except gql_exceptions.TransportQueryError as error:
raise RemoteOperationError(
error.errors if error.errors else "Unknown error occoured in the remote operation"
)

result = ExtensionCancelResultData.from_dict(revoke_result["cancelTask"])
return result
19 changes: 19 additions & 0 deletions pyaqueduct/client/extension_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,22 @@ def from_dict(cls, data: dict) -> ExtensionExecutionResultData:
Object populated with server response data.
"""
return ExtensionExecutionResultData(**data)


class ExtensionCancelResultData(BaseModel):
"""Results for task cancellation"""

returnCode: int
message: str

@classmethod
def from_dict(cls, data: dict) -> ExtensionCancelResultData:
"""Compose an object from a server response.
Args:
data: server response
Returns:
Object populated with server response data
"""
return ExtensionCancelResultData(**data)
18 changes: 18 additions & 0 deletions pyaqueduct/schemas/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,21 @@
}
"""
)

cancel_task_mutation = gql(
"""
mutation CancelTask (
$taskId: UUID!
) {
cancelTask(
cancelTaskInput: {
taskId: $taskId
}
) {
taskId
resultCode
taskStatus
}
}
"""
)

1 comment on commit 9589de4

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
pyaqueduct
   experiment.py56198%90
pyaqueduct/client
   client.py1574869%54–60, 110–115, 243–249, 314–322, 365–366, 403–404, 418–423, 462–467, 486–501
   experiment_types.py39197%20
   extension_types.py38197%89
TOTAL4025187% 

Tests Skipped Failures Errors Time
24 0 💤 0 ❌ 0 🔥 1.458s ⏱️

Please sign in to comment.