diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 7323f73..7aad256 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -9,7 +9,11 @@ from constants import ANNOT_ICONS from utils.data_utils import extract_parameters_from_html, tiled_masks from utils.plot_utils import generate_notification -from utils.prefect import get_flow_run_name, query_flow_run, schedule_prefect_flow +from utils.prefect import ( + get_flow_run_name, + get_flow_runs_by_name, + schedule_prefect_flow, +) MODE = os.getenv("MODE", "") RESULTS_DIR = os.getenv("RESULTS_DIR", "") @@ -27,14 +31,18 @@ "image_name": "ghcr.io/mlexchange/mlex_dlsia_segmentation_prototype", "image_tag": "main", "command": 'python -c \\"import time; time.sleep(30)\\"', - "model_params": {"io_parameters": {"uid": "uid0001"}}, + "params": { + "io_parameters": {"uid_save": "uid0001", "uid_retrieve": "uid0001"} + }, "volumes": [f"{RESULTS_DIR}:/app/work/results"], }, { "image_name": "ghcr.io/mlexchange/mlex_dlsia_segmentation_prototype", "image_tag": "main", "command": 'python -c \\"import time; time.sleep(10)\\"', - "model_params": {"io_parameters": {"uid": "uid0001"}}, + "params": { + "io_parameters": {"uid_save": "uid0001", "uid_retrieve": "uid0001"} + }, "volumes": [f"{RESULTS_DIR}:/app/work/results"], }, ], @@ -47,7 +55,9 @@ "image_name": "ghcr.io/mlexchange/mlex_dlsia_segmentation_prototype", "image_tag": "main", "command": 'python -c \\"import time; time.sleep(30)\\"', - "model_params": {"io_parameters": {"uid": "uid0001"}}, + "params": { + "io_parameters": {"uid_save": "uid0001", "uid_retrieve": "uid0001"} + }, "volumes": [f"{RESULTS_DIR}:/app/work/results"], }, ], @@ -100,7 +110,7 @@ def run_train( FLOW_NAME, parameters=TRAIN_PARAMS_EXAMPLE, flow_run_name=f"{job_name} {current_time}", - tags=PREFECT_TAGS + ["train"], + tags=PREFECT_TAGS + ["train", project_name], ) job_message = f"Job has been succesfully submitted with uid: {job_uid}" notification_color = "indigo" @@ -123,9 +133,10 @@ def run_train( Output("notifications-container", "children", allow_duplicate=True), Input("run-inference", "n_clicks"), State("train-job-selector", "value"), + State("project-name-src", "value"), prevent_initial_call=True, ) -def run_inference(n_clicks, train_job_id): +def run_inference(n_clicks, train_job_id, project_name): """ This callback collects parameters from the UI and submits an inference job to Prefect. If the app is run from "dev" mode, then only a placeholder job_uid will be created. @@ -151,7 +162,7 @@ def run_inference(n_clicks, train_job_id): FLOW_NAME, parameters=INFERENCE_PARAMS_EXAMPLE, flow_run_name=f"{job_name} {current_time}", - tags=PREFECT_TAGS + ["inference"], + tags=PREFECT_TAGS + ["inference", project_name], ) job_message = ( f"Job has been succesfully submitted with uid: {job_uid}" @@ -195,7 +206,7 @@ def check_train_job(n_intervals): {"label": "✅ DLSIA CBA 03/11/2024 10:02AM", "value": "uid0003"}, ] else: - data = query_flow_run(PREFECT_TAGS + ["train"]) + data = get_flow_runs_by_name(tags=PREFECT_TAGS + ["train"]) return data @@ -204,8 +215,9 @@ def check_train_job(n_intervals): Output("inference-job-selector", "value"), Input("model-check", "n_intervals"), Input("train-job-selector", "value"), + State("project-name-src", "value"), ) -def check_inference_job(n_intervals, train_job_id): +def check_inference_job(n_intervals, train_job_id, project_name): """ This callback populates the inference job selector dropdown with job names and ids from Prefect. The list of jobs is filtered by the selected train job in the train job selector dropdown. @@ -240,8 +252,9 @@ def check_inference_job(n_intervals, train_job_id): }, ] else: - data = query_flow_run( - PREFECT_TAGS + ["inference"], flow_run_name=job_name + data = get_flow_runs_by_name( + flow_run_name=job_name, + tags=PREFECT_TAGS + ["inference", project_name], ) selected_value = None if len(data) == 0 else no_update return data, selected_value diff --git a/utils/prefect.py b/utils/prefect.py index d837309..4dba0b2 100644 --- a/utils/prefect.py +++ b/utils/prefect.py @@ -5,6 +5,7 @@ from prefect.client.schemas.filters import ( FlowRunFilter, FlowRunFilterName, + FlowRunFilterParentFlowRunId, FlowRunFilterTags, ) @@ -58,26 +59,47 @@ def get_flow_run_name(flow_run_id): return asyncio.run(_get_name(flow_run_id)) -async def _flow_run_query(tags, flow_run_name=None): - flow_runs_by_name = [] +async def _flow_run_query( + tags=None, flow_run_name=None, parent_flow_run_id=None, sort="START_TIME_DESC" +): + flow_run_filter_parent_flow_run_id = ( + FlowRunFilterParentFlowRunId(any_=[parent_flow_run_id]) + if parent_flow_run_id + else None + ) async with get_client() as client: flow_runs = await client.read_flow_runs( flow_run_filter=FlowRunFilter( name=FlowRunFilterName(like_=flow_run_name), + parent_flow_run_id=flow_run_filter_parent_flow_run_id, tags=FlowRunFilterTags(all_=tags), ), - sort="START_TIME_DESC", + sort=sort, ) - for flow_run in flow_runs: - if flow_run.state_name == "Failed": - flow_name = f"❌ {flow_run.name}" - elif flow_run.state_name == "Completed": - flow_name = f"✅ {flow_run.name}" - else: - flow_name = f"🕑 {flow_run.name}" - flow_runs_by_name.append({"label": flow_name, "value": str(flow_run.id)}) - return flow_runs_by_name + return flow_runs -def query_flow_run(tags, flow_run_name=None): - return asyncio.run(_flow_run_query(tags, flow_run_name)) +def get_flow_runs_by_name(flow_run_name=None, tags=None): + flow_runs_by_name = [] + flow_runs = asyncio.run(_flow_run_query(tags, flow_run_name=flow_run_name)) + for flow_run in flow_runs: + if flow_run.state_name in {"Failed", "Crashed"}: + flow_name = f"❌ {flow_run.name}" + elif flow_run.state_name == "Completed": + flow_name = f"✅ {flow_run.name}" + elif flow_run.state_name == "Cancelled": + flow_name = f"🚫 {flow_run.name}" + else: + flow_name = f"🕑 {flow_run.name}" + flow_runs_by_name.append({"label": flow_name, "value": str(flow_run.id)}) + return flow_runs_by_name + + +def get_children_flow_run_ids(parent_flow_run_id, sort="START_TIME_ASC"): + children_flow_runs = asyncio.run( + _flow_run_query(parent_flow_run_id=parent_flow_run_id, sort=sort) + ) + children_flow_run_ids = [ + str(children_flow_run.id) for children_flow_run in children_flow_runs + ] + return children_flow_run_ids