Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boilerplate for triggering ml job run #91

Merged
merged 11 commits into from
Sep 21, 2023
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ pip install -r requirements-dev.txt
2. Configure a connection to the Tiled server via a `.env` file with the following environment variables:

```
TILED_URI=https://mlex-segmentation.als.lbl.gov
TILED_URI='https://mlex-segmentation.als.lbl.gov'
API_KEY=<key-provided-on-request>
MODE='dev'
```

3. Start a local server:
Expand Down
175 changes: 140 additions & 35 deletions callbacks/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,155 @@
import os
import time
import uuid

import dash_mantine_components as dmc
import numpy as np
import requests
from dash import ALL, Input, Output, State, callback, no_update
from dash.exceptions import PreventUpdate

from utils import data_utils
from utils.annotations import Annotations
from utils.data_utils import get_data_sequence_by_name

MODE = os.getenv("MODE", "")

DEMO_WORKFLOW = {
"user_uid": "high_res_user",
"job_list": [
{
"mlex_app": "high-res-segmentation",
"description": "test_1",
"service_type": "backend",
"working_directory": "/data/mlex_repo/mlex_tiled/data",
"job_kwargs": {
"uri": "mlexchange1/random-forest-dc:1.1",
"type": "docker",
"cmd": 'python random_forest.py data/seg-results/spiral/image-train data/seg-results-test/spiral/feature data/seg-results/spiral/mask data/seg-results-test/spiral/model \'{"n_estimators": 30, "oob_score": true, "max_depth": 8}\'',
"kwargs": {
"job_type": "train",
"experiment_id": "123",
"dataset": "name_of_dataset",
"params": '{"n_estimators": 30, "oob_score": true, "max_depth": 8}',
},
},
},
{
"mlex_app": "high-res-segmentation",
"description": "test_1",
"service_type": "backend",
"working_directory": "/data/mlex_repo/mlex_tiled/data",
"job_kwargs": {
"uri": "mlexchange1/random-forest-dc:1.1",
"type": "docker",
"cmd": "python segment.py data/data/20221222_085501_looking_from_above_spiralUP_CounterClockwise_endPointAtDoor_0-1000 data/seg-results-test/spiral/model/random-forest.model data/seg-results-test/spiral/output '{\"show_progress\": 1}'",
"kwargs": {
"job_type": "train",
"experiment_id": "124",
"dataset": "name_of_dataset",
"params": '{"show_progress": 1}',
},
},
},
],
"host_list": ["vaughan.als.lbl.gov"],
"dependencies": {"0": [], "1": [0]},
"requirements": {"num_processors": 2, "num_gpus": 0, "num_nodes": 1},
}


# NEXT STEPS:
# - this function returns a job ID, which would be associated with the workflow run on vaughan
# - then we need another callback to pick up this ID and start polling for successful output
@callback(
Output("output-placeholder", "children"),
Output("output-details", "children"),
Output("submitted-job-id", "data"),
Input("run-model", "n_clicks"),
State("annotation-store", "data"),
State({"type": "annotation-class-store", "index": ALL}, "data"),
State("project-name-src", "value"),
)
def run_job(n_clicks, global_store, all_annotations, project_name):
# As a placeholder, pulling together the inputs we'd need if we were going to submit a job
"""
This callback collects parameters from the UI and submits a job to the computing api.
If the app is run from "dev" mode, then only a placeholder job_uid will be created.
The job_uid is saved in a dcc.Store for reference by the check_job callback below.

# TODO: Appropriately paramaterize the DEMO_WORKFLOW json depending on user inputs
and relevant file paths
"""
if n_clicks:
annotations = Annotations(all_annotations, global_store)
annotations.create_annotation_mask(
sparse=False
) # TODO: Would sparse need to be true?

# Get metadata and annotation data
metadata = annotations.get_annotations()
mask = annotations.get_annotation_mask()

# Get raw images associated with each annotated slice
# Actually we can just pass the indices and have the job point to Tiled directly
img_idx = list(metadata.keys())
img = get_data_sequence_by_name(project_name)
raw = []
for idx in img_idx:
ar = img[int(idx)]
raw.append(ar)
try:
raw = np.stack(raw)
mask = np.stack(mask)
except ValueError:
return "No annotations to process."

# Some checks to validate that things are in the format we'd expect
print(metadata)
print(mask.shape)
print(raw.shape)

return "Running the model..."
return no_update
if MODE == "dev":
job_uid = str(uuid.uuid4())
return (
dmc.Text(
f"Workflow has been succesfully submitted with uid: {job_uid}",
size="sm",
),
job_uid,
)
else:
data_utils.save_annotations_data(
global_store, all_annotations, project_name
)
job_submitted = requests.post(
"http://job-service:8080/api/v0/workflows", json=DEMO_WORKFLOW
)
job_uid = job_submitted.json()
if job_submitted.status_code == 200:
return (
dmc.Text(
f"Workflow has been succesfully submitted with uid: {job_uid}",
size="sm",
),
job_uid,
)
else:
return (
dmc.Text(
f"Workflow presented error code: {job_submitted.status_code}",
size="sm",
),
job_uid,
)
return no_update, no_update


