Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrieve segmentation results and populate job parameters #178

Merged
merged 24 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a8cf70a
Put Tiled client refresh into its own callback
Wiebke Mar 8, 2024
192e900
Remove unused stores
Wiebke Mar 8, 2024
6304ad3
Adapt result retrieval to store element
Wiebke Mar 8, 2024
8bb83a5
Add place holder for dvc results link
Wiebke Mar 8, 2024
0dd2ee4
Delete existing annotations json
Wiebke Mar 8, 2024
2525f4a
Assemble parameters for training job
Wiebke Mar 8, 2024
d92b046
:bug: Fix icon typo and parameter assembly
Wiebke Mar 8, 2024
8f3400e
Add parameters to inference
Wiebke Mar 8, 2024
6987c5a
Replace single result store by 2
Wiebke Mar 8, 2024
2b960e3
Populate result stores on dropdown change
Wiebke Mar 8, 2024
46d5c3f
Add number of classes to model parameters
Wiebke Mar 8, 2024
b6fe3bd
Add example script for copying mask as result
Wiebke Mar 8, 2024
bce4bcd
:bug: Fix default export location
Wiebke Mar 8, 2024
5696eef
:bug: Fix copy mask script (all bools inverted)
Wiebke Mar 9, 2024
8b53e25
:bug: Always check subflows for results
Wiebke Mar 9, 2024
a27b6cd
:bug: Slice into array when copying
Wiebke Mar 9, 2024
cc5c004
Add conda Prefect flows
Wiebke Mar 9, 2024
13a9951
:bug: Guard against `None` for string params
Wiebke Mar 9, 2024
aaff3de
Add missing `network` parameter
Wiebke Mar 9, 2024
9b91f7e
Update `.env.example` for conda flows
Wiebke Mar 9, 2024
cc02a53
Add default values for weights and dilation array
Wiebke Mar 10, 2024
9b06cc5
Change `val_pct` range from 0-100 to 0-1
Wiebke Mar 10, 2024
affc0c9
:bug: Test if `shapes` exist before editing them
Wiebke Mar 10, 2024
b7d5f2f
Change slider and switch styles
Wiebke Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.env
.git
.gitignore
*-env/
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ USER_PASSWORD=<to-be-specified-per-deployment>
PREFECT_API_URL=http://prefect:4200/api
FLOW_NAME="Parent flow/launch_parent_flow"
TIMEZONE="US/Pacific"

