From 7d75edcadfd188bebc16c18dc524426dd304774c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 27 Feb 2024 20:37:36 -0800 Subject: [PATCH 01/38] Rename `tiled_dataset.data` to `tiled_dataset.data_client` Consistent with naming in other projects --- callbacks/control_bar.py | 2 +- utils/data_utils.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9d7d438..f0a1d6b 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -853,7 +853,7 @@ def populate_classification_results( image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled ): if refresh_tiled: - tiled_dataset.refresh_data() + tiled_dataset.refresh_data_client() data_options = [ item for item in tiled_dataset.get_data_project_names() if "seg" not in item diff --git a/utils/data_utils.py b/utils/data_utils.py index bb54d29..c542590 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -22,14 +22,14 @@ def __init__( ): self.data_tiled_uri = data_tiled_uri self.data_tiled_api_key = data_tiled_api_key - self.data = from_uri( + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), ) - def refresh_data(self): - self.data = from_uri( + def refresh_data_client(self): + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), @@ -42,8 +42,8 @@ def get_data_project_names(self): """ project_names = [ project - for project in list(self.data) - if isinstance(self.data[project], (Container, ArrayClient)) + for project in list(self.data_client) + if isinstance(self.data_client[project], (Container, ArrayClient)) ] return project_names @@ -53,7 +53,7 @@ def get_data_sequence_by_name(self, project_name): but can also be additionally encapsulated in a folder, multiple container or in a .nxs file. We make use of specs to figure out the path to the 3d data. """ - project_client = self.data[project_name] + project_client = self.data_client[project_name] # If the project directly points to an array, directly return it if isinstance(project_client, ArrayClient): return project_client From c97d589828fdb2fc24d521850befe064b1776570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 27 Feb 2024 21:08:34 -0800 Subject: [PATCH 02/38] Set up data loader for input and results and add separate mask handler for exporting and loading mask, to replace all annotation loading functionality through the JSON file --- callbacks/control_bar.py | 15 ++++++------- callbacks/image_viewer.py | 10 ++++----- callbacks/segmentation.py | 4 ++-- components/control_bar.py | 4 ++-- utils/data_utils.py | 44 ++++++++++++++++++++++++++++++++------- 5 files changed, 53 insertions(+), 24 deletions(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index f0a1d6b..f5c208b 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -26,7 +26,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import generate_notification, generate_notification_bg_icon_col # TODO - temporary local file path and user for annotation saving and exporting @@ -758,7 +758,7 @@ def populate_load_annotations_dropdown_menu_options(modal_opened, image_src): if not modal_opened: raise PreventUpdate - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) if not data: @@ -804,10 +804,10 @@ def load_and_apply_selected_annotations(selected_annotation, image_src, img_idx) )["index"] # TODO : when quering from the server, load (data) for user, source, time - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) - data = tiled_dataset.DEV_filter_json_data_by_timestamp( + data = tiled_masks.DEV_filter_json_data_by_timestamp( data, str(selected_annotation_timestamp) ) data = data[0]["data"] @@ -853,10 +853,10 @@ def populate_classification_results( image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled ): if refresh_tiled: - tiled_dataset.refresh_data_client() + tiled_datasets.refresh_data_client() data_options = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] results = [] value = None @@ -872,9 +872,10 @@ def populate_classification_results( disabled_toggle = False disabled_slider = slider_enabled else: + # TODO: Match by mask uid instead of image_src results = [ item - for item in tiled_dataset.get_data_project_names() + for item in tiled_results.get_data_project_names() if ("seg" in item and image_src in item) ] if results: diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index a864640..48c6507 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -16,7 +16,7 @@ from dash.exceptions import PreventUpdate from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import ( create_viewfinder, downscale_view, @@ -108,7 +108,7 @@ def render_image( if image_idx: image_idx -= 1 # slider starts at 1, so subtract 1 to get the correct index - tf = tiled_dataset.get_data_sequence_by_name(project_name)[image_idx] + tf = tiled_datasets.get_data_sequence_by_name(project_name)[image_idx] if toggle_seg_result: # if toggle is true and overlay exists already (2 images in data) this will # be handled in hide_show_segmentation_overlay callback @@ -117,8 +117,8 @@ def render_image( and ctx.triggered_id == "show-result-overlay-toggle" ): return [dash.no_update] * 7 + ["hidden"] - if str(image_idx + 1) in tiled_dataset.get_annotated_segmented_results(): - result = tiled_dataset.get_data_sequence_by_name(seg_result_selection)[ + if str(image_idx + 1) in tiled_masks.get_annotated_segmented_results(): + result = tiled_results.get_data_sequence_by_name(seg_result_selection)[ image_idx ] else: @@ -485,7 +485,7 @@ def update_slider_values(project_name, annotation_store): """ # Retrieve data shape if project_name is valid and points to a 3d array data_shape = ( - tiled_dataset.get_data_shape_by_name(project_name) if project_name else None + tiled_datasets.get_data_shape_by_name(project_name) if project_name else None ) disable_slider = data_shape is None if not disable_slider: diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 69a206c..40f47b6 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -7,7 +7,7 @@ from dash import ALL, Input, Output, State, callback, no_update from dash.exceptions import PreventUpdate -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_masks MODE = os.getenv("MODE", "") @@ -84,7 +84,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name): ) else: - tiled_dataset.save_annotations_data( + tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) job_submitted = requests.post( diff --git a/components/control_bar.py b/components/control_bar.py index b4ab2b0..007f856 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -6,7 +6,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets def _tooltip(text, children): @@ -62,7 +62,7 @@ def layout(): Returns the layout for the control panel in the app UI """ DATA_OPTIONS = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] return drawer_section( dmc.Stack( diff --git a/utils/data_utils.py b/utils/data_utils.py index c542590..c51feb8 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -14,6 +14,10 @@ DATA_TILED_URI = os.getenv("DATA_TILED_URI") DATA_TILED_API_KEY = os.getenv("DATA_TILED_API_KEY") +MASK_TILED_URI = os.getenv("MASK_TILED_URI") +MASK_TILED_API_KEY = os.getenv("MASK_TILED_API_KEY") +SEG_TILED_URI = os.getenv("SEG_TILED_URI") +SEG_TILED_API_KEY = os.getenv("SEG_TILED_API_KEY") class TiledDataLoader: @@ -84,6 +88,23 @@ def get_data_shape_by_name(self, project_name): return project_container.shape return None + +class TiledMaskHandler: + """ + This class is used to handle the masks that are generated from the annotations. + """ + + def __init__( + self, mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY + ): + self.mask_tiled_uri = mask_tiled_uri + self.mask_tiled_api_key = mask_tiled_api_key + self.mask_client = from_uri( + self.mask_tiled_uri, + api_key=self.mask_tiled_api_key, + timeout=httpx.Timeout(30.0), + ) + @staticmethod def get_annotated_segmented_results(json_file_path="exported_annotation_data.json"): annotated_slices = [] @@ -128,7 +149,8 @@ def DEV_load_exported_json_data(file_path, USER_NAME, PROJECT_NAME): def DEV_filter_json_data_by_timestamp(data, timestamp): return [data for data in data if data["time"] == timestamp] - def save_annotations_data(self, global_store, all_annotations, project_name): + @staticmethod + def save_annotations_data(global_store, all_annotations, project_name): """ Transforms annotations data to a pixelated mask and outputs to the Tiled server @@ -144,13 +166,9 @@ def save_annotations_data(self, global_store, all_annotations, project_name): # Get raw images associated with each annotated slice img_idx = list(metadata.keys()) - img = self.data[project_name] - raw = [] - for idx in img_idx: - ar = img[int(idx)] - raw.append(ar) + metadata["mask_idx"] = img_idx + metadata["project_name"] = project_name try: - raw = np.stack(raw) mask = np.stack(mask) except ValueError: return "No annotations to process." @@ -158,4 +176,14 @@ def save_annotations_data(self, global_store, all_annotations, project_name): return -tiled_dataset = TiledDataLoader() +tiled_datasets = TiledDataLoader( + data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY +) + +tiled_masks = TiledMaskHandler( + mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY +) + +tiled_results = TiledDataLoader( + data_tiled_uri=SEG_TILED_URI, data_tiled_api_key=SEG_TILED_API_KEY +) From 98284371fd1197a282cb975649f4c73a974dafe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 27 Feb 2024 22:24:15 -0800 Subject: [PATCH 03/38] Generate overlay colormap based on selected colors Resolves #169 --- callbacks/image_viewer.py | 17 ++++++----------- utils/plot_utils.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index 48c6507..3f02a78 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -21,6 +21,7 @@ create_viewfinder, downscale_view, generate_notification, + generate_segmentation_colormap, get_view_finder_max_min, resize_canvas, ) @@ -127,20 +128,14 @@ def render_image( tf = np.zeros((500, 500)) fig = px.imshow(tf, binary_string=True) if toggle_seg_result and result is not None: - unique_segmentation_values = np.unique(result) - normalized_range = np.linspace( - 0, 1, len(unique_segmentation_values) - ) # heatmap requires a normalized range - color_list = ( - px.colors.qualitative.Plotly - ) # TODO placeholder - replace with user defined classess - colorscale = [ - [normalized_range[i], color_list[i % len(color_list)]] - for i in range(len(unique_segmentation_values)) - ] + colorscale, max_class_id = generate_segmentation_colormap( + all_annotation_class_store + ) fig.add_trace( go.Heatmap( z=result, + zmin=-0.5, + zmax=max_class_id + 0.5, colorscale=colorscale, showscale=False, ) diff --git a/utils/plot_utils.py b/utils/plot_utils.py index 55a7dd1..9e70224 100644 --- a/utils/plot_utils.py +++ b/utils/plot_utils.py @@ -1,6 +1,7 @@ import random import dash_mantine_components as dmc +import numpy as np import plotly.express as px import plotly.graph_objects as go from dash_iconify import DashIconify @@ -170,6 +171,37 @@ def resize_canvas(h, w, H, W, figure): return figure, image_center_coor +def generate_segmentation_colormap(all_annotations_data): + """ + Generates a discrete colormap for the segmentation overlay + based on the color information per class. + + The discrete colormap maps values from 0 to 1 to colors, + but is meant to be applied to images with class ids as values, + with these varying from 0 to the number of classes - 1. + To account for numerical inaccuracies, it is best to center the plot range + around the class ids, by setting cmin=-0.5 and cmax=max_class_id+0.5. + """ + max_class_id = max( + [annotation_class["class_id"] for annotation_class in all_annotations_data] + ) + # heatmap requires a normalized range from 0 to 1 + # We need to specify color for at least the range limits (0 and 1) + # as well for every additional class + # due to using zero-based class ids, we need to add 2 to the max class id + normalized_range = np.linspace(0, 1, max_class_id + 2) + color_list = [ + annotation_class["color"] for annotation_class in all_annotations_data + ] + colorscale = [ + [normalized_range[i + j], color_list[i % len(color_list)]] + for i in range(0, normalized_range.size - 1) + for j in range(2) + ] + + return colorscale, max_class_id + + def generate_notification(title, color, icon, message=""): return dmc.Notification( title=title, From 3fa5801b089abde6d56ee4607a454364fc7fa1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 28 Feb 2024 20:03:42 -0800 Subject: [PATCH 04/38] :boom: Change class labels to be 0-based instead of 1-based --- components/annotation_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/annotation_class.py b/components/annotation_class.py index a855609..d97b836 100644 --- a/components/annotation_class.py +++ b/components/annotation_class.py @@ -33,7 +33,7 @@ def annotation_class_item(class_color, class_label, existing_ids, data=None): annotations = data["annotations"] is_visible = data["is_visible"] else: - class_id = 1 if not existing_ids else max(existing_ids) + 1 + class_id = 0 if not existing_ids else max(existing_ids) + 1 annotations = {} is_visible = True class_color_transparent = class_color + "50" From dab2b920a39b73e57d3f6368ee78ea18c6bf4d37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 28 Feb 2024 20:12:11 -0800 Subject: [PATCH 05/38] :bug: Use mapped index when accessing segmented frames with annotations --- callbacks/image_viewer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index 3f02a78..ea29e75 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -118,9 +118,12 @@ def render_image( and ctx.triggered_id == "show-result-overlay-toggle" ): return [dash.no_update] * 7 + ["hidden"] - if str(image_idx + 1) in tiled_masks.get_annotated_segmented_results(): + annotation_indices = tiled_masks.get_annotated_segmented_results() + if str(image_idx + 1) in annotation_indices: + # Will not return an error since we already checked if image_idx+1 is in the list + mapped_index = annotation_indices.index(str(image_idx + 1)) result = tiled_results.get_data_sequence_by_name(seg_result_selection)[ - image_idx + mapped_index ] else: result = None From 570ce78a3d6840b0aef6f6b20ec686eb1ae388dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Thu, 29 Feb 2024 11:23:42 -0800 Subject: [PATCH 06/38] Save masks to Tiled on job submission --- callbacks/segmentation.py | 4 +++- utils/data_utils.py | 27 +++++++++++++++++++-------- utils/plot_utils.py | 4 +++- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 40f47b6..bef1ad3 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -74,6 +74,9 @@ def run_job(n_clicks, global_store, all_annotations, project_name): """ if n_clicks: if MODE == "dev": + tiled_masks.save_annotations_data( + global_store, all_annotations, project_name + ) job_uid = str(uuid.uuid4()) return ( dmc.Text( @@ -83,7 +86,6 @@ def run_job(n_clicks, global_store, all_annotations, project_name): job_uid, ) else: - tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) diff --git a/utils/data_utils.py b/utils/data_utils.py index c51feb8..1f0df06 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -18,6 +18,7 @@ MASK_TILED_API_KEY = os.getenv("MASK_TILED_API_KEY") SEG_TILED_URI = os.getenv("SEG_TILED_URI") SEG_TILED_API_KEY = os.getenv("SEG_TILED_API_KEY") +USER_NAME = os.getenv("USER_NAME", "user1") class TiledDataLoader: @@ -149,16 +150,13 @@ def DEV_load_exported_json_data(file_path, USER_NAME, PROJECT_NAME): def DEV_filter_json_data_by_timestamp(data, timestamp): return [data for data in data if data["time"] == timestamp] - @staticmethod - def save_annotations_data(global_store, all_annotations, project_name): + def save_annotations_data(self, global_store, all_annotations, project_name): """ - Transforms annotations data to a pixelated mask and outputs to - the Tiled server - - # TODO: Save data to Tiled server after transformation + Transforms annotations data to a pixelated mask and outputs to the Tiled server """ annotations = Annotations(all_annotations, global_store) - annotations.create_annotation_mask(sparse=True) # TODO: Check sparse status + # TODO: Check sparse status + annotations.create_annotation_mask(sparse=False) # Get metadata and annotation data metadata = annotations.get_annotations() @@ -173,7 +171,20 @@ def save_annotations_data(global_store, all_annotations, project_name): except ValueError: return "No annotations to process." - return + # Store the mask in the Tiled server under /username/project_name/uid/mask" + container_keys = [USER_NAME, project_name] + last_container = self.mask_client + for key in container_keys: + if key not in last_container.keys(): + last_container = last_container.create_container(key=key) + else: + last_container = last_container[key] + # Add json metadata to a container with a uuid as key + # (uuid will be created by Tiled, since no key is given) + last_container = last_container.create_container(metadata=metadata) + mask = last_container.write_array(key="mask", array=mask) + # print("Created a mask array with the following uri: ", mask.uri) + return mask.uri tiled_datasets = TiledDataLoader( diff --git a/utils/plot_utils.py b/utils/plot_utils.py index 9e70224..f42ccb3 100644 --- a/utils/plot_utils.py +++ b/utils/plot_utils.py @@ -193,12 +193,14 @@ def generate_segmentation_colormap(all_annotations_data): color_list = [ annotation_class["color"] for annotation_class in all_annotations_data ] + # We need to repeat each color twice, to create discrete color segments + # This loop contains the range limits 0 and 1 once, + # but every other value in between twice colorscale = [ [normalized_range[i + j], color_list[i % len(color_list)]] for i in range(0, normalized_range.size - 1) for j in range(2) ] - return colorscale, max_class_id From 228cd31407ca41f6b6590e8dc6398450cfe67de2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Fri, 1 Mar 2024 15:13:37 -0800 Subject: [PATCH 07/38] Change values of unlabeled pixels to -1 --- utils/annotations.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index 765ac73..c69dccd 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -84,7 +84,9 @@ def create_annotation_mask(self, sparse=False): for slice_idx, slice_data in self.annotations.items(): image_height = slice_data[0]["image_shape"][0] image_width = slice_data[0]["image_shape"][1] - slice_mask = np.zeros([image_height, image_width], dtype=np.uint8) + slice_mask = np.full( + [image_height, image_width], fill_value=-1, dtype=np.int8 + ) for shape in slice_data: if shape["type"] == "Closed Freeform": shape_mask = ShapeConversion.closed_path_to_array( @@ -100,7 +102,7 @@ def create_annotation_mask(self, sparse=False): ) else: continue - slice_mask[shape_mask > 0] = shape_mask[shape_mask > 0] + slice_mask[shape_mask >= 0] = shape_mask[shape_mask >= 0] annotation_mask.append(slice_mask) if sparse: @@ -154,7 +156,7 @@ def ellipse_to_array(self, svg_data, image_shape, mask_class): c_radius = abs(svg_data["y0"] - svg_data["y1"]) / 2 # Vertical radius # Create mask and draw ellipse - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.ellipse( cy, cx, c_radius, r_radius ) # Vertical radius first, then horizontal @@ -181,7 +183,7 @@ def rectangle_to_array(self, svg_data, image_shape, mask_class): y1 = max(min(y1, image_height - 1), 0) # # Draw the rectangle - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1)) # Convert coordinates to integers @@ -217,8 +219,9 @@ def closed_path_to_array(self, svg_data, image_shape, mask_class): is_inside = polygon_path.contains_points(points) # Reshape the result back into the 2D shape - mask = is_inside.reshape(image_height, image_width).astype(int) + mask = is_inside.reshape(image_height, image_width).astype(np.int8) # Set the class value for the pixels inside the polygon mask[mask == 1] = mask_class + mask[mask == 0] = -1 return mask From 57a41bd100ddb26adb0f0a19f987c63651bbd115 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Fri, 1 Mar 2024 15:18:37 -0800 Subject: [PATCH 08/38] Restructure meta-data generation to remove redundancies --- utils/annotations.py | 34 ++++++++++++++++++++++++---------- utils/data_utils.py | 22 +++++++++++++++------- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index c69dccd..a830412 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -17,24 +17,34 @@ def __init__(self, annotation_store, global_store): slices = set(slices) annotations = {key: [] for key in slices} + all_class_labels = [ + annotation_class["class_id"] for annotation_class in annotation_store + ] + annotation_classes = {} + for annotation_class in annotation_store: + condensed_id = str(all_class_labels.index(annotation_class["class_id"])) + annotation_classes[condensed_id] = { + "label": annotation_class["label"], + "color": annotation_class["color"], + } for image_idx, slice_data in annotation_class["annotations"].items(): for shape in slice_data: self._set_annotation_type(shape) self._set_annotation_svg(shape) annotation = { - "id": annotation_class["class_id"], + "class_id": condensed_id, "type": self.annotation_type, - "class": annotation_class["label"], - # TODO: This is the same across all images in a dataset - "image_shape": global_store["image_shapes"][0], "svg_data": self.svg_data, } annotations[image_idx].append(annotation) else: - annotations = [] + annotations = None + annotation_classes = None + self.annotation_classes = annotation_classes self.annotations = annotations + self.image_shape = global_store["image_shapes"][0] def get_annotations(self): return self.annotations @@ -42,6 +52,9 @@ def get_annotations(self): def get_annotation_mask(self): return self.annotation_mask + def get_annotation_classes(self): + return self.annotation_classes + def get_annotation_mask_as_bytes(self): buffer = io.BytesIO() zip_buffer = io.BytesIO() @@ -81,24 +94,25 @@ def create_annotation_mask(self, sparse=False): self.sparse = sparse annotation_mask = [] + image_height = self.image_shape[0] + image_width = self.image_shape[1] + for slice_idx, slice_data in self.annotations.items(): - image_height = slice_data[0]["image_shape"][0] - image_width = slice_data[0]["image_shape"][1] slice_mask = np.full( [image_height, image_width], fill_value=-1, dtype=np.int8 ) for shape in slice_data: if shape["type"] == "Closed Freeform": shape_mask = ShapeConversion.closed_path_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Rectangle": shape_mask = ShapeConversion.rectangle_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Ellipse": shape_mask = ShapeConversion.ellipse_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) else: continue diff --git a/utils/data_utils.py b/utils/data_utils.py index 1f0df06..8bc1690 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -155,23 +155,31 @@ def save_annotations_data(self, global_store, all_annotations, project_name): Transforms annotations data to a pixelated mask and outputs to the Tiled server """ annotations = Annotations(all_annotations, global_store) - # TODO: Check sparse status + # TODO: Check sparse status, it may be worthwhile to store the mask as a sparse array + # if our machine learning models can handle sparse arrays annotations.create_annotation_mask(sparse=False) # Get metadata and annotation data - metadata = annotations.get_annotations() + annnotations_per_slice = annotations.get_annotations() + annotation_classes = annotations.get_annotation_classes() + + metadata = { + "classes": annotation_classes, + "unlabeled_class_id": -1, + "annotations": annnotations_per_slice, + "image_shape": global_store["image_shapes"][0], + "project_name": project_name, + "mask_idx": list(annnotations_per_slice.keys()), + } + print("Metadata: ", metadata) mask = annotations.get_annotation_mask() - # Get raw images associated with each annotated slice - img_idx = list(metadata.keys()) - metadata["mask_idx"] = img_idx - metadata["project_name"] = project_name try: mask = np.stack(mask) except ValueError: return "No annotations to process." - # Store the mask in the Tiled server under /username/project_name/uid/mask" + # Store the mask in the Tiled server under /username/project_name/uuid/mask" container_keys = [USER_NAME, project_name] last_container = self.mask_client for key in container_keys: From 9337fb78f1ddacc43b270c0c29c653c720a6d6d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Fri, 1 Mar 2024 15:26:11 -0800 Subject: [PATCH 09/38] Sort mask slices by index to ensure consistent order --- utils/annotations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/utils/annotations.py b/utils/annotations.py index a830412..1526bbf 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -15,7 +15,9 @@ def __init__(self, annotation_store, global_store): for annotation_class in annotation_store: slices.extend(list(annotation_class["annotations"].keys())) slices = set(slices) - annotations = {key: [] for key in slices} + # Slices need to be sorted to ensure that the exported mask slices + # have the same order as the original data set + annotations = {key: [] for key in sorted(slices, key=int)} all_class_labels = [ annotation_class["class_id"] for annotation_class in annotation_store From 77332e4db0389f98cb4c70c62c7794986ab5baa3 Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:11:57 -0800 Subject: [PATCH 10/38] :wrench: added new pkg required for using dash_component_editor --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 7a2fbcc..fac494c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ scipy dash-extensions==1.0.1 dash-bootstrap-components==1.5.0 dash_auth==2.0.0 +dash_daq==0.1.0 From 0a51ed006873a81bbcd9235aadffc7860f810d65 Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:13:30 -0800 Subject: [PATCH 11/38] :sparkles: added content registry example --- utils/content_registry.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 utils/content_registry.py diff --git a/utils/content_registry.py b/utils/content_registry.py new file mode 100644 index 0000000..4e1c12e --- /dev/null +++ b/utils/content_registry.py @@ -0,0 +1,30 @@ +import json +from copy import deepcopy + +class Models: + def __init__(self, modelfile_path='./assets/mode_description.json'): + self.path = modelfile_path + f = open('./assets/mode_description.json') + + self.contents = json.load(f)['contents'] + self.modelname_list = [content['model_name'] for content in self.contents] + self.models = {} + + for i, n in enumerate(self.modelname_list): + self.models[n] = self.contents[i] + + @staticmethod + def remove_key_from_dict_list(data, key): + new_data = [] + for item in data: + if key in item: + new_item = deepcopy(item) + new_item.pop(key) + new_data.append(new_item) + else: + new_data.append(item) + + return new_data + + +models = Models() \ No newline at end of file From c02246d189e927f11f3c323e91e60f74f08ee3dc Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:14:41 -0800 Subject: [PATCH 12/38] :sparkles: added a model description example --- assets/mode_description.json | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100755 assets/mode_description.json diff --git a/assets/mode_description.json b/assets/mode_description.json new file mode 100755 index 0000000..f394cdc --- /dev/null +++ b/assets/mode_description.json @@ -0,0 +1,6 @@ +{ + "contents":[ + {"model_name": "random_forest", "version": "1.0.0", "type": "supervised", "user": "mlexchange team", "uri": "xxx", "application": ["classification", "segmentation"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-tree", "title": "Number of Trees", "param_key": "n_estimators", "value": "30"}, {"type": "int", "name": "tree-depth", "title": "Tree Depth", "param_key": "max_depth", "value": "8"}], "cmd": ["xxx"], "reference": "Adapted from Dash Plotly image segmentation example"}, + {"model_name": "kmeans", "version": "1.0.0", "type": "unsupervised", "user": "mlexchange team", "uri": "xxx", "application": ["segmentation", "clustering"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-cluster", "title": "Number of Clusters", "param_key": "n_clusters", "value": "2"}, {"type": "int", "name": "num-iter", "title": "Max Iteration", "param_key": "max_iter", "value": "300"}], "cmd": ["xxx", "xxxx"], "reference": "Nicholas Schwartz & Howard Yanxon"} + ] +} \ No newline at end of file From a9b9b3f8468564e077492a690aec0a7ac567af7e Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:15:30 -0800 Subject: [PATCH 13/38] :sparkles: added automatic dash gui generator --- app.py | 21 ++ components/control_bar.py | 18 ++ components/dash_component_editor.py | 408 ++++++++++++++++++++++++++++ 3 files changed, 447 insertions(+) create mode 100644 components/dash_component_editor.py diff --git a/app.py b/app.py index c8514d9..b14303d 100644 --- a/app.py +++ b/app.py @@ -10,6 +10,9 @@ from components.control_bar import layout as control_bar_layout from components.image_viewer import layout as image_viewer_layout +from utils.content_registry import models +from components.dash_component_editor import JSONParameterEditor + USER_NAME = os.getenv("USER_NAME") USER_PASSWORD = os.getenv("USER_PASSWORD") @@ -31,8 +34,26 @@ control_bar_layout(), image_viewer_layout(), dcc.Store(id="current-class-selection", data="#FFA200"), + dcc.Store(id="gui-components-values", data={}) ], ) +### automatic Dash gui callback ### +@callback( + Output("gui-layouts", "children"), + Input("model-list", "value"), +) +def update_gui_parameters(model_name): + data = models.models[model_name] + if data["gui_parameters"]: + item_list = JSONParameterEditor( _id={'type': str(uuid.uuid4())}, # pattern match _id (base id), name + json_blob=models.remove_key_from_dict_list(data["gui_parameters"], "comp_group"), + ) + item_list.init_callbacks(app) + return [html.H4("Model Parameters"), item_list] + else: + return[""] + + if __name__ == "__main__": app.run_server(debug=True) diff --git a/components/control_bar.py b/components/control_bar.py index b4ab2b0..08c65ab 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -7,6 +7,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, KEYBINDS from utils.data_utils import tiled_dataset +from utils.content_registry import models def _tooltip(text, children): @@ -603,6 +604,23 @@ def layout(): "run-model", id="model-configuration", children=[ + _control_item( + "Model Selection", + "model-selector", + dmc.Select( + id="model-list", + data=models.modelname_list, + value=( + models.modelname_list[0] + if models.modelname_list[0] + else None + ), + placeholder="Select an model...", + ), + ), + dmc.Space(h=25), + html.Div(id="gui-layouts"), + dmc.Space(h=25), dmc.Center( dmc.Button( "Run model", diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py new file mode 100644 index 0000000..aecc456 --- /dev/null +++ b/components/dash_component_editor.py @@ -0,0 +1,408 @@ +import re +from typing import Callable +# noinspection PyUnresolvedReferences +from inspect import signature, _empty + +from dash import html, dcc, dash_table, Input, Output, State, MATCH, ALL +import dash_bootstrap_components as dbc +import dash_daq as daq + +import base64 +#import PIL.Image +import io +#import plotly.express as px +# Procedural dash form generation + + +""" +{'name', 'title', 'value', 'type', +""" + + +class SimpleItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + type='number', + debounce=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dbc.Input(type=type, + debounce=debounce, + id={**base_id, + 'name': name, + 'param_key': param_key}, + **kwargs) + + super(SimpleItem, self).__init__(children=[self.label, self.input]) + + +class FloatItem(SimpleItem): + pass + + +class IntItem(SimpleItem): + def __init__(self, *args, **kwargs): + if 'min' not in kwargs: + kwargs['min'] = -9007199254740991 + super(IntItem, self).__init__(*args, step=1, **kwargs) + + +class StrItem(SimpleItem): + def __init__(self, *args, **kwargs): + super(StrItem, self).__init__(*args, type='text', **kwargs) + + +class SliderItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dcc.Slider(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + tooltip={"placement": "bottom", "always_visible": True}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(SliderItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input], + style=style) + + +class DropdownItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dcc.Dropdown(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(DropdownItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input], + style=style) + + +class RadioItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dbc.RadioItems(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(RadioItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input], + style=style) + + +class BoolItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = daq.ToggleSwitch(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + **kwargs) + self.output_label = dbc.Label('False/True') + + style = {} + if not visible: + style['display'] = 'none' + + super(BoolItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input, self.output_label], + style=style) + + +class ImgItem(dbc.Col): + def __init__(self, + name, + src, + base_id, + title=None, + param_key=None, + width='100px', + visible=True, + **kwargs): + + if param_key == None: + param_key = name + + if not (width.endswith('px') or width.endswith('%')): + width = width + 'px' + + self.label = dbc.Label(title) + + encoded_image = base64.b64encode(open(src, 'rb').read()) + self.src = 'data:image/png;base64,{}'.format(encoded_image.decode()) + self.input_img = html.Img(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + src=self.src, + style={'height':'auto', 'width':width}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(ImgItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input_img], + style=style) + + +# class GraphItem(dbc.Col): +# def __init__(self, +# name, +# base_id, +# title=None, +# param_key=None, +# visible=True, +# figure = None, +# **kwargs): +# +# self.name = name +# if param_key == None: +# param_key = name +# self.label = dbc.Label(title) +# self.input_graph = dcc.Graph(id={**base_id, +# 'name': name, +# 'param_key': param_key, +# 'layer': 'input'}, +# **kwargs) +# +# self.input_upload = dcc.Upload(id={**base_id, +# 'name': name+'_upload', +# 'param_key': param_key, +# 'layer': 'input'}, +# children=html.Div([ +# 'Drag and Drop or ', +# html.A('Select Files') +# ]), +# style={ +# 'width': '95%', +# 'height': '60px', +# 'lineHeight': '60px', +# 'borderWidth': '1px', +# 'borderStyle': 'dashed', +# 'borderRadius': '5px', +# 'textAlign': 'center', +# 'margin': '10px' +# }, +# multiple = False) +# +# style = {} +# if not visible: +# style['display'] = 'none' +# +# super(GraphItem, self).__init__(id={**base_id, +# 'name': name, +# 'param_key': param_key, +# 'layer': 'form_group'}, +# children=[self.label, self.input_upload, self.input_graph], +# style=style) +# +# # Issue: cannot get inputs from the callback decorator +# def return_upload(self, *args): +# print(f'before if, args {args}') +# if args: +# print(f'args {args}') +# img_bytes = base64.b64decode(contents.split(",")[1]) +# img = PIL.Image.open(io.BytesIO(img_bytes)) +# fig = px.imshow(img, binary_string=True) +# return fig +# +# def init_callbacks(self, app): +# app.callback(Output({**self.id, +# 'name': self.name, +# 'layer': 'input'}, 'figure', allow_duplicate=True), +# Input({**self.id, +# 'name': self.name+'_upload', +# 'layer': 'input'}, +# 'contents'), +# State({**self.id, +# 'name': self.name+'_upload', +# 'layer': 'input'}, 'last_modified'), +# State({**self.id, +# 'name': self.name+'_upload', +# 'layer': 'input'}, 'filename'), +# prevent_initial_call=True)(self.return_upload()) + + + +class ParameterEditor(dbc.Form): + + type_map = {float: FloatItem, + int: IntItem, + str: StrItem, + } + + def __init__(self, _id, parameters, **kwargs): + self._parameters = parameters + + super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + self.children = self.build_children() + + def init_callbacks(self, app): + app.callback(Output(self.id, 'n_submit'), + Input({**self.id, + 'name': ALL}, + 'value'), + State(self.id, 'n_submit'), + ) + + for child in self.children: + if hasattr(child,"init_callbacks"): + child.init_callbacks(app) + + + @property + def values(self): + return {param['name']: param.get('value', None) for param in self._parameters} + + @property + def parameters(self): + return {param['name']: param for param in self._parameters} + + def _determine_type(self, parameter_dict): + if 'type' in parameter_dict: + if parameter_dict['type'] in self.type_map: + return parameter_dict['type'] + elif parameter_dict['type'].__name__ in self.type_map: + return parameter_dict['type'].__name__ + elif type(parameter_dict['value']) in self.type_map: + return type(parameter_dict['value']) + raise TypeError(f'No item type could be determined for this parameter: {parameter_dict}') + + def build_children(self, values=None): + children = [] + for parameter_dict in self._parameters: + parameter_dict = parameter_dict.copy() + if values and parameter_dict['name'] in values: + parameter_dict['value'] = values[parameter_dict['name']] + type = self._determine_type(parameter_dict) + parameter_dict.pop('type', None) + item = self.type_map[type](**parameter_dict, base_id=self.id) + children.append(item) + + return children + + +class JSONParameterEditor(ParameterEditor): + type_map = {'float': FloatItem, + 'int': IntItem, + 'str': StrItem, + 'slider': SliderItem, + 'dropdown': DropdownItem, + 'radio': RadioItem, + 'bool': BoolItem, + 'img': ImgItem, + #'graph': GraphItem, + } + + def __init__(self, _id, json_blob, **kwargs): + super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + self._json_blob = json_blob + self.children = self.build_children() + + def build_children(self, values=None): + children = [] + for json_record in self._json_blob: + ... + # build a parameter dict from self.json_blob + ... + type = json_record.get('type', self._determine_type(json_record)) + json_record = json_record.copy() + if values and json_record['name'] in values: + json_record['value'] = values[json_record['name']] + json_record.pop('type', None) + item = self.type_map[type](**json_record, base_id=self.id) + children.append(item) + + return children + + +class KwargsEditor(ParameterEditor): + def __init__(self, instance_index, func: Callable, **kwargs): + self.func = func + self._instance_index = instance_index + + parameters = [{'name': name, 'value': param.default} for name, param in signature(func).parameters.items() + if param.default is not _empty] + + super(KwargsEditor, self).__init__(dict(index=instance_index, type='kwargs-editor'), parameters=parameters, **kwargs) + + def new_record(self): + return {name: p.default for name, p in signature(self.func).parameters.items() if p.default is not _empty} From 28b2edcd5b0ec3747214f00da502b5e6a1776ef2 Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:16:24 -0800 Subject: [PATCH 14/38] :sparkles: added callback to retrieve model paramters from gui --- callbacks/segmentation.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 69a206c..588c1ce 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -57,13 +57,15 @@ @callback( Output("output-details", "children"), - Output("submitted-job-id", "data"), + Output("submitted-job-id", "data"), + Output("gui-components-values", "data"), Input("run-model", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), + State("gui-layouts", "children") ) -def run_job(n_clicks, global_store, all_annotations, project_name): +def run_job(n_clicks, global_store, all_annotations, project_name, children): """ This callback collects parameters from the UI and submits a job to the computing api. If the app is run from "dev" mode, then only a placeholder job_uid will be created. @@ -72,7 +74,17 @@ def run_job(n_clicks, global_store, all_annotations, project_name): # TODO: Appropriately paramaterize the DEMO_WORKFLOW json depending on user inputs and relevant file paths """ + input_params = {} if n_clicks: + if len(children) >= 2: + params = children[1] + for param in params['props']['children']: + key = param["props"]["children"][1]["props"]["id"]["param_key"] + value = param["props"]["children"][1]["props"]["value"] + input_params[key] = value + + # return the input values in dictionary and saved to dcc.Store "gui-components-values" + print(f'input_param:\n{input_params}') if MODE == "dev": job_uid = str(uuid.uuid4()) return ( @@ -81,6 +93,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name): size="sm", ), job_uid, + input_params ) else: @@ -98,6 +111,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name): size="sm", ), job_uid, + input_params ) else: return ( @@ -106,8 +120,9 @@ def run_job(n_clicks, global_store, all_annotations, project_name): size="sm", ), job_uid, + input_params ) - return no_update, no_update + return no_update, no_update, input_params @callback( From a7bd9267e43b83044c7bded7fed308cd93f17f2c Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:30:56 -0800 Subject: [PATCH 15/38] :wrench: cleaned lb import --- components/dash_component_editor.py | 82 +---------------------------- 1 file changed, 2 insertions(+), 80 deletions(-) diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index aecc456..7fa677a 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -3,7 +3,7 @@ # noinspection PyUnresolvedReferences from inspect import signature, _empty -from dash import html, dcc, dash_table, Input, Output, State, MATCH, ALL +from dash import html, dcc, Input, Output, State, ALL import dash_bootstrap_components as dbc import dash_daq as daq @@ -14,6 +14,7 @@ # Procedural dash form generation + """ {'name', 'title', 'value', 'type', """ @@ -223,85 +224,6 @@ def __init__(self, style=style) -# class GraphItem(dbc.Col): -# def __init__(self, -# name, -# base_id, -# title=None, -# param_key=None, -# visible=True, -# figure = None, -# **kwargs): -# -# self.name = name -# if param_key == None: -# param_key = name -# self.label = dbc.Label(title) -# self.input_graph = dcc.Graph(id={**base_id, -# 'name': name, -# 'param_key': param_key, -# 'layer': 'input'}, -# **kwargs) -# -# self.input_upload = dcc.Upload(id={**base_id, -# 'name': name+'_upload', -# 'param_key': param_key, -# 'layer': 'input'}, -# children=html.Div([ -# 'Drag and Drop or ', -# html.A('Select Files') -# ]), -# style={ -# 'width': '95%', -# 'height': '60px', -# 'lineHeight': '60px', -# 'borderWidth': '1px', -# 'borderStyle': 'dashed', -# 'borderRadius': '5px', -# 'textAlign': 'center', -# 'margin': '10px' -# }, -# multiple = False) -# -# style = {} -# if not visible: -# style['display'] = 'none' -# -# super(GraphItem, self).__init__(id={**base_id, -# 'name': name, -# 'param_key': param_key, -# 'layer': 'form_group'}, -# children=[self.label, self.input_upload, self.input_graph], -# style=style) -# -# # Issue: cannot get inputs from the callback decorator -# def return_upload(self, *args): -# print(f'before if, args {args}') -# if args: -# print(f'args {args}') -# img_bytes = base64.b64decode(contents.split(",")[1]) -# img = PIL.Image.open(io.BytesIO(img_bytes)) -# fig = px.imshow(img, binary_string=True) -# return fig -# -# def init_callbacks(self, app): -# app.callback(Output({**self.id, -# 'name': self.name, -# 'layer': 'input'}, 'figure', allow_duplicate=True), -# Input({**self.id, -# 'name': self.name+'_upload', -# 'layer': 'input'}, -# 'contents'), -# State({**self.id, -# 'name': self.name+'_upload', -# 'layer': 'input'}, 'last_modified'), -# State({**self.id, -# 'name': self.name+'_upload', -# 'layer': 'input'}, 'filename'), -# prevent_initial_call=True)(self.return_upload()) - - - class ParameterEditor(dbc.Form): type_map = {float: FloatItem, From 70a1146b0150ee9729dc6b236cebe9d203bcb187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 15:00:39 -0800 Subject: [PATCH 16/38] Add `.dockerignore` ignoring `.env` --- .dockerignore | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..1fafbd5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.env +.git +.gitignore From 55d90e3f764f0bba38199f328be23546c6bdbf93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 15:51:19 -0800 Subject: [PATCH 17/38] Add example script that saves a mask plot based on a given uri --- examples/plot_mask.py | 72 +++++++++++++++++++++++++++++++++++++++ examples/requirements.txt | 2 ++ 2 files changed, 74 insertions(+) create mode 100644 examples/plot_mask.py create mode 100644 examples/requirements.txt diff --git a/examples/plot_mask.py b/examples/plot_mask.py new file mode 100644 index 0000000..eee8d8c --- /dev/null +++ b/examples/plot_mask.py @@ -0,0 +1,72 @@ +import os +import sys + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +from dotenv import load_dotenv +from matplotlib.colors import ListedColormap +from tiled.client import from_uri + + +def plot_mask(mask_uri, api_key, slice_idx, output_path): + """ + Saves a plot of a given mask using metadata information such as class colors and labels. + It is assumed that the given uri is the uri of a mask container, with associated meta data + and a mask array under the key "mask". + The given slice index references mask slices, not the original data. + However, the printed slice index in the figure will be the index of the original data. + """ + # Retrieve mask and metadata + mask_client = from_uri(mask_uri, api_key=api_key) + mask = mask_client["mask"][slice_idx] + meta_data = mask_client.metadata + mask_idx = meta_data["mask_idx"] + + if slice_idx > len(mask_idx): + raise ValueError("Slice index out of range") + + class_meta_data = meta_data["classes"] + max_class_id = len(class_meta_data.keys()) - 1 + + colors = [ + annotation_class["color"] for _, annotation_class in class_meta_data.items() + ] + labels = [ + annotation_class["label"] for _, annotation_class in class_meta_data.items() + ] + plt.imshow( + mask, + cmap=ListedColormap(colors), + vmin=-0.5, + vmax=max_class_id + 0.5, + ) + plt.title(meta_data["project_name"] + ", slice: " + mask_idx[slice_idx]) + + # create a patch for every color + patches = [ + mpatches.Patch(color=colors[i], label=labels[i]) for i in range(len(labels)) + ] + # Plot legend below the image + plt.legend( + handles=patches, loc="upper center", bbox_to_anchor=(0.5, -0.075), ncol=3 + ) + plt.savefig(output_path, bbox_inches="tight") + + +if __name__ == "__main__": + """ + Example usage: python3 plot_mask.py http://localhost:8000/api/v1/metadata/mlex_store/mlex_store/username/dataset/uuid + """ + + load_dotenv() + api_key = os.getenv("MASK_TILED_API_KEY", None) + + if len(sys.argv) < 2: + print("Usage: python3 plot_mask.py [slice_idx] [output_path]") + sys.exit(1) + + mask_uri = sys.argv[1] + slice_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0 + output_path = sys.argv[3] if len(sys.argv) > 3 else "mask.png" + + plot_mask(mask_uri, api_key, slice_idx, output_path) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..618af5e --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,2 @@ +matplotlib +tiled[client] From 3a951b9934aeb440e59a1245facdcb514a9ecea8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 18:19:53 -0800 Subject: [PATCH 18/38] Include `data_uri` in mask meta data --- utils/data_utils.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/utils/data_utils.py b/utils/data_utils.py index 8bc1690..a580b28 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -89,6 +89,20 @@ def get_data_shape_by_name(self, project_name): return project_container.shape return None + def get_data_uri_by_name(self, project_name): + """ + Retrieve uri of the data + """ + project_container = self.get_data_sequence_by_name(project_name) + if project_container: + return project_container.uri + return None + + +tiled_datasets = TiledDataLoader( + data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY +) + class TiledMaskHandler: """ @@ -169,6 +183,7 @@ def save_annotations_data(self, global_store, all_annotations, project_name): "annotations": annnotations_per_slice, "image_shape": global_store["image_shapes"][0], "project_name": project_name, + "data_uri": tiled_datasets.get_data_uri_by_name(project_name), "mask_idx": list(annnotations_per_slice.keys()), } print("Metadata: ", metadata) @@ -191,13 +206,9 @@ def save_annotations_data(self, global_store, all_annotations, project_name): # (uuid will be created by Tiled, since no key is given) last_container = last_container.create_container(metadata=metadata) mask = last_container.write_array(key="mask", array=mask) - # print("Created a mask array with the following uri: ", mask.uri) - return mask.uri + return mask.uri -tiled_datasets = TiledDataLoader( - data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY -) tiled_masks = TiledMaskHandler( mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY From 0b88839e0c67881686439659d384040cc20e2d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:31:40 -0800 Subject: [PATCH 19/38] :bug: Fix export of free-forms for class with index 0 --- utils/annotations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index 1526bbf..7877448 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -237,7 +237,7 @@ def closed_path_to_array(self, svg_data, image_shape, mask_class): # Reshape the result back into the 2D shape mask = is_inside.reshape(image_height, image_width).astype(np.int8) - # Set the class value for the pixels inside the polygon - mask[mask == 1] = mask_class + # Set the class value for the pixels inside the polygon, -1 for the rest mask[mask == 0] = -1 + mask[mask == 1] = mask_class return mask From b709dbdae055c774eae85aaf4f0396a7607b19f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:38:31 -0800 Subject: [PATCH 20/38] :sparkles: Use hash as key for masks instead of randomly generated uuid --- utils/annotations.py | 9 +++++++++ utils/data_utils.py | 26 ++++++++++++++++---------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index 7877448..6978c5f 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -1,4 +1,6 @@ +import hashlib import io +import json import zipfile import numpy as np @@ -46,6 +48,7 @@ def __init__(self, annotation_store, global_store): self.annotation_classes = annotation_classes self.annotations = annotations + self.annotations_hash = self.get_annotations_hash() self.image_shape = global_store["image_shapes"][0] def get_annotations(self): @@ -57,6 +60,12 @@ def get_annotation_mask(self): def get_annotation_classes(self): return self.annotation_classes + def get_annotations_hash(self): + hash_object = hashlib.md5() + hash_object.update(json.dumps(self.annotations, sort_keys=True).encode()) + hash_object.update(json.dumps(self.annotation_classes, sort_keys=True).encode()) + return hash_object.hexdigest() + def get_annotation_mask_as_bytes(self): buffer = io.BytesIO() zip_buffer = io.BytesIO() diff --git a/utils/data_utils.py b/utils/data_utils.py index a580b28..a60b919 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -176,17 +176,18 @@ def save_annotations_data(self, global_store, all_annotations, project_name): # Get metadata and annotation data annnotations_per_slice = annotations.get_annotations() annotation_classes = annotations.get_annotation_classes() + annotations_hash = annotations.get_annotations_hash() metadata = { - "classes": annotation_classes, - "unlabeled_class_id": -1, - "annotations": annnotations_per_slice, - "image_shape": global_store["image_shapes"][0], "project_name": project_name, "data_uri": tiled_datasets.get_data_uri_by_name(project_name), + "image_shape": global_store["image_shapes"][0], "mask_idx": list(annnotations_per_slice.keys()), + "classes": annotation_classes, + "annotations": annnotations_per_slice, + "unlabeled_class_id": -1, } - print("Metadata: ", metadata) + mask = annotations.get_annotation_mask() try: @@ -202,12 +203,17 @@ def save_annotations_data(self, global_store, all_annotations, project_name): last_container = last_container.create_container(key=key) else: last_container = last_container[key] - # Add json metadata to a container with a uuid as key - # (uuid will be created by Tiled, since no key is given) - last_container = last_container.create_container(metadata=metadata) - mask = last_container.write_array(key="mask", array=mask) - return mask.uri + # Add json metadata to a container with the md5 hash as key + # if a mask with that hash does not already exist + if annotations_hash not in last_container.keys(): + last_container = last_container.create_container( + key=annotations_hash, metadata=metadata + ) + mask = last_container.write_array(key="mask", array=mask) + else: + last_container = last_container[annotations_hash] + return last_container.uri tiled_masks = TiledMaskHandler( From ee37e3581ecb5115aa2f10540cd33d24749a00e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:40:11 -0800 Subject: [PATCH 21/38] Check for unlabled label name and add color for unlabled pixels in plotting --- callbacks/control_bar.py | 4 ++++ examples/plot_mask.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index f5c208b..d0da110 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -332,6 +332,10 @@ def open_annotation_class_modal( disable_class_creation = True error_msg.append("Label Already in Use!") error_msg.append(html.Br()) + if new_label == "Unlabeled": + disable_class_creation = True + error_msg.append("Label name cannot be 'Unlabeled'") + error_msg.append(html.Br()) if new_color in current_colors: disable_class_creation = True error_msg.append("Color Already in use!") diff --git a/examples/plot_mask.py b/examples/plot_mask.py index eee8d8c..fa4735d 100644 --- a/examples/plot_mask.py +++ b/examples/plot_mask.py @@ -19,6 +19,7 @@ def plot_mask(mask_uri, api_key, slice_idx, output_path): # Retrieve mask and metadata mask_client = from_uri(mask_uri, api_key=api_key) mask = mask_client["mask"][slice_idx] + meta_data = mask_client.metadata mask_idx = meta_data["mask_idx"] @@ -34,10 +35,14 @@ def plot_mask(mask_uri, api_key, slice_idx, output_path): labels = [ annotation_class["label"] for _, annotation_class in class_meta_data.items() ] + # Add color for unlabeled pixels + colors = ["#D3D3D3"] + colors + labels = ["Unlabeled"] + labels + plt.imshow( mask, cmap=ListedColormap(colors), - vmin=-0.5, + vmin=-1.5, vmax=max_class_id + 0.5, ) plt.title(meta_data["project_name"] + ", slice: " + mask_idx[slice_idx]) From 212530d7cf5832938fb56517fbb9783841502d9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:41:17 -0800 Subject: [PATCH 22/38] Capture `mask_uri` for job submission --- callbacks/segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index bef1ad3..c7c8d40 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -74,19 +74,19 @@ def run_job(n_clicks, global_store, all_annotations, project_name): """ if n_clicks: if MODE == "dev": - tiled_masks.save_annotations_data( + mask_uri = tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) job_uid = str(uuid.uuid4()) return ( dmc.Text( - f"Workflow has been succesfully submitted with uid: {job_uid}", + f"Workflow has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}", size="sm", ), job_uid, ) else: - tiled_masks.save_annotations_data( + mask_uri = tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) job_submitted = requests.post( From f0d6a428ee4e20f1e2ca3c1036ecd710118f88c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 12:33:47 -0800 Subject: [PATCH 23/38] Use `canonicaljson` instead of `json.dumps(..., sort_keys=True)` for hashing --- requirements.txt | 1 + utils/annotations.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7a2fbcc..6eca7ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ scipy dash-extensions==1.0.1 dash-bootstrap-components==1.5.0 dash_auth==2.0.0 +canonicaljson diff --git a/utils/annotations.py b/utils/annotations.py index 6978c5f..c2eafcd 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -1,8 +1,8 @@ import hashlib import io -import json import zipfile +import canonicaljson import numpy as np import scipy.sparse as sp from matplotlib.path import Path @@ -62,8 +62,8 @@ def get_annotation_classes(self): def get_annotations_hash(self): hash_object = hashlib.md5() - hash_object.update(json.dumps(self.annotations, sort_keys=True).encode()) - hash_object.update(json.dumps(self.annotation_classes, sort_keys=True).encode()) + hash_object.update(canonicaljson.encode_canonical_json(self.annotations)) + hash_object.update(canonicaljson.encode_canonical_json(self.annotation_classes)) return hash_object.hexdigest() def get_annotation_mask_as_bytes(self): From 10dab134c4151eb2ee4f48b587790dcde7dc7e4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 17:55:28 -0800 Subject: [PATCH 24/38] Apply `pre-commit run --all files` --- app.py | 19 +- assets/mode_description.json | 6 +- callbacks/segmentation.py | 20 +- components/control_bar.py | 2 +- components/dash_component_editor.py | 389 ++++++++++++++-------------- utils/content_registry.py | 21 +- 6 files changed, 231 insertions(+), 226 deletions(-) diff --git a/app.py b/app.py index b14303d..7b9538c 100644 --- a/app.py +++ b/app.py @@ -8,10 +8,9 @@ from callbacks.image_viewer import * # noqa: F403, F401 from callbacks.segmentation import * # noqa: F403, F401 from components.control_bar import layout as control_bar_layout +from components.dash_component_editor import JSONParameterEditor from components.image_viewer import layout as image_viewer_layout - from utils.content_registry import models -from components.dash_component_editor import JSONParameterEditor USER_NAME = os.getenv("USER_NAME") USER_PASSWORD = os.getenv("USER_PASSWORD") @@ -34,10 +33,11 @@ control_bar_layout(), image_viewer_layout(), dcc.Store(id="current-class-selection", data="#FFA200"), - dcc.Store(id="gui-components-values", data={}) + dcc.Store(id="gui-components-values", data={}), ], ) + ### automatic Dash gui callback ### @callback( Output("gui-layouts", "children"), @@ -45,14 +45,17 @@ ) def update_gui_parameters(model_name): data = models.models[model_name] - if data["gui_parameters"]: - item_list = JSONParameterEditor( _id={'type': str(uuid.uuid4())}, # pattern match _id (base id), name - json_blob=models.remove_key_from_dict_list(data["gui_parameters"], "comp_group"), - ) + if data["gui_parameters"]: + item_list = JSONParameterEditor( + _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name + json_blob=models.remove_key_from_dict_list( + data["gui_parameters"], "comp_group" + ), + ) item_list.init_callbacks(app) return [html.H4("Model Parameters"), item_list] else: - return[""] + return [""] if __name__ == "__main__": diff --git a/assets/mode_description.json b/assets/mode_description.json index f394cdc..7c95a20 100755 --- a/assets/mode_description.json +++ b/assets/mode_description.json @@ -1,6 +1,6 @@ -{ +{ "contents":[ {"model_name": "random_forest", "version": "1.0.0", "type": "supervised", "user": "mlexchange team", "uri": "xxx", "application": ["classification", "segmentation"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-tree", "title": "Number of Trees", "param_key": "n_estimators", "value": "30"}, {"type": "int", "name": "tree-depth", "title": "Tree Depth", "param_key": "max_depth", "value": "8"}], "cmd": ["xxx"], "reference": "Adapted from Dash Plotly image segmentation example"}, {"model_name": "kmeans", "version": "1.0.0", "type": "unsupervised", "user": "mlexchange team", "uri": "xxx", "application": ["segmentation", "clustering"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-cluster", "title": "Number of Clusters", "param_key": "n_clusters", "value": "2"}, {"type": "int", "name": "num-iter", "title": "Max Iteration", "param_key": "max_iter", "value": "300"}], "cmd": ["xxx", "xxxx"], "reference": "Nicholas Schwartz & Howard Yanxon"} - ] -} \ No newline at end of file + ] +} diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 588c1ce..c477bcb 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -57,13 +57,13 @@ @callback( Output("output-details", "children"), - Output("submitted-job-id", "data"), + Output("submitted-job-id", "data"), Output("gui-components-values", "data"), Input("run-model", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), - State("gui-layouts", "children") + State("gui-layouts", "children"), ) def run_job(n_clicks, global_store, all_annotations, project_name, children): """ @@ -78,13 +78,13 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): if n_clicks: if len(children) >= 2: params = children[1] - for param in params['props']['children']: - key = param["props"]["children"][1]["props"]["id"]["param_key"] + for param in params["props"]["children"]: + key = param["props"]["children"][1]["props"]["id"]["param_key"] value = param["props"]["children"][1]["props"]["value"] input_params[key] = value - - # return the input values in dictionary and saved to dcc.Store "gui-components-values" - print(f'input_param:\n{input_params}') + + # return the input values in dictionary and saved to dcc.Store "gui-components-values" + print(f"input_param:\n{input_params}") if MODE == "dev": job_uid = str(uuid.uuid4()) return ( @@ -93,7 +93,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): size="sm", ), job_uid, - input_params + input_params, ) else: @@ -111,7 +111,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): size="sm", ), job_uid, - input_params + input_params, ) else: return ( @@ -120,7 +120,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): size="sm", ), job_uid, - input_params + input_params, ) return no_update, no_update, input_params diff --git a/components/control_bar.py b/components/control_bar.py index 08c65ab..c8f6bf9 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -6,8 +6,8 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, KEYBINDS -from utils.data_utils import tiled_dataset from utils.content_registry import models +from utils.data_utils import tiled_dataset def _tooltip(text, children): diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 7fa677a..0ab0f14 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,44 +1,46 @@ +import base64 + +# import PIL.Image +import io import re -from typing import Callable + # noinspection PyUnresolvedReferences -from inspect import signature, _empty +from inspect import _empty, signature +from typing import Callable -from dash import html, dcc, Input, Output, State, ALL import dash_bootstrap_components as dbc import dash_daq as daq +from dash import ALL, Input, Output, State, dcc, html -import base64 -#import PIL.Image -import io -#import plotly.express as px +# import plotly.express as px # Procedural dash form generation - """ -{'name', 'title', 'value', 'type', +{'name', 'title', 'value', 'type', """ -class SimpleItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - type='number', - debounce=True, - **kwargs): - +class SimpleItem(dbc.Col): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + type="number", + debounce=True, + **kwargs, + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dbc.Input(type=type, - debounce=debounce, - id={**base_id, - 'name': name, - 'param_key': param_key}, - **kwargs) + self.input = dbc.Input( + type=type, + debounce=debounce, + id={**base_id, "name": name, "param_key": param_key}, + **kwargs, + ) super(SimpleItem, self).__init__(children=[self.label, self.input]) @@ -49,253 +51,241 @@ class FloatItem(SimpleItem): class IntItem(SimpleItem): def __init__(self, *args, **kwargs): - if 'min' not in kwargs: - kwargs['min'] = -9007199254740991 + if "min" not in kwargs: + kwargs["min"] = -9007199254740991 super(IntItem, self).__init__(*args, step=1, **kwargs) class StrItem(SimpleItem): def __init__(self, *args, **kwargs): - super(StrItem, self).__init__(*args, type='text', **kwargs) + super(StrItem, self).__init__(*args, type="text", **kwargs) class SliderItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - debounce=True, - visible=True, - **kwargs): - + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs, + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dcc.Slider(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - tooltip={"placement": "bottom", "always_visible": True}, - **kwargs) + self.input = dcc.Slider( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + tooltip={"placement": "bottom", "always_visible": True}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(SliderItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) + super(SliderItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input], + style=style, + ) class DropdownItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - debounce=True, - visible=True, - **kwargs): - + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs, + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dcc.Dropdown(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) + self.input = dcc.Dropdown( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(DropdownItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) + super(DropdownItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input], + style=style, + ) class RadioItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - visible=True, - **kwargs): - + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dbc.RadioItems(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) + self.input = dbc.RadioItems( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(RadioItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) + super(RadioItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input], + style=style, + ) class BoolItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - visible=True, - **kwargs): - + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = daq.ToggleSwitch(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) - self.output_label = dbc.Label('False/True') + self.input = daq.ToggleSwitch( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + self.output_label = dbc.Label("False/True") style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(BoolItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input, self.output_label], - style=style) + super(BoolItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input, self.output_label], + style=style, + ) class ImgItem(dbc.Col): - def __init__(self, - name, - src, - base_id, - title=None, - param_key=None, - width='100px', - visible=True, - **kwargs): - + def __init__( + self, + name, + src, + base_id, + title=None, + param_key=None, + width="100px", + visible=True, + **kwargs, + ): if param_key == None: param_key = name - - if not (width.endswith('px') or width.endswith('%')): - width = width + 'px' - + + if not (width.endswith("px") or width.endswith("%")): + width = width + "px" + self.label = dbc.Label(title) - - encoded_image = base64.b64encode(open(src, 'rb').read()) - self.src = 'data:image/png;base64,{}'.format(encoded_image.decode()) - self.input_img = html.Img(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - src=self.src, - style={'height':'auto', 'width':width}, - **kwargs) + + encoded_image = base64.b64encode(open(src, "rb").read()) + self.src = "data:image/png;base64,{}".format(encoded_image.decode()) + self.input_img = html.Img( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + src=self.src, + style={"height": "auto", "width": width}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(ImgItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input_img], - style=style) + super(ImgItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input_img], + style=style, + ) class ParameterEditor(dbc.Form): - - type_map = {float: FloatItem, - int: IntItem, - str: StrItem, - } + type_map = { + float: FloatItem, + int: IntItem, + str: StrItem, + } def __init__(self, _id, parameters, **kwargs): self._parameters = parameters - super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + super(ParameterEditor, self).__init__( + id=_id, children=[], className="kwarg-editor", **kwargs + ) self.children = self.build_children() def init_callbacks(self, app): - app.callback(Output(self.id, 'n_submit'), - Input({**self.id, - 'name': ALL}, - 'value'), - State(self.id, 'n_submit'), - ) - + app.callback( + Output(self.id, "n_submit"), + Input({**self.id, "name": ALL}, "value"), + State(self.id, "n_submit"), + ) + for child in self.children: - if hasattr(child,"init_callbacks"): - child.init_callbacks(app) - - + if hasattr(child, "init_callbacks"): + child.init_callbacks(app) + @property def values(self): - return {param['name']: param.get('value', None) for param in self._parameters} + return {param["name"]: param.get("value", None) for param in self._parameters} @property def parameters(self): - return {param['name']: param for param in self._parameters} + return {param["name"]: param for param in self._parameters} def _determine_type(self, parameter_dict): - if 'type' in parameter_dict: - if parameter_dict['type'] in self.type_map: - return parameter_dict['type'] - elif parameter_dict['type'].__name__ in self.type_map: - return parameter_dict['type'].__name__ - elif type(parameter_dict['value']) in self.type_map: - return type(parameter_dict['value']) - raise TypeError(f'No item type could be determined for this parameter: {parameter_dict}') + if "type" in parameter_dict: + if parameter_dict["type"] in self.type_map: + return parameter_dict["type"] + elif parameter_dict["type"].__name__ in self.type_map: + return parameter_dict["type"].__name__ + elif type(parameter_dict["value"]) in self.type_map: + return type(parameter_dict["value"]) + raise TypeError( + f"No item type could be determined for this parameter: {parameter_dict}" + ) def build_children(self, values=None): children = [] for parameter_dict in self._parameters: parameter_dict = parameter_dict.copy() - if values and parameter_dict['name'] in values: - parameter_dict['value'] = values[parameter_dict['name']] + if values and parameter_dict["name"] in values: + parameter_dict["value"] = values[parameter_dict["name"]] type = self._determine_type(parameter_dict) - parameter_dict.pop('type', None) - item = self.type_map[type](**parameter_dict, base_id=self.id) + parameter_dict.pop("type", None) + item = self.type_map[type](**parameter_dict, base_id=self.id) children.append(item) return children - + class JSONParameterEditor(ParameterEditor): - type_map = {'float': FloatItem, - 'int': IntItem, - 'str': StrItem, - 'slider': SliderItem, - 'dropdown': DropdownItem, - 'radio': RadioItem, - 'bool': BoolItem, - 'img': ImgItem, - #'graph': GraphItem, - } + type_map = { + "float": FloatItem, + "int": IntItem, + "str": StrItem, + "slider": SliderItem, + "dropdown": DropdownItem, + "radio": RadioItem, + "bool": BoolItem, + "img": ImgItem, + #'graph': GraphItem, + } def __init__(self, _id, json_blob, **kwargs): - super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + super(ParameterEditor, self).__init__( + id=_id, children=[], className="kwarg-editor", **kwargs + ) self._json_blob = json_blob self.children = self.build_children() @@ -305,11 +295,11 @@ def build_children(self, values=None): ... # build a parameter dict from self.json_blob ... - type = json_record.get('type', self._determine_type(json_record)) + type = json_record.get("type", self._determine_type(json_record)) json_record = json_record.copy() - if values and json_record['name'] in values: - json_record['value'] = values[json_record['name']] - json_record.pop('type', None) + if values and json_record["name"] in values: + json_record["value"] = values[json_record["name"]] + json_record.pop("type", None) item = self.type_map[type](**json_record, base_id=self.id) children.append(item) @@ -321,10 +311,21 @@ def __init__(self, instance_index, func: Callable, **kwargs): self.func = func self._instance_index = instance_index - parameters = [{'name': name, 'value': param.default} for name, param in signature(func).parameters.items() - if param.default is not _empty] + parameters = [ + {"name": name, "value": param.default} + for name, param in signature(func).parameters.items() + if param.default is not _empty + ] - super(KwargsEditor, self).__init__(dict(index=instance_index, type='kwargs-editor'), parameters=parameters, **kwargs) + super(KwargsEditor, self).__init__( + dict(index=instance_index, type="kwargs-editor"), + parameters=parameters, + **kwargs, + ) def new_record(self): - return {name: p.default for name, p in signature(self.func).parameters.items() if p.default is not _empty} + return { + name: p.default + for name, p in signature(self.func).parameters.items() + if p.default is not _empty + } diff --git a/utils/content_registry.py b/utils/content_registry.py index 4e1c12e..c441d98 100644 --- a/utils/content_registry.py +++ b/utils/content_registry.py @@ -1,18 +1,19 @@ import json from copy import deepcopy + class Models: - def __init__(self, modelfile_path='./assets/mode_description.json'): - self.path = modelfile_path - f = open('./assets/mode_description.json') - - self.contents = json.load(f)['contents'] - self.modelname_list = [content['model_name'] for content in self.contents] + def __init__(self, modelfile_path="./assets/mode_description.json"): + self.path = modelfile_path + f = open("./assets/mode_description.json") + + self.contents = json.load(f)["contents"] + self.modelname_list = [content["model_name"] for content in self.contents] self.models = {} for i, n in enumerate(self.modelname_list): self.models[n] = self.contents[i] - + @staticmethod def remove_key_from_dict_list(data, key): new_data = [] @@ -23,8 +24,8 @@ def remove_key_from_dict_list(data, key): new_data.append(new_item) else: new_data.append(item) - - return new_data + + return new_data -models = Models() \ No newline at end of file +models = Models() From 0b8f71bccdc2b26b5980844f386c29fda54becc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 18:10:40 -0800 Subject: [PATCH 25/38] :boom: Move children generation callback into control_bar From there, it cannot reference the variable `app`, but we likely do not need this and the callback can become just a function so we will deal with this later. --- app.py | 23 ----------------------- callbacks/control_bar.py | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/app.py b/app.py index 7b9538c..1f17230 100644 --- a/app.py +++ b/app.py @@ -8,9 +8,7 @@ from callbacks.image_viewer import * # noqa: F403, F401 from callbacks.segmentation import * # noqa: F403, F401 from components.control_bar import layout as control_bar_layout -from components.dash_component_editor import JSONParameterEditor from components.image_viewer import layout as image_viewer_layout -from utils.content_registry import models USER_NAME = os.getenv("USER_NAME") USER_PASSWORD = os.getenv("USER_PASSWORD") @@ -37,26 +35,5 @@ ], ) - -### automatic Dash gui callback ### -@callback( - Output("gui-layouts", "children"), - Input("model-list", "value"), -) -def update_gui_parameters(model_name): - data = models.models[model_name] - if data["gui_parameters"]: - item_list = JSONParameterEditor( - _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name - json_blob=models.remove_key_from_dict_list( - data["gui_parameters"], "comp_group" - ), - ) - item_list.init_callbacks(app) - return [html.H4("Model Parameters"), item_list] - else: - return [""] - - if __name__ == "__main__": app.run_server(debug=True) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9d7d438..0664f50 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -2,6 +2,7 @@ import os import random import time +import uuid import dash_mantine_components as dmc import plotly.express as px @@ -24,8 +25,10 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item +from components.dash_component_editor import JSONParameterEditor from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations +from utils.content_registry import models from utils.data_utils import tiled_dataset from utils.plot_utils import generate_notification, generate_notification_bg_icon_col @@ -912,3 +915,22 @@ def update_current_annotated_slices_values(all_classes): ] disabled = True if len(dropdown_values) == 0 else False return dropdown_values, disabled + + +@callback( + Output("gui-layouts", "children"), + Input("model-list", "value"), +) +def update_gui_parameters(model_name): + data = models.models[model_name] + if data["gui_parameters"]: + item_list = JSONParameterEditor( + _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name + json_blob=models.remove_key_from_dict_list( + data["gui_parameters"], "comp_group" + ), + ) + # item_list.init_callbacks(app) + return [html.H4("Model Parameters"), item_list] + else: + return [""] From 2c6a4b504127fe3086683535b0f99f509d5c3d94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 18:11:38 -0800 Subject: [PATCH 26/38] Fix remaining `flake8 warnings` E711 comparison to None should be 'if cond is None:' E265 block comment should start with '# ' F401 Unused import --- components/dash_component_editor.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 0ab0f14..dd9cd20 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,10 +1,4 @@ import base64 - -# import PIL.Image -import io -import re - -# noinspection PyUnresolvedReferences from inspect import _empty, signature from typing import Callable @@ -12,14 +6,6 @@ import dash_daq as daq from dash import ALL, Input, Output, State, dcc, html -# import plotly.express as px -# Procedural dash form generation - - -""" -{'name', 'title', 'value', 'type', -""" - class SimpleItem(dbc.Col): def __init__( @@ -32,7 +18,7 @@ def __init__( debounce=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dbc.Input( @@ -72,7 +58,7 @@ def __init__( visible=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dcc.Slider( @@ -103,7 +89,7 @@ def __init__( visible=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dcc.Dropdown( @@ -126,7 +112,7 @@ class RadioItem(dbc.Col): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dbc.RadioItems( @@ -149,7 +135,7 @@ class BoolItem(dbc.Col): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = daq.ToggleSwitch( @@ -181,7 +167,7 @@ def __init__( visible=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name if not (width.endswith("px") or width.endswith("%")): @@ -279,7 +265,6 @@ class JSONParameterEditor(ParameterEditor): "radio": RadioItem, "bool": BoolItem, "img": ImgItem, - #'graph': GraphItem, } def __init__(self, _id, json_blob, **kwargs): From 132a5ea982645083eb1e8128307305b9c1e92499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 19:56:00 -0800 Subject: [PATCH 27/38] Move `dcc.Store` elements from app file into control bar component --- app.py | 4 +--- components/control_bar.py | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 1f17230..d2ed8d2 100644 --- a/app.py +++ b/app.py @@ -2,7 +2,7 @@ import dash_auth import dash_mantine_components as dmc -from dash import Dash, dcc +from dash import Dash from callbacks.control_bar import * # noqa: F403, F401 from callbacks.image_viewer import * # noqa: F403, F401 @@ -30,8 +30,6 @@ children=[ control_bar_layout(), image_viewer_layout(), - dcc.Store(id="current-class-selection", data="#FFA200"), - dcc.Store(id="gui-components-values", data={}), ], ) diff --git a/components/control_bar.py b/components/control_bar.py index c8f6bf9..38f7e09 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -453,6 +453,9 @@ def layout(): }, className="add-class-btn", ), + dcc.Store( + id="current-class-selection", data="#FFA200" + ), dmc.Space(h=20), ], ), @@ -605,7 +608,7 @@ def layout(): id="model-configuration", children=[ _control_item( - "Model Selection", + "Model", "model-selector", dmc.Select( id="model-list", @@ -615,11 +618,12 @@ def layout(): if models.modelname_list[0] else None ), - placeholder="Select an model...", + placeholder="Select a model...", ), ), dmc.Space(h=25), html.Div(id="gui-layouts"), + dcc.Store(id="gui-components-values", data={}), dmc.Space(h=25), dmc.Center( dmc.Button( From b51c99803d63ba6b6868b16437b3d78d5abb2c3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 20:07:45 -0800 Subject: [PATCH 28/38] :bug: Fix (non-)use of path parameter in `Models` Additionally pretty-prints model file for easier editing --- assets/mode_description.json | 6 ---- assets/models.json | 69 ++++++++++++++++++++++++++++++++++++ utils/content_registry.py | 4 +-- 3 files changed, 71 insertions(+), 8 deletions(-) delete mode 100755 assets/mode_description.json create mode 100755 assets/models.json diff --git a/assets/mode_description.json b/assets/mode_description.json deleted file mode 100755 index 7c95a20..0000000 --- a/assets/mode_description.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "contents":[ - {"model_name": "random_forest", "version": "1.0.0", "type": "supervised", "user": "mlexchange team", "uri": "xxx", "application": ["classification", "segmentation"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-tree", "title": "Number of Trees", "param_key": "n_estimators", "value": "30"}, {"type": "int", "name": "tree-depth", "title": "Tree Depth", "param_key": "max_depth", "value": "8"}], "cmd": ["xxx"], "reference": "Adapted from Dash Plotly image segmentation example"}, - {"model_name": "kmeans", "version": "1.0.0", "type": "unsupervised", "user": "mlexchange team", "uri": "xxx", "application": ["segmentation", "clustering"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-cluster", "title": "Number of Clusters", "param_key": "n_clusters", "value": "2"}, {"type": "int", "name": "num-iter", "title": "Max Iteration", "param_key": "max_iter", "value": "300"}], "cmd": ["xxx", "xxxx"], "reference": "Nicholas Schwartz & Howard Yanxon"} - ] -} diff --git a/assets/models.json b/assets/models.json new file mode 100755 index 0000000..bf77000 --- /dev/null +++ b/assets/models.json @@ -0,0 +1,69 @@ +{ + "contents": [ + { + "model_name": "random_forest", + "version": "1.0.0", + "type": "supervised", + "user": "mlexchange team", + "uri": "xxx", + "application": [ + "classification", + "segmentation" + ], + "description": "xxx", + "gui_parameters": [ + { + "type": "int", + "name": "num-tree", + "title": "Number of Trees", + "param_key": "n_estimators", + "value": "30" + }, + { + "type": "int", + "name": "tree-depth", + "title": "Tree Depth", + "param_key": "max_depth", + "value": "8" + } + ], + "cmd": [ + "xxx" + ], + "reference": "Adapted from Dash Plotly image segmentation example" + }, + { + "model_name": "kmeans", + "version": "1.0.0", + "type": "unsupervised", + "user": "mlexchange team", + "uri": "xxx", + "application": [ + "segmentation", + "clustering" + ], + "description": "xxx", + "gui_parameters": [ + { + "type": "int", + "name": "num-cluster", + "title": "Number of Clusters", + "param_key": "n_clusters", + "value": "2" + }, + { + "type": "int", + "name": "num-iter", + "title": "Max Iteration", + "param_key": "max_iter", + "value": "300" + } + ], + "cmd": [ + "xxx", + "xxxx" + ], + "reference": "Nicholas Schwartz & Howard Yanxon" + } + ] +} diff --git a/utils/content_registry.py b/utils/content_registry.py index c441d98..ed9203c 100644 --- a/utils/content_registry.py +++ b/utils/content_registry.py @@ -3,9 +3,9 @@ class Models: - def __init__(self, modelfile_path="./assets/mode_description.json"): + def __init__(self, modelfile_path="./assets/models.json"): self.path = modelfile_path - f = open("./assets/mode_description.json") + f = open(self.path) self.contents = json.load(f)["contents"] self.modelname_list = [content["model_name"] for content in self.contents] From b3915e7f2f71a0ff2c1c12b932f0f5478d309f8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 5 Mar 2024 08:28:06 -0800 Subject: [PATCH 29/38] Elevate `_control_item ` function to class `ControlItem`, delete unused code --- callbacks/control_bar.py | 7 +-- components/control_bar.py | 35 +++-------- components/dash_component_editor.py | 92 +++++++---------------------- 3 files changed, 34 insertions(+), 100 deletions(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 0664f50..9b15112 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -921,16 +921,15 @@ def update_current_annotated_slices_values(all_classes): Output("gui-layouts", "children"), Input("model-list", "value"), ) -def update_gui_parameters(model_name): +def update_model_parameters(model_name): data = models.models[model_name] if data["gui_parameters"]: item_list = JSONParameterEditor( - _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name + _id={"type": str(uuid.uuid4())}, json_blob=models.remove_key_from_dict_list( data["gui_parameters"], "comp_group" ), ) - # item_list.init_callbacks(app) - return [html.H4("Model Parameters"), item_list] + return [item_list] else: return [""] diff --git a/components/control_bar.py b/components/control_bar.py index 38f7e09..4ca57e7 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -5,6 +5,7 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item +from components.dash_component_editor import ControlItem from constants import ANNOT_ICONS, KEYBINDS from utils.content_registry import models from utils.data_utils import tiled_dataset @@ -19,24 +20,6 @@ def _tooltip(text, children): ) -def _control_item(title, title_id, item): - """ - Returns a customized layout for a control item - """ - return dmc.Grid( - [ - dmc.Text( - title, - id=title_id, - size="sm", - style={"width": "100px", "margin": "auto", "paddingRight": "5px"}, - align="right", - ), - html.Div(item, style={"width": "265px", "margin": "auto"}), - ] - ) - - def _accordion_item(title, icon, value, children, id): """ Returns a customized layout for an accordion item @@ -79,7 +62,7 @@ def layout(): id="data-selection-controls", children=[ dmc.Space(h=5), - _control_item( + ControlItem( "Dataset", "image-selector", dmc.Grid( @@ -115,7 +98,7 @@ def layout(): ), ), dmc.Space(h=25), - _control_item( + ControlItem( "Slice 1", "image-selection-text", [ @@ -178,7 +161,7 @@ def layout(): ], ), dmc.Space(h=25), - _control_item( + ControlItem( _tooltip( "Jump to your annotated slices", "Annotated slices", @@ -208,7 +191,7 @@ def layout(): children=html.Div( [ dmc.Space(h=5), - _control_item( + ControlItem( "Brightness", "bightness-text", [ @@ -252,7 +235,7 @@ def layout(): ], ), dmc.Space(h=20), - _control_item( + ControlItem( "Contrast", "contrast-text", dmc.Grid( @@ -607,7 +590,7 @@ def layout(): "run-model", id="model-configuration", children=[ - _control_item( + ControlItem( "Model", "model-selector", dmc.Select( @@ -646,7 +629,7 @@ def layout(): styles={"trackLabel": {"cursor": "pointer"}}, ), dmc.Space(h=25), - _control_item( + ControlItem( "Results", "", dmc.Select( @@ -655,7 +638,7 @@ def layout(): ), ), dmc.Space(h=25), - _control_item( + ControlItem( "Opacity", "", dmc.Slider( diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index dd9cd20..5b232a0 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,12 +1,30 @@ -import base64 -from inspect import _empty, signature -from typing import Callable - import dash_bootstrap_components as dbc import dash_daq as daq +import dash_mantine_components as dmc from dash import ALL, Input, Output, State, dcc, html +class ControlItem(dmc.Grid): + """ + Customized layout for a control item + """ + + def __init__(self, title, title_id, item, **kwargs): + super(ControlItem, self).__init__( + [ + dmc.Text( + title, + id=title_id, + size="sm", + style={"width": "100px", "margin": "auto", "paddingRight": "5px"}, + align="right", + ), + html.Div(item, style={"width": "265px", "margin": "auto"}), + ], + **kwargs, + ) + + class SimpleItem(dbc.Col): def __init__( self, @@ -155,46 +173,6 @@ def __init__( ) -class ImgItem(dbc.Col): - def __init__( - self, - name, - src, - base_id, - title=None, - param_key=None, - width="100px", - visible=True, - **kwargs, - ): - if param_key is None: - param_key = name - - if not (width.endswith("px") or width.endswith("%")): - width = width + "px" - - self.label = dbc.Label(title) - - encoded_image = base64.b64encode(open(src, "rb").read()) - self.src = "data:image/png;base64,{}".format(encoded_image.decode()) - self.input_img = html.Img( - id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, - src=self.src, - style={"height": "auto", "width": width}, - **kwargs, - ) - - style = {} - if not visible: - style["display"] = "none" - - super(ImgItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input_img], - style=style, - ) - - class ParameterEditor(dbc.Form): type_map = { float: FloatItem, @@ -264,7 +242,6 @@ class JSONParameterEditor(ParameterEditor): "dropdown": DropdownItem, "radio": RadioItem, "bool": BoolItem, - "img": ImgItem, } def __init__(self, _id, json_blob, **kwargs): @@ -289,28 +266,3 @@ def build_children(self, values=None): children.append(item) return children - - -class KwargsEditor(ParameterEditor): - def __init__(self, instance_index, func: Callable, **kwargs): - self.func = func - self._instance_index = instance_index - - parameters = [ - {"name": name, "value": param.default} - for name, param in signature(func).parameters.items() - if param.default is not _empty - ] - - super(KwargsEditor, self).__init__( - dict(index=instance_index, type="kwargs-editor"), - parameters=parameters, - **kwargs, - ) - - def new_record(self): - return { - name: p.default - for name, p in signature(self.func).parameters.items() - if p.default is not _empty - } From 8b34636b5ea58c90bf1e954fb1c632a0e7168814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 5 Mar 2024 17:24:33 -0800 Subject: [PATCH 30/38] :sparkles: Add Dlsia model parameters --- app.py | 2 +- assets/models.json | 1083 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 1051 insertions(+), 34 deletions(-) diff --git a/app.py b/app.py index d2ed8d2..efda09c 100644 --- a/app.py +++ b/app.py @@ -34,4 +34,4 @@ ) if __name__ == "__main__": - app.run_server(debug=True) + app.run_server(host="0.0.0.0", port=8075, debug=True) diff --git a/assets/models.json b/assets/models.json index bf77000..245638e 100755 --- a/assets/models.json +++ b/assets/models.json @@ -1,69 +1,1086 @@ { "contents": [ { - "model_name": "random_forest", - "version": "1.0.0", + "model_name": "DSLIA MSDNet", + "version": "0.0.1", "type": "supervised", "user": "mlexchange team", - "uri": "xxx", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", "application": [ - "classification", "segmentation" ], - "description": "xxx", + "description": "MSDNets in DLSIA for image segmentation", "gui_parameters": [ { "type": "int", - "name": "num-tree", - "title": "Number of Trees", - "param_key": "n_estimators", - "value": "30" + "name": "layer_width", + "title": "Layers Width", + "param_key": "layer_width", + "value": 1, + "comp_group": "train_model" }, { "type": "int", - "name": "tree-depth", - "title": "Tree Depth", - "param_key": "max_depth", - "value": "8" + "name": "num_layers", + "title": "Number of Layers", + "param_key": "num_layers", + "value": 3, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "custom_dilation", + "title": "Custom Dilation", + "param_key": "custom_dilation", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "max_dilation", + "title": "Maximum Dilation", + "param_key": "max_dilation", + "value": 5, + "comp_group": "train_model" + }, + { + "type": "str", + "name": "dilation_array", + "title": "Dilation Array", + "param_key": "dilation_array", + "value": "[1, 2, 4]", + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "Number of epoch", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "ReLU", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": { + "0": "0", + "100": "100" + }, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_val", + "title": "Shuffle Training", + "param_key": "shuffle_val", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_inference", + "title": "Shuffle Inference", + "param_key": "shuffle_inference", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "prediction_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "prediction_model" } ], "cmd": [ - "xxx" + "python3 src/train_model.py", + "python3 src/segment.py" ], - "reference": "Adapted from Dash Plotly image segmentation example" + "reference": "https://dlsia.readthedocs.io/en/latest/" }, { - "model_name": "kmeans", - "version": "1.0.0", - "type": "unsupervised", + "model_name": "DSLIA TUNet", + "version": "0.0.1", + "type": "supervised", "user": "mlexchange team", - "uri": "xxx", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", "application": [ - "segmentation", - "clustering" + "segmentation" ], - "description": "xxx", + "description": "TUNet in DLSIA for image segmentation", "gui_parameters": [ { "type": "int", - "name": "num-cluster", - "title": "Number of Clusters", - "param_key": "n_clusters", - "value": "2" + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" }, { "type": "int", - "name": "num-iter", - "title": "Max Iteration", - "param_key": "max_iter", - "value": "300" + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "Number of epoch", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "ReLU", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": { + "0": "0", + "100": "100" + }, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_val", + "title": "Shuffle Training", + "param_key": "shuffle_val", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_inference", + "title": "Shuffle Inference", + "param_key": "shuffle_inference", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "prediction_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "DSLIA TUNet3+", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet3+ DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "carryover_channels", + "title": "Carryover Channels", + "param_key": "carryover_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "Number of epoch", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "ReLU", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": { + "0": "0", + "100": "100" + }, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_val", + "title": "Shuffle Training", + "param_key": "shuffle_val", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_inference", + "title": "Shuffle Inference", + "param_key": "shuffle_inference", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "prediction_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "prediction_model" } ], "cmd": [ - "xxx", - "xxxx" + "python3 src/train_model.py", + "python3 src/segment.py" ], - "reference": "Nicholas Schwartz & Howard Yanxon" + "reference": "https://dlsia.readthedocs.io/en/latest/" } ] } From fcc511f48deb509e761f237b45d715c51d606a49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 5 Mar 2024 17:32:36 -0800 Subject: [PATCH 31/38] Remove shuffling for validation and inference --- assets/models.json | 108 --------------------------------------------- 1 file changed, 108 deletions(-) diff --git a/assets/models.json b/assets/models.json index 245638e..c8bd7c2 100755 --- a/assets/models.json +++ b/assets/models.json @@ -305,24 +305,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_val", - "title": "Shuffle Training", - "param_key": "shuffle_val", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "train_model" - }, { "type": "slider", "name": "batch_size_val", @@ -334,24 +316,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_inference", - "title": "Shuffle Inference", - "param_key": "shuffle_inference", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "prediction_model" - }, { "type": "slider", "name": "batch_size_inference", @@ -657,24 +621,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_val", - "title": "Shuffle Training", - "param_key": "shuffle_val", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "train_model" - }, { "type": "slider", "name": "batch_size_val", @@ -686,24 +632,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_inference", - "title": "Shuffle Inference", - "param_key": "shuffle_inference", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "prediction_model" - }, { "type": "slider", "name": "batch_size_inference", @@ -1017,24 +945,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_val", - "title": "Shuffle Training", - "param_key": "shuffle_val", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "train_model" - }, { "type": "slider", "name": "batch_size_val", @@ -1046,24 +956,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_inference", - "title": "Shuffle Inference", - "param_key": "shuffle_inference", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "prediction_model" - }, { "type": "slider", "name": "batch_size_inference", From ed16b22ea817f882b46c76a94f74c0635f8751b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 10:53:03 -0800 Subject: [PATCH 32/38] Switch from Bootstrap to Mantine for generated control elements --- assets/models.json | 338 +++++++++++++++++++++------- components/dash_component_editor.py | 151 ++++++++----- 2 files changed, 361 insertions(+), 128 deletions(-) diff --git a/assets/models.json b/assets/models.json index c8bd7c2..df59436 100755 --- a/assets/models.json +++ b/assets/models.json @@ -14,7 +14,7 @@ { "type": "int", "name": "layer_width", - "title": "Layers Width", + "title": "Layer Width", "param_key": "layer_width", "value": 1, "comp_group": "train_model" @@ -28,21 +28,11 @@ "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "custom_dilation", "title": "Custom Dilation", "param_key": "custom_dilation", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": false, "comp_group": "train_model" }, { @@ -64,11 +54,25 @@ { "type": "slider", "name": "num_epochs", - "title": "Number of epoch", + "title": "# Epochs", "param_key": "num_epochs", "min": 1, "max": 1000, "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], "comp_group": "train_model" }, { @@ -77,7 +81,7 @@ "title": "Optimizer", "param_key": "optimizer", "value": "Adam", - "options": [ + "data": [ { "label": "Adadelta", "value": "Adadelta" @@ -131,7 +135,7 @@ "title": "Criterion", "param_key": "criterion", "value": "MSELoss", - "options": [ + "data": [ { "label": "L1Loss", "value": "L1Loss" @@ -241,7 +245,7 @@ "title": "Activation", "param_key": "activation", "value": "ReLU", - "options": [ + "data": [ { "label": "ReLU", "value": "ReLU" @@ -270,61 +274,115 @@ "max": 100, "step": 5, "value": 20, - "marks": { - "0": "0", - "100": "100" - }, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 50, + "label": "50%" + }, + { + "value": 100, + "label": "100%" + } + ], "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "shuffle_train", "title": "Shuffle Training", "param_key": "shuffle_train", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": true, "comp_group": "train_model" }, { "type": "slider", "name": "batch_size_train", - "title": "Training Batch Size", + "title": "Batch Size Training", "param_key": "batch_size_train", "min": 16, "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { "type": "slider", "name": "batch_size_val", - "title": "Validation Batch Size", + "title": "Batch Size Validation", "param_key": "batch_size_val", "min": 16, "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { "type": "slider", "name": "batch_size_inference", - "title": "Inference Batch Size", + "title": "Batch Size Inference", "param_key": "batch_size_inference", "min": 16, "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "prediction_model" } ], @@ -380,11 +438,25 @@ { "type": "slider", "name": "num_epochs", - "title": "Number of epoch", + "title": "# Epochs", "param_key": "num_epochs", "min": 1, "max": 1000, "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], "comp_group": "train_model" }, { @@ -393,7 +465,7 @@ "title": "Optimizer", "param_key": "optimizer", "value": "Adam", - "options": [ + "data": [ { "label": "Adadelta", "value": "Adadelta" @@ -447,7 +519,7 @@ "title": "Criterion", "param_key": "criterion", "value": "MSELoss", - "options": [ + "data": [ { "label": "L1Loss", "value": "L1Loss" @@ -557,7 +629,7 @@ "title": "Activation", "param_key": "activation", "value": "ReLU", - "options": [ + "data": [ { "label": "ReLU", "value": "ReLU" @@ -586,28 +658,24 @@ "max": 100, "step": 5, "value": 20, - "marks": { - "0": "0", - "100": "100" - }, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 100, + "label": "100%" + } + ], "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "shuffle_train", "title": "Shuffle Training", "param_key": "shuffle_train", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": true, "comp_group": "train_model" }, { @@ -619,6 +687,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -630,6 +716,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -641,6 +745,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "prediction_model" } ], @@ -704,11 +826,25 @@ { "type": "slider", "name": "num_epochs", - "title": "Number of epoch", + "title": "# Epochs", "param_key": "num_epochs", "min": 1, "max": 1000, "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], "comp_group": "train_model" }, { @@ -717,7 +853,7 @@ "title": "Optimizer", "param_key": "optimizer", "value": "Adam", - "options": [ + "data": [ { "label": "Adadelta", "value": "Adadelta" @@ -771,7 +907,7 @@ "title": "Criterion", "param_key": "criterion", "value": "MSELoss", - "options": [ + "data": [ { "label": "L1Loss", "value": "L1Loss" @@ -881,7 +1017,7 @@ "title": "Activation", "param_key": "activation", "value": "ReLU", - "options": [ + "data": [ { "label": "ReLU", "value": "ReLU" @@ -910,28 +1046,24 @@ "max": 100, "step": 5, "value": 20, - "marks": { - "0": "0", - "100": "100" - }, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 100, + "label": "100%" + } + ], "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "shuffle_train", "title": "Shuffle Training", "param_key": "shuffle_train", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": true, "comp_group": "train_model" }, { @@ -943,6 +1075,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -954,6 +1104,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -965,6 +1133,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "prediction_model" } ], diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 5b232a0..186c987 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,7 +1,6 @@ import dash_bootstrap_components as dbc -import dash_daq as daq import dash_mantine_components as dmc -from dash import ALL, Input, Output, State, dcc, html +from dash import ALL, Input, Output, State, html class ControlItem(dmc.Grid): @@ -9,9 +8,9 @@ class ControlItem(dmc.Grid): Customized layout for a control item """ - def __init__(self, title, title_id, item, **kwargs): + def __init__(self, title, title_id, item, style={}): super(ControlItem, self).__init__( - [ + children=[ dmc.Text( title, id=title_id, @@ -21,67 +20,93 @@ def __init__(self, title, title_id, item, **kwargs): ), html.Div(item, style={"width": "265px", "margin": "auto"}), ], - **kwargs, + style=style, ) -class SimpleItem(dbc.Col): +class NumberItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, - type="number", - debounce=True, + visible=True, **kwargs, ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dbc.Input( - type=type, - debounce=debounce, - id={**base_id, "name": name, "param_key": param_key}, + self.input = dmc.NumberInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, **kwargs, ) - super(SimpleItem, self).__init__(children=[self.label, self.input]) - + style = {} + if not visible: + style["display"] = "none" -class FloatItem(SimpleItem): - pass + super(NumberItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "title", + }, + item=self.input, + style=style, + ) -class IntItem(SimpleItem): - def __init__(self, *args, **kwargs): - if "min" not in kwargs: - kwargs["min"] = -9007199254740991 - super(IntItem, self).__init__(*args, step=1, **kwargs) +class StrItem(ControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.TextInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + style = {} + if not visible: + style["display"] = "none" -class StrItem(SimpleItem): - def __init__(self, *args, **kwargs): - super(StrItem, self).__init__(*args, type="text", **kwargs) + super(StrItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "title", + }, + item=self.input, + style=style, + ) -class SliderItem(dbc.Col): +class SliderItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, - debounce=True, visible=True, **kwargs, ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dcc.Slider( + self.input = dmc.Slider( id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, - tooltip={"placement": "bottom", "always_visible": True}, + labelAlwaysOn=False, **kwargs, ) @@ -90,27 +115,31 @@ def __init__( style["display"] = "none" super(SliderItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input], + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "title", + }, + item=self.input, style=style, ) -class DropdownItem(dbc.Col): +class DropdownItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, - debounce=True, visible=True, **kwargs, ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dcc.Dropdown( + self.input = dmc.Select( id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, **kwargs, ) @@ -120,20 +149,32 @@ def __init__( style["display"] = "none" super(DropdownItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input], + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, style=style, ) -class RadioItem(dbc.Col): +class RadioItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dbc.RadioItems( + + options = [ + dmc.Radio(option["label"], value=option["value"]) + for option in kwargs["options"] + ] + kwargs.pop("options", None) + self.input = dmc.RadioGroup( + options, id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, **kwargs, ) @@ -143,24 +184,30 @@ def __init__( style["display"] = "none" super(RadioItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input], + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, style=style, ) -class BoolItem(dbc.Col): +class BoolItem(dmc.Grid): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = daq.ToggleSwitch( + + self.input = dmc.Switch( id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + label=title, **kwargs, ) - self.output_label = dbc.Label("False/True") style = {} if not visible: @@ -168,15 +215,15 @@ def __init__( super(BoolItem, self).__init__( id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input, self.output_label], + children=[self.input, dmc.Space(h=25)], style=style, ) class ParameterEditor(dbc.Form): type_map = { - float: FloatItem, - int: IntItem, + float: NumberItem, + int: NumberItem, str: StrItem, } @@ -235,8 +282,8 @@ def build_children(self, values=None): class JSONParameterEditor(ParameterEditor): type_map = { - "float": FloatItem, - "int": IntItem, + "float": NumberItem, + "int": NumberItem, "str": StrItem, "slider": SliderItem, "dropdown": DropdownItem, From 5b177c467645d5351d9555035cb78e37a5b909c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 13:34:42 -0800 Subject: [PATCH 33/38] Clean up `Models`, `JSONParameterEditor`, and parameter retrieval --- assets/models.json | 2 +- callbacks/control_bar.py | 23 +++---- callbacks/segmentation.py | 18 ++--- components/control_bar.py | 7 +- components/dash_component_editor.py | 100 ++++++++-------------------- utils/content_registry.py | 31 --------- utils/data_utils.py | 42 ++++++++++++ 7 files changed, 90 insertions(+), 133 deletions(-) delete mode 100644 utils/content_registry.py diff --git a/assets/models.json b/assets/models.json index df59436..f8f8b6e 100755 --- a/assets/models.json +++ b/assets/models.json @@ -22,7 +22,7 @@ { "type": "int", "name": "num_layers", - "title": "Number of Layers", + "title": "# Layers", "param_key": "num_layers", "value": 3, "comp_group": "train_model" diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9b15112..ae3ba28 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -25,11 +25,10 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item -from components.dash_component_editor import JSONParameterEditor +from components.dash_component_editor import ParameterItems from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations -from utils.content_registry import models -from utils.data_utils import tiled_dataset +from utils.data_utils import models, tiled_dataset from utils.plot_utils import generate_notification, generate_notification_bg_icon_col # TODO - temporary local file path and user for annotation saving and exporting @@ -918,18 +917,16 @@ def update_current_annotated_slices_values(all_classes): @callback( - Output("gui-layouts", "children"), + Output("model-parameters", "children"), Input("model-list", "value"), ) def update_model_parameters(model_name): - data = models.models[model_name] - if data["gui_parameters"]: - item_list = JSONParameterEditor( - _id={"type": str(uuid.uuid4())}, - json_blob=models.remove_key_from_dict_list( - data["gui_parameters"], "comp_group" - ), + model = models[model_name] + if model["gui_parameters"]: + # TODO: Retain old parameters if they exist + item_list = ParameterItems( + _id={"type": str(uuid.uuid4())}, json_blob=model["gui_parameters"] ) - return [item_list] + return item_list else: - return [""] + return html.Div("Model has no parameters") diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index c477bcb..2029b18 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -7,7 +7,7 @@ from dash import ALL, Input, Output, State, callback, no_update from dash.exceptions import PreventUpdate -from utils.data_utils import tiled_dataset +from utils.data_utils import extract_parameters_from_html, tiled_dataset MODE = os.getenv("MODE", "") @@ -58,14 +58,14 @@ @callback( Output("output-details", "children"), Output("submitted-job-id", "data"), - Output("gui-components-values", "data"), + Output("model-parameter-values", "data"), Input("run-model", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), - State("gui-layouts", "children"), + State("model-parameters", "children"), ) -def run_job(n_clicks, global_store, all_annotations, project_name, children): +def run_job(n_clicks, global_store, all_annotations, project_name, model_parameters): """ This callback collects parameters from the UI and submits a job to the computing api. If the app is run from "dev" mode, then only a placeholder job_uid will be created. @@ -76,14 +76,8 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): """ input_params = {} if n_clicks: - if len(children) >= 2: - params = children[1] - for param in params["props"]["children"]: - key = param["props"]["children"][1]["props"]["id"]["param_key"] - value = param["props"]["children"][1]["props"]["value"] - input_params[key] = value - - # return the input values in dictionary and saved to dcc.Store "gui-components-values" + input_params = extract_parameters_from_html(model_parameters) + # return the input values in dictionary and save to the model parameter store print(f"input_param:\n{input_params}") if MODE == "dev": job_uid = str(uuid.uuid4()) diff --git a/components/control_bar.py b/components/control_bar.py index 4ca57e7..75a46ea 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -7,8 +7,7 @@ from components.annotation_class import annotation_class_item from components.dash_component_editor import ControlItem from constants import ANNOT_ICONS, KEYBINDS -from utils.content_registry import models -from utils.data_utils import tiled_dataset +from utils.data_utils import models, tiled_dataset def _tooltip(text, children): @@ -605,8 +604,8 @@ def layout(): ), ), dmc.Space(h=25), - html.Div(id="gui-layouts"), - dcc.Store(id="gui-components-values", data={}), + html.Div(id="model-parameters"), + dcc.Store(id="model-parameter-values", data={}), dmc.Space(h=25), dmc.Center( dmc.Button( diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 186c987..8404820 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,6 +1,6 @@ import dash_bootstrap_components as dbc import dash_mantine_components as dmc -from dash import ALL, Input, Output, State, html +from dash import html class ControlItem(dmc.Grid): @@ -51,7 +51,7 @@ def __init__( **base_id, "name": name, "param_key": param_key, - "layer": "title", + "layer": "label", }, item=self.input, style=style, @@ -85,7 +85,7 @@ def __init__( **base_id, "name": name, "param_key": param_key, - "layer": "title", + "layer": "label", }, item=self.input, style=style, @@ -120,7 +120,7 @@ def __init__( **base_id, "name": name, "param_key": param_key, - "layer": "title", + "layer": "label", }, item=self.input, style=style, @@ -196,7 +196,7 @@ def __init__( ) -class BoolItem(dmc.Grid): +class BoolItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): @@ -214,45 +214,33 @@ def __init__( style["display"] = "none" super(BoolItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.input, dmc.Space(h=25)], + title="", # title is already in the switch + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, style=style, ) -class ParameterEditor(dbc.Form): +class ParameterItems(dbc.Form): type_map = { - float: NumberItem, - int: NumberItem, - str: StrItem, + "float": NumberItem, + "int": NumberItem, + "str": StrItem, + "slider": SliderItem, + "dropdown": DropdownItem, + "radio": RadioItem, + "bool": BoolItem, } - def __init__(self, _id, parameters, **kwargs): - self._parameters = parameters - - super(ParameterEditor, self).__init__( - id=_id, children=[], className="kwarg-editor", **kwargs - ) - self.children = self.build_children() - - def init_callbacks(self, app): - app.callback( - Output(self.id, "n_submit"), - Input({**self.id, "name": ALL}, "value"), - State(self.id, "n_submit"), - ) - - for child in self.children: - if hasattr(child, "init_callbacks"): - child.init_callbacks(app) - - @property - def values(self): - return {param["name"]: param.get("value", None) for param in self._parameters} - - @property - def parameters(self): - return {param["name"]: param for param in self._parameters} + def __init__(self, _id, json_blob, values=None): + super(ParameterItems, self).__init__(id=_id, children=[]) + self._json_blob = json_blob + self.children = self.build_children(values=values) def _determine_type(self, parameter_dict): if "type" in parameter_dict: @@ -266,49 +254,17 @@ def _determine_type(self, parameter_dict): f"No item type could be determined for this parameter: {parameter_dict}" ) - def build_children(self, values=None): - children = [] - for parameter_dict in self._parameters: - parameter_dict = parameter_dict.copy() - if values and parameter_dict["name"] in values: - parameter_dict["value"] = values[parameter_dict["name"]] - type = self._determine_type(parameter_dict) - parameter_dict.pop("type", None) - item = self.type_map[type](**parameter_dict, base_id=self.id) - children.append(item) - - return children - - -class JSONParameterEditor(ParameterEditor): - type_map = { - "float": NumberItem, - "int": NumberItem, - "str": StrItem, - "slider": SliderItem, - "dropdown": DropdownItem, - "radio": RadioItem, - "bool": BoolItem, - } - - def __init__(self, _id, json_blob, **kwargs): - super(ParameterEditor, self).__init__( - id=_id, children=[], className="kwarg-editor", **kwargs - ) - self._json_blob = json_blob - self.children = self.build_children() - def build_children(self, values=None): children = [] for json_record in self._json_blob: - ... - # build a parameter dict from self.json_blob - ... + # Build a parameter dict from self.json_blob type = json_record.get("type", self._determine_type(json_record)) json_record = json_record.copy() if values and json_record["name"] in values: json_record["value"] = values[json_record["name"]] json_record.pop("type", None) + if "comp_group" in json_record: + json_record.pop("comp_group", None) item = self.type_map[type](**json_record, base_id=self.id) children.append(item) diff --git a/utils/content_registry.py b/utils/content_registry.py deleted file mode 100644 index ed9203c..0000000 --- a/utils/content_registry.py +++ /dev/null @@ -1,31 +0,0 @@ -import json -from copy import deepcopy - - -class Models: - def __init__(self, modelfile_path="./assets/models.json"): - self.path = modelfile_path - f = open(self.path) - - self.contents = json.load(f)["contents"] - self.modelname_list = [content["model_name"] for content in self.contents] - self.models = {} - - for i, n in enumerate(self.modelname_list): - self.models[n] = self.contents[i] - - @staticmethod - def remove_key_from_dict_list(data, key): - new_data = [] - for item in data: - if key in item: - new_item = deepcopy(item) - new_item.pop(key) - new_data.append(new_item) - else: - new_data.append(item) - - return new_data - - -models = Models() diff --git a/utils/data_utils.py b/utils/data_utils.py index bb54d29..43a24e9 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -159,3 +159,45 @@ def save_annotations_data(self, global_store, all_annotations, project_name): tiled_dataset = TiledDataLoader() + + +class Models: + def __init__(self, modelfile_path="./assets/models.json"): + self.path = modelfile_path + f = open(self.path) + + contents = json.load(f)["contents"] + self.modelname_list = [content["model_name"] for content in contents] + self.models = {} + + for i, n in enumerate(self.modelname_list): + self.models[n] = contents[i] + + def __getitem__(self, key): + try: + return self.models[key] + except KeyError: + raise KeyError(f"A model with name {key} does not exist.") + + +models = Models() + + +def extract_parameters_from_html(model_parameters_html): + """ + Extracts parameters from the children component of a + """ + input_params = {} + for param in model_parameters_html["props"]["children"]: + # param["props"]["children"][0] is the label + # param["props"]["children"][1] is the input + parameter_container = param["props"]["children"][1] + # The achtual parameter item is the first and only child of the parameter container + parameter_item = parameter_container["props"]["children"]["props"] + key = parameter_item["id"]["param_key"] + if "value" in parameter_item: + value = parameter_item["value"] + elif "checked" in parameter_item: + value = parameter_item["checked"] + input_params[key] = value + return input_params From 790d4efc35bb96ce2995d82c37fa27ded0df893f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 13:41:17 -0800 Subject: [PATCH 34/38] Rename `dash_component_editor` to better represent new structure --- callbacks/control_bar.py | 2 +- components/control_bar.py | 2 +- components/{dash_component_editor.py => parameter_items.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename components/{dash_component_editor.py => parameter_items.py} (100%) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index ae3ba28..6011b29 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -25,7 +25,7 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item -from components.dash_component_editor import ParameterItems +from components.parameter_items import ParameterItems from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations from utils.data_utils import models, tiled_dataset diff --git a/components/control_bar.py b/components/control_bar.py index 75a46ea..2e09354 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -5,7 +5,7 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item -from components.dash_component_editor import ControlItem +from components.parameter_items import ControlItem from constants import ANNOT_ICONS, KEYBINDS from utils.data_utils import models, tiled_dataset diff --git a/components/dash_component_editor.py b/components/parameter_items.py similarity index 100% rename from components/dash_component_editor.py rename to components/parameter_items.py From 6193c53a0d221798edda5d334bc75f1c3b2b2e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 15:51:07 -0800 Subject: [PATCH 35/38] Check if `image_shapes` was initialized --- utils/annotations.py | 4 ++-- utils/data_utils.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index c2eafcd..e7875f1 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -11,7 +11,7 @@ class Annotations: - def __init__(self, annotation_store, global_store): + def __init__(self, annotation_store, image_shape): if annotation_store: slices = [] for annotation_class in annotation_store: @@ -49,7 +49,7 @@ def __init__(self, annotation_store, global_store): self.annotation_classes = annotation_classes self.annotations = annotations self.annotations_hash = self.get_annotations_hash() - self.image_shape = global_store["image_shapes"][0] + self.image_shape = image_shape def get_annotations(self): return self.annotations diff --git a/utils/data_utils.py b/utils/data_utils.py index 87cc635..3311019 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -168,7 +168,18 @@ def save_annotations_data(self, global_store, all_annotations, project_name): """ Transforms annotations data to a pixelated mask and outputs to the Tiled server """ - annotations = Annotations(all_annotations, global_store) + if "image_shapes" in global_store: + image_shape = global_store["image_shapes"][0] + else: + print("Global store was not filled.") + data_shape = ( + tiled_datasets.get_data_shape_by_name(project_name) + if project_name + else None + ) + image_shape = (data_shape[1], data_shape[2]) + + annotations = Annotations(all_annotations, image_shape) # TODO: Check sparse status, it may be worthwhile to store the mask as a sparse array # if our machine learning models can handle sparse arrays annotations.create_annotation_mask(sparse=False) From d09a45a6e27e43b0d19be69bd463b0607a5528d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 16:31:38 -0800 Subject: [PATCH 36/38] Give the generated parameters some space --- components/parameter_items.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/components/parameter_items.py b/components/parameter_items.py index 8404820..f165125 100644 --- a/components/parameter_items.py +++ b/components/parameter_items.py @@ -41,7 +41,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -75,7 +75,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -110,7 +110,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 15px 0px"} if not visible: style["display"] = "none" @@ -144,7 +144,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -179,7 +179,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -209,7 +209,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" From 8e6e4883da2b118b32db3cb43edcec45fa97dbff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 16:32:15 -0800 Subject: [PATCH 37/38] :whale: Add missing environment variables --- docker-compose.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index c0a7f2b..5f1a7d5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,12 @@ services: environment: DATA_TILED_URI: '${DATA_TILED_URI}' DATA_TILED_API_KEY: '${DATA_TILED_API_KEY}' + MASK_TILED_URI: '${MASK_TILED_URI}' + MASK_TILED_API_KEY: '${TILED_API_KEY}' + SEG_TILED_URI: '${SEG_TILED_URI}' + SEG_TILED_API_KEY: '${SEG_TILED_API_KEY}' + USER_NAME: '${USER_NAME}' + USER_PASSWORD: '${USER_PASSWORD}' volumes: - ./app.py:/app/app.py - ./constants.py:/app/constants.py From 02732a622565ae15e1c42988b7adcce7b39a2a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 16:35:53 -0800 Subject: [PATCH 38/38] Change default activation and learning rate step --- assets/models.json | 18 +++++++++++++++--- components/control_bar.py | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/assets/models.json b/assets/models.json index f8f8b6e..3b22cb1 100755 --- a/assets/models.json +++ b/assets/models.json @@ -237,6 +237,10 @@ "title": "Learning Rate", "param_key": "learning_rate", "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, "comp_group": "train_model" }, { @@ -244,7 +248,7 @@ "name": "activation", "title": "Activation", "param_key": "activation", - "value": "ReLU", + "value": "Sigmoid", "data": [ { "label": "ReLU", @@ -621,6 +625,10 @@ "title": "Learning Rate", "param_key": "learning_rate", "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, "comp_group": "train_model" }, { @@ -628,7 +636,7 @@ "name": "activation", "title": "Activation", "param_key": "activation", - "value": "ReLU", + "value": "Sigmoid", "data": [ { "label": "ReLU", @@ -1009,6 +1017,10 @@ "title": "Learning Rate", "param_key": "learning_rate", "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, "comp_group": "train_model" }, { @@ -1016,7 +1028,7 @@ "name": "activation", "title": "Activation", "param_key": "activation", - "value": "ReLU", + "value": "Sigmoid", "data": [ { "label": "ReLU", diff --git a/components/control_bar.py b/components/control_bar.py index c1e538b..4809e5f 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -603,7 +603,7 @@ def layout(): placeholder="Select a model...", ), ), - dmc.Space(h=25), + dmc.Space(h=15), html.Div(id="model-parameters"), dcc.Store(id="model-parameter-values", data={}), dmc.Space(h=25),