@callback(
Output("output-details", "children", allow_duplicate=True),
Output("submitted-job-id", "data", allow_duplicate=True),
Input("submitted-job-id", "data"),
Input("model-check", "n_intervals"),
prevent_initial_call=True,
)
def check_job(job_id, n_intervals):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant because the models will most likely take a long time to finish and the user won't wait for it.
It could be worth adding a notification that shows while they are away (or the first time they come back - would need to track in the DB if the notification has been seen) that would inform them that since their last visit, the ML job(s) has(have) finished and list them.

2 weeks when we talked about showing ML output, we agreed that it would go under Data Selection, and if the particular project has ML output, there would appear another dropdown where users would be able to select what they wish to view, definitely not a toggle since they are not limited to 1 ML output per project. See #62 I edited the Issue description with the aforementioned requirements for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed having a simple toggle at this week's meeting, since this feature for now is just implemented as an MVP. Will refine that piece further in #88.

Good point about the length of time. But I'd disagree that this function is entirely redundant, but we may want to refine the mechanism by which we're checking for results, depending on the length of time and how we want to manage user concurrency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Tanny: user ID - get request to the computing API - get all the jobs associated with the user ID and the segmentation - get all the jobs and the status of all these jobs - this happens on page load and happens at an interval.

"""
This callback checks to see if a job has completed successfully and will only
update if there is a job_id present in the submitted-job-id dcc.Store. Will
wait 3sec in "dev" mode to simulate.

# TODO: Connect with the computing API when not in "dev" mode
"""
output_layout = [
dmc.Text(
f"Workflow {job_id} completed successfully. Click button below to view segmentation results.",
size="sm",
),
dmc.Space(h=20),
dmc.Switch(
size="sm",
radius="lg",
label="Show output results",
id="show-results",
checked=False,
),
]

if MODE == "dev":
if job_id:
time.sleep(3)
return (
output_layout,
None,
)
raise PreventUpdate
else:
# TODO - connect with API
raise PreventUpdate
19 changes: 16 additions & 3 deletions components/control_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,18 @@ def _control_item(title, title_id, item):
)


def _accordion_item(title, icon, value, children, id):
def _accordion_item(title, icon, value, children, id, loading=True):
"""
Returns a customized layout for an accordion item
"""
if loading:
panel = dmc.LoadingOverlay(
dmc.AccordionPanel(children=children, id=id),
loaderProps={"size": 0},
overlayOpacity=0.4,
)
else:
panel = dmc.AccordionPanel(children=children, id=id)
return dmc.AccordionItem(
[
dmc.AccordionControl(
Expand All @@ -50,7 +58,7 @@ def _accordion_item(title, icon, value, children, id):
width=20,
),
),
dmc.AccordionPanel(children=children, id=id),
panel,
],
value=value,
)
Expand Down Expand Up @@ -615,8 +623,9 @@ def layout():
variant="light",
style={"width": "160px", "margin": "5px"},
),
html.Div(id="output-placeholder"),
html.Div(id="output-details"),
],
loading=False,
),
],
),
Expand Down Expand Up @@ -692,6 +701,10 @@ def drawer_section(children):
dcc.Download(id="export-annotation-metadata"),
dcc.Download(id="export-annotation-mask"),
dcc.Store(id="project-data"),
dcc.Store(id="submitted-job-id"),
dcc.Interval(
id="model-check", interval=5000
), # TODO: May want to increase frequency
html.Div(id="dummy-output"),
EventListener(
events=[
Expand Down
33 changes: 33 additions & 0 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import os

import httpx
import numpy as np
import requests
from dotenv import load_dotenv
from tiled.client import from_uri
from tiled.client.array import ArrayClient
from tiled.client.container import Container

from utils.annotations import Annotations


def DEV_download_google_sample_data():
"""
Expand Down Expand Up @@ -90,6 +93,36 @@ def DEV_filter_json_data_by_timestamp(data, timestamp):
return [data for data in data if data["time"] == timestamp]


def save_annotations_data(global_store, all_annotations, project_name):
"""
Transforms annotations data to a pixelated mask and outputs to
the Tiled server

# TODO: Save data to Tiled server after transformation
"""
annotations = Annotations(all_annotations, global_store)
annotations.create_annotation_mask(sparse=True) # TODO: Check sparse status

# Get metadata and annotation data
metadata = annotations.get_annotations()
mask = annotations.get_annotation_mask()

# Get raw images associated with each annotated slice
img_idx = list(metadata.keys())
img = data[project_name]
raw = []
for idx in img_idx:
ar = img[int(idx)]
raw.append(ar)
try:
raw = np.stack(raw)
mask = np.stack(mask)
except ValueError:
return "No annotations to process."

return


load_dotenv()

TILED_URI = os.getenv("TILED_URI")
Expand Down
Loading