diff --git a/.env.example b/.env.example index dd72de4..35f00f4 100644 --- a/.env.example +++ b/.env.example @@ -30,6 +30,10 @@ SEG_TILED_URI= # Replace with your API key SEG_TILED_API_KEY= +# Directory where the segmentation application will store trained models and segmentation +# results. If using podman, this is the directory that will be mounted as a volume. +RESULTS_DIR=${PWD}/data/results + # Development environment variables, to be removed in upcoming versions DASH_DEPLOYMENT_LOC='Local' EXPORT_FILE_PATH='data/exported_annotations.json' @@ -38,3 +42,8 @@ MODE='dev' # Basic authentication for segmentation application when deploying on a publicly accessible server USER_NAME= USER_PASSWORD= + +# Prefect environment variables +PREFECT_API_URL=http://prefect:4200/api +FLOW_NAME="Parent flow/launch_parent_flow" +TIMEZONE="US/Pacific" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc2ea1e..6df81cc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,12 @@ repos: - id: check-symlinks - id: check-yaml - id: debug-statements + - repo: https://github.com/gitguardian/ggshield + rev: v1.25.0 + hooks: + - id: ggshield + language_version: python3 + stages: [commit] # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.2.0 diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index dca9537..7323f73 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -1,77 +1,78 @@ import os -import time +import traceback import uuid +from datetime import datetime -import dash_mantine_components as dmc -import requests +import pytz from dash import ALL, Input, Output, State, callback, no_update -from dash.exceptions import PreventUpdate +from constants import ANNOT_ICONS from utils.data_utils import extract_parameters_from_html, tiled_masks +from utils.plot_utils import generate_notification +from utils.prefect import get_flow_run_name, query_flow_run, schedule_prefect_flow MODE = os.getenv("MODE", "") +RESULTS_DIR = os.getenv("RESULTS_DIR", "") +FLOW_NAME = os.getenv("FLOW_NAME", "") +PREFECT_TAGS = os.getenv("PREFECT_TAGS", ["high-res-segmentation"]) -DEMO_WORKFLOW = { - "user_uid": "high_res_user", - "job_list": [ +# TODO: Retrieve timezone from browser +TIMEZONE = os.getenv("TIMEZONE", "US/Pacific") + +# TODO: Get model parameters from UI +TRAIN_PARAMS_EXAMPLE = { + "flow_type": "podman", + "params_list": [ + { + "image_name": "ghcr.io/mlexchange/mlex_dlsia_segmentation_prototype", + "image_tag": "main", + "command": 'python -c \\"import time; time.sleep(30)\\"', + "model_params": {"io_parameters": {"uid": "uid0001"}}, + "volumes": [f"{RESULTS_DIR}:/app/work/results"], + }, { - "mlex_app": "high-res-segmentation", - "description": "test_1", - "service_type": "backend", - "working_directory": "/data/mlex_repo/mlex_tiled/data", - "job_kwargs": { - "uri": "mlexchange1/random-forest-dc:1.1", - "type": "docker", - "cmd": 'python random_forest.py data/seg-results/spiral/image-train data/seg-results-test/spiral/feature data/seg-results/spiral/mask data/seg-results-test/spiral/model \'{"n_estimators": 30, "oob_score": true, "max_depth": 8}\'', # noqa: E501 - "kwargs": { - "job_type": "train", - "experiment_id": "123", - "dataset": "name_of_dataset", - "params": '{"n_estimators": 30, "oob_score": true, "max_depth": 8}', - }, - }, + "image_name": "ghcr.io/mlexchange/mlex_dlsia_segmentation_prototype", + "image_tag": "main", + "command": 'python -c \\"import time; time.sleep(10)\\"', + "model_params": {"io_parameters": {"uid": "uid0001"}}, + "volumes": [f"{RESULTS_DIR}:/app/work/results"], }, + ], +} + +INFERENCE_PARAMS_EXAMPLE = { + "flow_type": "podman", + "params_list": [ { - "mlex_app": "high-res-segmentation", - "description": "test_1", - "service_type": "backend", - "working_directory": "/data/mlex_repo/mlex_tiled/data", - "job_kwargs": { - "uri": "mlexchange1/random-forest-dc:1.1", - "type": "docker", - "cmd": "python segment.py data/data/20221222_085501_looking_from_above_spiralUP_CounterClockwise_endPointAtDoor_0-1000 data/seg-results-test/spiral/model/random-forest.model data/seg-results-test/spiral/output '{\"show_progress\": 1}'", # noqa: E501 - "kwargs": { - "job_type": "train", - "experiment_id": "124", - "dataset": "name_of_dataset", - "params": '{"show_progress": 1}', - }, - }, + "image_name": "ghcr.io/mlexchange/mlex_dlsia_segmentation_prototype", + "image_tag": "main", + "command": 'python -c \\"import time; time.sleep(30)\\"', + "model_params": {"io_parameters": {"uid": "uid0001"}}, + "volumes": [f"{RESULTS_DIR}:/app/work/results"], }, ], - "host_list": ["vaughan.als.lbl.gov"], - "dependencies": {"0": [], "1": [0]}, - "requirements": {"num_processors": 2, "num_gpus": 0, "num_nodes": 1}, } @callback( - Output("output-details", "children"), - Output("submitted-job-id", "data"), + Output("notifications-container", "children", allow_duplicate=True), Output("model-parameter-values", "data"), - Input("run-model", "n_clicks"), + Input("run-train", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), State("model-parameters", "children"), + State("job-name", "value"), + prevent_initial_call=True, ) -def run_job(n_clicks, global_store, all_annotations, project_name, model_parameters): +def run_train( + n_clicks, global_store, all_annotations, project_name, model_parameters, job_name +): """ - This callback collects parameters from the UI and submits a job to the computing api. + This callback collects parameters from the UI and submits a training job to Prefect. If the app is run from "dev" mode, then only a placeholder job_uid will be created. - The job_uid is saved in a dcc.Store for reference by the check_job callback below. - # TODO: Appropriately paramaterize the DEMO_WORKFLOW json depending on user inputs + # TODO: Appropriately paramaterize the job json depending on user inputs and relevant file paths """ input_params = {} @@ -84,73 +85,164 @@ def run_job(n_clicks, global_store, all_annotations, project_name, model_paramet global_store, all_annotations, project_name ) job_uid = str(uuid.uuid4()) - return ( - dmc.Text( - f"Workflow has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}", - size="sm", - ), - job_uid, - input_params, - ) + job_message = f"Workflow has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}" + notification_color = "indigo" else: mask_uri = tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) - job_submitted = requests.post( - "http://job-service:8080/api/v0/workflows", json=DEMO_WORKFLOW - ) - job_uid = job_submitted.json() - if job_submitted.status_code == 200: - return ( - dmc.Text( - f"Workflow has been succesfully submitted with uid: {job_uid}", - size="sm", - ), - job_uid, - input_params, + try: + # Schedule job + current_time = datetime.now(pytz.timezone(TIMEZONE)).strftime( + "%Y/%m/%d %H:%M:%S" ) - else: - return ( - dmc.Text( - f"Workflow presented error code: {job_submitted.status_code}", - size="sm", - ), - job_uid, - input_params, + job_uid = schedule_prefect_flow( + FLOW_NAME, + parameters=TRAIN_PARAMS_EXAMPLE, + flow_run_name=f"{job_name} {current_time}", + tags=PREFECT_TAGS + ["train"], ) - return no_update, no_update, input_params + job_message = f"Job has been succesfully submitted with uid: {job_uid}" + notification_color = "indigo" + except Exception as e: + # Print the traceback to the console + traceback.print_exc() + job_uid = None + job_message = f"Job presented error: {e}" + notification_color = "red" + + notification = generate_notification( + "Job Submission", notification_color, ANNOT_ICONS["submit"], job_message + ) + + return notification, input_params + return no_update, no_update @callback( - Output("output-details", "children", allow_duplicate=True), - Output("submitted-job-id", "data", allow_duplicate=True), - Input("submitted-job-id", "data"), - Input("model-check", "n_intervals"), + Output("notifications-container", "children", allow_duplicate=True), + Input("run-inference", "n_clicks"), + State("train-job-selector", "value"), prevent_initial_call=True, ) -def check_job(job_id, n_intervals): +def run_inference(n_clicks, train_job_id): """ - This callback checks to see if a job has completed successfully and will only - update if there is a job_id present in the submitted-job-id dcc.Store. Will - wait 3sec in "dev" mode to simulate. + This callback collects parameters from the UI and submits an inference job to Prefect. + If the app is run from "dev" mode, then only a placeholder job_uid will be created. - # TODO: Connect with the computing API when not in "dev" mode + # TODO: Appropriately paramaterize the job json depending on user inputs + and relevant file paths """ - output_layout = [ - dmc.Text( - f"Workflow {job_id} completed successfully. Click button below to view segmentation results.", - size="sm", - ), - ] + if n_clicks: + if MODE == "dev": + job_uid = str(uuid.uuid4()) + job_message = f"Job has been succesfully submitted with uid: {job_uid}" + notification_color = "indigo" + else: + if train_job_id is not None: + job_name = get_flow_run_name(train_job_id) + if job_name is not None: + try: + # Schedule job + current_time = datetime.now(pytz.timezone(TIMEZONE)).strftime( + "%Y/%m/%d %H:%M:%S" + ) + job_uid = schedule_prefect_flow( + FLOW_NAME, + parameters=INFERENCE_PARAMS_EXAMPLE, + flow_run_name=f"{job_name} {current_time}", + tags=PREFECT_TAGS + ["inference"], + ) + job_message = ( + f"Job has been succesfully submitted with uid: {job_uid}" + ) + notification_color = "indigo" + except Exception as e: + # Print the traceback to the console + traceback.print_exc() + job_uid = None + job_message = f"Job presented error: {e}" + else: + job_message = "Please select a valid train job" + notification_color = "red" + else: + job_message = "Please select a train job from the dropdown" + notification_color = "red" + notification = generate_notification( + "Job Submission", notification_color, ANNOT_ICONS["submit"], job_message + ) + + return notification + + return no_update + + +@callback( + Output("train-job-selector", "data"), + Input("model-check", "n_intervals"), +) +def check_train_job(n_intervals): + """ + This callback populates the train job selector dropdown with job names and ids from Prefect. + This callback displays the current status of the job as part of the job name in the dropdown. + In "dev" mode, the dropdown is populated with the sample data below. + """ if MODE == "dev": - if job_id: - time.sleep(3) - return ( - output_layout, - None, - ) - raise PreventUpdate + data = [ + {"label": "❌ DLSIA ABC 03/11/2024 15:38PM", "value": "uid0001"}, + {"label": "🕑 DLSIA XYC 03/11/2024 14:21PM", "value": "uid0002"}, + {"label": "✅ DLSIA CBA 03/11/2024 10:02AM", "value": "uid0003"}, + ] + else: + data = query_flow_run(PREFECT_TAGS + ["train"]) + return data + + +@callback( + Output("inference-job-selector", "data"), + Output("inference-job-selector", "value"), + Input("model-check", "n_intervals"), + Input("train-job-selector", "value"), +) +def check_inference_job(n_intervals, train_job_id): + """ + This callback populates the inference job selector dropdown with job names and ids from Prefect. + The list of jobs is filtered by the selected train job in the train job selector dropdown. + The selected value is set to None if the list of jobs is empty. + This callback displays the current status of the job as part of the job name in the dropdown. + In "dev" mode, the dropdown is populated with the sample data below. + """ + if MODE == "dev": + data = [ + {"label": "❌ DLSIA ABC 03/11/2024 15:38PM", "value": "uid0001"}, + {"label": "🕑 DLSIA XYC 03/11/2024 14:21PM", "value": "uid0002"}, + {"label": "✅ DLSIA CBA 03/11/2024 10:02AM", "value": "uid0003"}, + ] + return data, None else: - # TODO - connect with API - raise PreventUpdate + if train_job_id is not None: + job_name = get_flow_run_name(train_job_id) + if job_name is not None: + if MODE == "dev": + data = [ + { + "label": "❌ DLSIA ABC 03/11/2024 15:38PM", + "value": "uid0001", + }, + { + "label": "🕑 DLSIA XYC 03/11/2024 14:21PM", + "value": "uid0002", + }, + { + "label": "✅ DLSIA CBA 03/11/2024 10:02AM", + "value": "uid0003", + }, + ] + else: + data = query_flow_run( + PREFECT_TAGS + ["inference"], flow_run_name=job_name + ) + selected_value = None if len(data) == 0 else no_update + return data, selected_value + return [], None diff --git a/components/control_bar.py b/components/control_bar.py index 4809e5f..03d2343 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -607,15 +607,46 @@ def layout(): html.Div(id="model-parameters"), dcc.Store(id="model-parameter-values", data={}), dmc.Space(h=25), - dmc.Center( - dmc.Button( - "Run model", - id="run-model", - variant="light", - style={"width": "160px", "margin": "5px"}, - ) + ControlItem( + "Name", + "job-name-input", + dmc.TextInput( + placeholder="Name your job...", + id="job-name", + ), + ), + dmc.Space(h=10), + dmc.Button( + "Train", + id="run-train", + variant="light", + style={"width": "100%", "margin": "5px"}, + ), + dmc.Space(h=10), + ControlItem( + "Train Jobs", + "selected-train-job", + dmc.Select( + placeholder="Select a job...", + id="train-job-selector", + ), + ), + dmc.Space(h=10), + dmc.Button( + "Inference", + id="run-inference", + variant="light", + style={"width": "100%", "margin": "5px"}, + ), + dmc.Space(h=10), + ControlItem( + "Inference Jobs", + "selected-inference-job", + dmc.Select( + placeholder="Select a job...", + id="inference-job-selector", + ), ), - html.Div(id="output-details"), dmc.Space(h=25), dmc.Switch( id="show-result-overlay-toggle", diff --git a/constants.py b/constants.py index 544a1ed..f7b441f 100644 --- a/constants.py +++ b/constants.py @@ -29,6 +29,7 @@ "export-annotation": "entypo:export", "no-more-slices": "pajamas:warning-solid", "export": "entypo:export", + "submit": "formkit:submit", } ANNOT_NOTIFICATION_MSGS = { diff --git a/docker-compose.yml b/docker-compose.yml index 5f1a7d5..127c653 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,11 +9,15 @@ services: DATA_TILED_URI: '${DATA_TILED_URI}' DATA_TILED_API_KEY: '${DATA_TILED_API_KEY}' MASK_TILED_URI: '${MASK_TILED_URI}' - MASK_TILED_API_KEY: '${TILED_API_KEY}' + MASK_TILED_API_KEY: '${MASK_TILED_API_KEY}' SEG_TILED_URI: '${SEG_TILED_URI}' SEG_TILED_API_KEY: '${SEG_TILED_API_KEY}' USER_NAME: '${USER_NAME}' USER_PASSWORD: '${USER_PASSWORD}' + RESULTS_DIR: '${RESULTS_DIR}' + PREFECT_API_URL: '${PREFECT_API_URL}' + FLOW_NAME: '${FLOW_NAME}' + TIMEZONE: "${TIMEZONE}" volumes: - ./app.py:/app/app.py - ./constants.py:/app/constants.py diff --git a/requirements.txt b/requirements.txt index 6eca7ec..2dcbdbe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ numpy packaging==23.1 pandas==2.0.2 plotly==5.17.0 +prefect-client==2.14.21 python-dateutil==2.8.2 pytz==2023.3 six==1.16.0 diff --git a/utils/prefect.py b/utils/prefect.py new file mode 100644 index 0000000..d837309 --- /dev/null +++ b/utils/prefect.py @@ -0,0 +1,83 @@ +import asyncio +from typing import Optional + +from prefect import get_client +from prefect.client.schemas.filters import ( + FlowRunFilter, + FlowRunFilterName, + FlowRunFilterTags, +) + + +async def _schedule( + deployment_name: str, + flow_run_name: str, + parameters: Optional[dict] = None, + tags: Optional[list] = [], +): + async with get_client() as client: + deployment = await client.read_deployment_by_name(deployment_name) + assert ( + deployment + ), f"No deployment found in config for deployment_name {deployment_name}" + flow_run = await client.create_flow_run_from_deployment( + deployment.id, + parameters=parameters, + name=flow_run_name, + tags=tags, + ) + return flow_run.id + + +def schedule_prefect_flow( + deployment_name: str, + parameters: Optional[dict] = None, + flow_run_name: Optional[str] = None, + tags: Optional[list] = [], +): + if not flow_run_name: + model_name = parameters["model_name"] + flow_run_name = f"{deployment_name}: {model_name}" + flow_run_id = asyncio.run( + _schedule(deployment_name, flow_run_name, parameters, tags) + ) + return flow_run_id + + +async def _get_name(flow_run_id): + async with get_client() as client: + flow_run = await client.read_flow_run(flow_run_id) + if flow_run.state.is_final(): + if flow_run.state.is_completed(): + return flow_run.name + return None + + +def get_flow_run_name(flow_run_id): + """Retrieves the name of the flow with the given id.""" + return asyncio.run(_get_name(flow_run_id)) + + +async def _flow_run_query(tags, flow_run_name=None): + flow_runs_by_name = [] + async with get_client() as client: + flow_runs = await client.read_flow_runs( + flow_run_filter=FlowRunFilter( + name=FlowRunFilterName(like_=flow_run_name), + tags=FlowRunFilterTags(all_=tags), + ), + sort="START_TIME_DESC", + ) + for flow_run in flow_runs: + if flow_run.state_name == "Failed": + flow_name = f"❌ {flow_run.name}" + elif flow_run.state_name == "Completed": + flow_name = f"✅ {flow_run.name}" + else: + flow_name = f"🕑 {flow_run.name}" + flow_runs_by_name.append({"label": flow_name, "value": str(flow_run.id)}) + return flow_runs_by_name + + +def query_flow_run(tags, flow_run_name=None): + return asyncio.run(_flow_run_query(tags, flow_run_name))