# Environment variables for conda-based Prefect flows
CONDA_ENV_NAME="dlsia"
TRAIN_SCRIPT_PATH="src/train.py"
SEGMENT_SCRIPT_PATH="src/segment.py"
47 changes: 31 additions & 16 deletions assets/models.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"contents": [
{
"model_name": "DSLIA MSDNet",
"model_name": "MSDNet",
"version": "0.0.1",
"type": "supervised",
"user": "mlexchange team",
Expand Down Expand Up @@ -48,6 +48,7 @@
"name": "dilation_array",
"title": "Dilation Array",
"param_key": "dilation_array",
"value": "[1, 2, 4]",
"placeholder": "e.g. [1, 2, 4]",
"error": "Provide a list of ints for dilation",
"debounce": 1000,
Expand Down Expand Up @@ -230,6 +231,7 @@
"name": "weights",
"title": "Class Weights",
"param_key": "weights",
"value": "[1]",
"placeholder": "e.g [0.1, 0.4, 0.5]",
"error": "Provide a list with a float for each class",
"debounce": 1000,
Expand Down Expand Up @@ -312,20 +314,21 @@
"title": "Validation %",
"param_key": "val_pct",
"min": 0,
"max": 100,
"step": 5,
"value": 20,
"max": 1,
"step": 0.05,
"value": 0.2,
"precision": 2,
"marks": [
{
"value": 0,
"label": "0%"
},
{
"value": 50,
"value": 0.5,
"label": "50%"
},
{
"value": 100,
"value": 1,
"label": "100%"
}
],
Expand Down Expand Up @@ -434,7 +437,7 @@
"reference": "https://dlsia.readthedocs.io/en/latest/"
},
{
"model_name": "DSLIA TUNet",
"model_name": "TUNet",
"version": "0.0.1",
"type": "supervised",
"user": "mlexchange team",
Expand Down Expand Up @@ -653,6 +656,7 @@
"name": "weights",
"title": "Class Weights",
"param_key": "weights",
"value": "[1]",
"placeholder": "e.g [0.1, 0.4, 0.5]",
"error": "Provide a list with a float for each class",
"debounce": 1000,
Expand Down Expand Up @@ -735,16 +739,21 @@
"title": "Validation %",
"param_key": "val_pct",
"min": 0,
"max": 100,
"step": 5,
"value": 20,
"max": 1,
"step": 0.05,
"value": 0.2,
"precision": 2,
"marks": [
{
"value": 0,
"label": "0%"
},
{
"value": 100,
"value": 0.5,
"label": "50%"
},
{
"value": 1,
"label": "100%"
}
],
Expand Down Expand Up @@ -853,7 +862,7 @@
"reference": "https://dlsia.readthedocs.io/en/latest/"
},
{
"model_name": "DSLIA TUNet3+",
"model_name": "TUNet3+",
"version": "0.0.1",
"type": "supervised",
"user": "mlexchange team",
Expand Down Expand Up @@ -1080,6 +1089,7 @@
"name": "weights",
"title": "Class Weights",
"param_key": "weights",
"value": "[1]",
"placeholder": "e.g [0.1, 0.4, 0.5]",
"error": "Provide a list with a float for each class",
"debounce": 1000,
Expand Down Expand Up @@ -1162,16 +1172,21 @@
"title": "Validation %",
"param_key": "val_pct",
"min": 0,
"max": 100,
"step": 5,
"value": 20,
"max": 1,
"step": 0.05,
"value": 0.2,
"precision": 2,
"marks": [
{
"value": 0,
"label": "0%"
},
{
"value": 100,
"value": 0.5,
"label": "50%"
},
{
"value": 1,
"label": "100%"
}
],
Expand Down
90 changes: 42 additions & 48 deletions callbacks/control_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from components.parameter_items import ParameterItems
from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS
from utils.annotations import Annotations
from utils.data_utils import models, tiled_datasets, tiled_masks, tiled_results
from utils.data_utils import models, tiled_datasets, tiled_masks
from utils.plot_utils import generate_notification, generate_notification_bg_icon_col

# TODO - temporary local file path and user for annotation saving and exporting
Expand Down Expand Up @@ -230,16 +230,18 @@ def annotation_mode(
patched_figure["layout"]["dragmode"] = "drawrect"
annotation_store["dragmode"] = "drawrect"
styles[trigger] = active

elif trigger == "pan-and-zoom" and pan_and_zoom > 0:
patched_figure["layout"]["dragmode"] = "pan"
annotation_store["dragmode"] = "pan"
styles[trigger] = active

# disable shape editing when in pan/zoom mode
for shape in fig["layout"]["shapes"]:
shape["editable"] = trigger != "pan-and-zoom" and pan_and_zoom > 0
patched_figure["layout"]["shapes"] = fig["layout"]["shapes"]
# if no shapes have been added yet,
# none need to be set to not editable
if "shapes" in fig["layout"]:
for shape in fig["layout"]["shapes"]:
shape["editable"] = trigger != "pan-and-zoom" and pan_and_zoom > 0
patched_figure["layout"]["shapes"] = fig["layout"]["shapes"]
return (
patched_figure,
styles["closed-freeform"],
Expand Down Expand Up @@ -853,64 +855,46 @@ def open_controls_drawer(n_clicks, is_opened):
return no_update, no_update


@callback(Output("project-name-src", "data"), Input("refresh-tiled", "n_clicks"))
def refresh_data_client(refresh_tiled):
if refresh_tiled:
tiled_datasets.refresh_data_client()
data_options = [
item for item in tiled_datasets.get_data_project_names() if "seg" not in item
]
return data_options


@callback(
Output("result-selector", "data"),
Output("result-selector", "value"),
Output("result-selector", "disabled"),
Output("show-result-overlay-toggle", "checked"),
Output("show-result-overlay-toggle", "disabled"),
Output("seg-result-opacity-slider", "disabled"),
Output("project-name-src", "data"),
Input("project-name-src", "value"),
Input("refresh-tiled", "n_clicks"),
Input("show-result-overlay-toggle", "checked"),
State("result-selector", "disabled"),
Input("seg-results-train-store", "data"),
Input("seg-results-inference-store", "data"),
State("seg-result-opacity-slider", "disabled"),
)
def populate_classification_results(
image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled
def update_result_controls(
toggle, seg_result_train, seg_result_inference, slider_disabled
):
if refresh_tiled:
tiled_datasets.refresh_data_client()

data_options = [
item for item in tiled_datasets.get_data_project_names() if "seg" not in item
]
results = []
value = None
checked = False
disabled_dropdown = True
disabled_toggle = True
disabled_slider = True
disable_toggle = True
disable_slider = True
# Disable opacity slider if result overlay is unchecked
if ctx.triggered_id == "show-result-overlay-toggle":
results = no_update
value = no_update
checked = no_update
disabled_dropdown = dropdown_enabled
disabled_toggle = False
disabled_slider = slider_enabled
# Must have been enabled to be source of trigger
disable_toggle = no_update
disable_slider = not slider_disabled
else:
# TODO: Match by mask uid instead of image_src
results = [
item
for item in tiled_results.get_data_project_names()
if ("seg" in item and image_src in item)
]
if results:
value = results[0]
disabled_dropdown = False
if seg_result_train or seg_result_inference:
checked = False
disabled_toggle = False
disabled_slider = False

disable_toggle = False
disable_slider = False
return (
results,
value,
disabled_dropdown,
checked,
disabled_toggle,
disabled_slider,
data_options,
disable_toggle,
disable_slider,
)


Expand Down Expand Up @@ -961,6 +945,10 @@ def update_model_parameters(model_name):
),
)
def validate_class_weights(all_annotation_classes, weights):

if weights is None:
return "Provide a list with a float for each class"

parsed_weights = weights.strip("[]").split(",")
try:
parsed_weights = [float(weight.strip()) for weight in parsed_weights]
Expand Down Expand Up @@ -996,11 +984,17 @@ def validate_class_weights(all_annotation_classes, weights):
),
)
def validate_dilation_array(dilation_array):

if dilation_array is None:
return "Provide a list of ints for dilation"

parsed_dilation_array = dilation_array.strip("[]").split(",")
try:
parsed_dilation_array = [
int(array_entry.strip()) for array_entry in parsed_dilation_array
]
if len(parsed_dilation_array) == 0:
return "Provide a list of ints for dilation"
# Check if all elements in the list are floats
return False
except ValueError:
Expand Down
34 changes: 24 additions & 10 deletions callbacks/image_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dash.exceptions import PreventUpdate

from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEYBINDS
from utils.data_utils import tiled_datasets, tiled_masks, tiled_results
from utils.data_utils import tiled_datasets, tiled_results
from utils.plot_utils import (
create_viewfinder,
downscale_view,
Expand Down Expand Up @@ -70,7 +70,8 @@ def hide_show_segmentation_overlay(toggle_seg_result, opacity):
State("image-metadata", "data"),
State("screen-size", "data"),
State("current-class-selection", "data"),
State("result-selector", "value"),
State("seg-results-train-store", "data"),
State("seg-results-inference-store", "data"),
State("seg-result-opacity-slider", "value"),
State("image-viewer", "figure"),
prevent_initial_call=True,
Expand All @@ -85,7 +86,8 @@ def render_image(
image_metadata,
screen_size,
current_color,
seg_result_selection,
seg_result_train,
seg_result_inference,
opacity,
fig,
):
Expand Down Expand Up @@ -118,13 +120,25 @@ def render_image(
and ctx.triggered_id == "show-result-overlay-toggle"
):
return [dash.no_update] * 7 + ["hidden"]
annotation_indices = tiled_masks.get_annotated_segmented_results()
if str(image_idx + 1) in annotation_indices:
# Will not return an error since we already checked if image_idx+1 is in the list
mapped_index = annotation_indices.index(str(image_idx + 1))
result = tiled_results.get_data_sequence_by_name(seg_result_selection)[
mapped_index
]
# Check if the stored results are for the current project and image
if seg_result_train or seg_result_inference:
seg_result = (
seg_result_inference if seg_result_inference else seg_result_train
)
if "mask_idx" in seg_result and seg_result["mask_idx"] is not None:
annotation_indices = seg_result["mask_idx"]
if str(image_idx) in annotation_indices:
# Will not return an error since we already checked if image_idx is in the list
mapped_index = annotation_indices.index(str(image_idx))
result = tiled_results.get_data_by_trimmed_uri(
seg_result["seg_result_trimmed_uri"], slice=mapped_index
)
else:
result = None
else:
result = tiled_results.get_data_by_trimmed_uri(
seg_result["seg_result_trimmed_uri"], slice=image_idx
)
else:
result = None
else:
Expand Down
Loading
Loading