Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boilerplate for triggering ml job run #91

Merged
merged 11 commits into from
Sep 21, 2023
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
pip install -r requirements.txt
```

2. Configure a connection to the Tiled server via a `.env` file with the following environment variables:
2. Create a `.env` file with the following environment variables:

```
TILED_URI=https://mlex-segmentation.als.lbl.gov
TILED_URI='https://mlex-segmentation.als.lbl.gov'
API_KEY=<key-provided-on-request>
MODE='dev'
```

3. Start a local server:
Expand Down
175 changes: 136 additions & 39 deletions callbacks/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,147 @@
from dash import callback, Input, Output, State, no_update
from utils.annotations import Annotations
from utils.data_utils import data
import numpy as np
from dash.exceptions import PreventUpdate
import os
import uuid
import requests
import time
import dash_mantine_components as dmc
from utils import data_utils

MODE = os.getenv("MODE", "")

DEMO_WORKFLOW = {
"user_uid": "high_res_user",
"job_list": [
{
"mlex_app": "high-res-segmentation",
"description": "test_1",
"service_type": "backend",
"working_directory": "/data/mlex_repo/mlex_tiled/data",
"job_kwargs": {
"uri": "mlexchange1/random-forest-dc:1.1",
"type": "docker",
"cmd": 'python random_forest.py data/seg-results/spiral/image-train data/seg-results-test/spiral/feature data/seg-results/spiral/mask data/seg-results-test/spiral/model \'{"n_estimators": 30, "oob_score": true, "max_depth": 8}\'',
"kwargs": {
"job_type": "train",
"experiment_id": "123",
"dataset": "name_of_dataset",
"params": '{"n_estimators": 30, "oob_score": true, "max_depth": 8}',
},
},
},
{
"mlex_app": "high-res-segmentation",
"description": "test_1",
"service_type": "backend",
"working_directory": "/data/mlex_repo/mlex_tiled/data",
"job_kwargs": {
"uri": "mlexchange1/random-forest-dc:1.1",
"type": "docker",
"cmd": "python segment.py data/data/20221222_085501_looking_from_above_spiralUP_CounterClockwise_endPointAtDoor_0-1000 data/seg-results-test/spiral/model/random-forest.model data/seg-results-test/spiral/output '{\"show_progress\": 1}'",
"kwargs": {
"job_type": "train",
"experiment_id": "124",
"dataset": "name_of_dataset",
"params": '{"show_progress": 1}',
},
},
},
],
"host_list": ["vaughan.als.lbl.gov"],
"dependencies": {"0": [], "1": [0]},
"requirements": {"num_processors": 2, "num_gpus": 0, "num_nodes": 1},
}


# NEXT STEPS:
# - this function returns a job ID, which would be associated with the workflow run on vaughan
# - then we need another callback to pick up this ID and start polling for successful output
@callback(
Output("output-placeholder", "children"),
Output("output-details", "children"),
Output("submitted-job-id", "data"),
Input("run-model", "n_clicks"),
State("annotation-store", "data"),
State("project-name-src", "value"),
)
def run_job(n_clicks, annotation_store, project_name):
# As a placeholder, pulling together the inputs we'd need if we were going to submit a job
"""
This callback collects parameters from the UI and submits a job to the computing api.
If the app is run from "dev" mode, then only a placeholder job_uid will be created.
The job_uid is saved in a dcc.Store for reference by the check_job callback below.

# TODO: Appropriately paramaterize the DEMO_WORKFLOW json depending on user inputs
and relevant file paths
"""
if n_clicks:
if MODE == "dev":
job_uid = str(uuid.uuid4())
return (
dmc.Text(
f"Workflow has been succesfully submitted with uid: {job_uid}",
size="sm",
),
job_uid,
)
else:
data_utils.save_annotations_data(annotation_store, project_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing on Vaughan gives us the following error:

File "/usr/local/lib/python3.9/site-packages/dash/_callback.py", line 450, in add_context
 output_value = func(*func_args, **func_kwargs)  # %% callback invoked %%
File "/app/callbacks/segmentation.py", line 83, in run_job
 data_utils.save_annotations_data(annotation_store, project_name)
File "/app/utils/data_utils.py", line 98, in save_annotations_data
 annotations.create_annotation_metadata()
File "/app/utils/annotations.py", line 61, in create_annotation_metadata
 self.set_annotation_image_shape(image_idx)
File "/app/utils/annotations.py", line 165, in set_annotation_image_shape
 self.annotation_image_shape = self.annotation_store["image_shapes"][0]
 KeyError: 'image_shapes'

I suspect this may come from interacting with a Tiled server that has only a single tiff-sequence in it, so we technically never actively selected a project. Interacting with the GUI more (changing slider value, 'selecting' the single project) does remove this error and the attempt to submit the job is made.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also test: App is loaded, then immediately click the "Run Model" button. And what if the annotation store is empty?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wiebke I think this is happening because of this line, where my guess is that DATA_OPTIONS evaluates to None, which means that the slider is disabled so this block isn't hit.

I think you're probably right in that this because of a different structure on the Tiled server on your end. What's the structure of the data variable you get after running:

client = from_uri(TILED_URI, api_key=API_KEY)
data = client["data"]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to have been indeed an issue with our previous local Tiled setup and resolved with the updated population of the project list.

job_submitted = requests.post(
"http://job-service:8080/api/v0/workflows", json=DEMO_WORKFLOW
)
job_uid = job_submitted.json()
if job_submitted.status_code == 200:
return (
dmc.Text(
f"Workflow has been succesfully submitted with uid: {job_uid}",
size="sm",
),
job_uid,
)
else:
return (
dmc.Text(
f"Workflow presented error code: {job_submitted.status_code}",
size="sm",
),
job_uid,
)
return no_update, no_update


@callback(
Output("output-details", "children", allow_duplicate=True),
Output("submitted-job-id", "data", allow_duplicate=True),
Input("submitted-job-id", "data"),
Input("model-check", "n_intervals"),
prevent_initial_call=True,
)
def check_job(job_id, n_intervals):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant because the models will most likely take a long time to finish and the user won't wait for it.
It could be worth adding a notification that shows while they are away (or the first time they come back - would need to track in the DB if the notification has been seen) that would inform them that since their last visit, the ML job(s) has(have) finished and list them.

2 weeks when we talked about showing ML output, we agreed that it would go under Data Selection, and if the particular project has ML output, there would appear another dropdown where users would be able to select what they wish to view, definitely not a toggle since they are not limited to 1 ML output per project. See #62 I edited the Issue description with the aforementioned requirements for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed having a simple toggle at this week's meeting, since this feature for now is just implemented as an MVP. Will refine that piece further in #88.

Good point about the length of time. But I'd disagree that this function is entirely redundant, but we may want to refine the mechanism by which we're checking for results, depending on the length of time and how we want to manage user concurrency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Tanny: user ID - get request to the computing API - get all the jobs associated with the user ID and the segmentation - get all the jobs and the status of all these jobs - this happens on page load and happens at an interval.

"""
This callback checks to see if a job has completed successfully and will only
update if there is a job_id present in the submitted-job-id dcc.Store. Will
wait 3sec in "dev" mode to simulate.

# TODO: Connect with the computing API when not in "dev" mode
"""
output_layout = [
dmc.Text(
f"Workflow {job_id} completed successfully. Click button below to view segmentation results.",
size="sm",
),
dmc.Space(h=20),
dmc.Switch(
size="sm",
radius="lg",
label="Show output results",
id="show-results",
checked=False,
),
]

annotations = Annotations(annotation_store)
annotations.create_annotation_metadata()
annotations.create_annotation_mask(
sparse=False
) # TODO: Would sparse need to be true?

# Get metadata and annotation data
metadata = annotations.get_annotations()
mask = annotations.get_annotation_mask()

# Get raw images associated with each annotated slice
# Actually we can just pass the indices and have the job point to Tiled directly
img_idx = list(metadata.keys())
img = data[project_name]
raw = []
for idx in img_idx:
ar = img[int(idx)]
raw.append(ar)
try:
raw = np.stack(raw)
mask = np.stack(mask)
except ValueError:
return "No annotations to process."

# Some checks to validate that things are in the format we'd expect
print(metadata)
print(mask.shape)
print(raw.shape)

return "Running the model..."
return no_update
if MODE == "dev":
if job_id:
time.sleep(3)
return (
output_layout,
None,
)
raise PreventUpdate
else:
# TODO - connect with API
raise PreventUpdate
23 changes: 16 additions & 7 deletions components/control_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@ def _control_item(title, title_id, item):
)


def _accordion_item(title, icon, value, children, id):
def _accordion_item(title, icon, value, children, id, loading=True):
if loading:
panel = dmc.LoadingOverlay(
dmc.AccordionPanel(children=children, id=id),
loaderProps={"size": 0},
overlayOpacity=0.4,
)
else:
panel = dmc.AccordionPanel(children=children, id=id)
return dmc.AccordionItem(
[
dmc.AccordionControl(
Expand All @@ -32,11 +40,7 @@ def _accordion_item(title, icon, value, children, id):
width=20,
),
),
dmc.LoadingOverlay(
dmc.AccordionPanel(children=children, id=id),
loaderProps={"size": 0},
overlayOpacity=0.4,
),
panel,
],
value=value,
)
Expand Down Expand Up @@ -735,8 +739,9 @@ def layout():
variant="light",
style={"width": "160px", "margin": "5px"},
),
html.Div(id="output-placeholder"),
html.Div(id="output-details"),
],
loading=False,
),
],
),
Expand Down Expand Up @@ -821,6 +826,10 @@ def drawer_section(children):
dcc.Download(id="export-annotation-metadata"),
dcc.Download(id="export-annotation-mask"),
dcc.Store(id="project-data"),
dcc.Store(id="submitted-job-id"),
dcc.Interval(
id="model-check", interval=5000
), # TODO: May want to increase frequency
html.Div(id="dummy-output"),
EventListener(
events=[
Expand Down
38 changes: 35 additions & 3 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import json
import requests
from urllib.parse import urlparse

from utils.annotations import Annotations
import numpy as np
from tiled.client import from_uri
from tiled.client.cache import Cache
from dotenv import load_dotenv


Expand Down Expand Up @@ -88,6 +87,39 @@ def DEV_filter_json_data_by_timestamp(data, timestamp):
return [data for data in data if data["time"] == timestamp]


def save_annotations_data(annotation_store, project_name):
"""
Transforms annotations data to a pixelated mask and outputs to
the Tiled server

# TODO: Save data to Tiled server after transformation
"""
annotations = Annotations(annotation_store)
annotations.create_annotation_metadata()
annotations.create_annotation_mask(
sparse=False
) # TODO: Would sparse need to be true?

# Get metadata and annotation data
metadata = annotations.get_annotations()
mask = annotations.get_annotation_mask()

# Get raw images associated with each annotated slice
img_idx = list(metadata.keys())
img = data[project_name]
raw = []
for idx in img_idx:
ar = img[int(idx)]
raw.append(ar)
try:
raw = np.stack(raw)
mask = np.stack(mask)
except ValueError:
return "No annotations to process."

return


load_dotenv()

TILED_URI = os.getenv("TILED_URI")
Expand Down
Loading