From 7d75edcadfd188bebc16c18dc524426dd304774c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 27 Feb 2024 20:37:36 -0800 Subject: [PATCH 01/17] Rename `tiled_dataset.data` to `tiled_dataset.data_client` Consistent with naming in other projects --- callbacks/control_bar.py | 2 +- utils/data_utils.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9d7d438..f0a1d6b 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -853,7 +853,7 @@ def populate_classification_results( image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled ): if refresh_tiled: - tiled_dataset.refresh_data() + tiled_dataset.refresh_data_client() data_options = [ item for item in tiled_dataset.get_data_project_names() if "seg" not in item diff --git a/utils/data_utils.py b/utils/data_utils.py index bb54d29..c542590 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -22,14 +22,14 @@ def __init__( ): self.data_tiled_uri = data_tiled_uri self.data_tiled_api_key = data_tiled_api_key - self.data = from_uri( + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), ) - def refresh_data(self): - self.data = from_uri( + def refresh_data_client(self): + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), @@ -42,8 +42,8 @@ def get_data_project_names(self): """ project_names = [ project - for project in list(self.data) - if isinstance(self.data[project], (Container, ArrayClient)) + for project in list(self.data_client) + if isinstance(self.data_client[project], (Container, ArrayClient)) ] return project_names @@ -53,7 +53,7 @@ def get_data_sequence_by_name(self, project_name): but can also be additionally encapsulated in a folder, multiple container or in a .nxs file. We make use of specs to figure out the path to the 3d data. """ - project_client = self.data[project_name] + project_client = self.data_client[project_name] # If the project directly points to an array, directly return it if isinstance(project_client, ArrayClient): return project_client From c97d589828fdb2fc24d521850befe064b1776570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 27 Feb 2024 21:08:34 -0800 Subject: [PATCH 02/17] Set up data loader for input and results and add separate mask handler for exporting and loading mask, to replace all annotation loading functionality through the JSON file --- callbacks/control_bar.py | 15 ++++++------- callbacks/image_viewer.py | 10 ++++----- callbacks/segmentation.py | 4 ++-- components/control_bar.py | 4 ++-- utils/data_utils.py | 44 ++++++++++++++++++++++++++++++++------- 5 files changed, 53 insertions(+), 24 deletions(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index f0a1d6b..f5c208b 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -26,7 +26,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import generate_notification, generate_notification_bg_icon_col # TODO - temporary local file path and user for annotation saving and exporting @@ -758,7 +758,7 @@ def populate_load_annotations_dropdown_menu_options(modal_opened, image_src): if not modal_opened: raise PreventUpdate - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) if not data: @@ -804,10 +804,10 @@ def load_and_apply_selected_annotations(selected_annotation, image_src, img_idx) )["index"] # TODO : when quering from the server, load (data) for user, source, time - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) - data = tiled_dataset.DEV_filter_json_data_by_timestamp( + data = tiled_masks.DEV_filter_json_data_by_timestamp( data, str(selected_annotation_timestamp) ) data = data[0]["data"] @@ -853,10 +853,10 @@ def populate_classification_results( image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled ): if refresh_tiled: - tiled_dataset.refresh_data_client() + tiled_datasets.refresh_data_client() data_options = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] results = [] value = None @@ -872,9 +872,10 @@ def populate_classification_results( disabled_toggle = False disabled_slider = slider_enabled else: + # TODO: Match by mask uid instead of image_src results = [ item - for item in tiled_dataset.get_data_project_names() + for item in tiled_results.get_data_project_names() if ("seg" in item and image_src in item) ] if results: diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index a864640..48c6507 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -16,7 +16,7 @@ from dash.exceptions import PreventUpdate from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import ( create_viewfinder, downscale_view, @@ -108,7 +108,7 @@ def render_image( if image_idx: image_idx -= 1 # slider starts at 1, so subtract 1 to get the correct index - tf = tiled_dataset.get_data_sequence_by_name(project_name)[image_idx] + tf = tiled_datasets.get_data_sequence_by_name(project_name)[image_idx] if toggle_seg_result: # if toggle is true and overlay exists already (2 images in data) this will # be handled in hide_show_segmentation_overlay callback @@ -117,8 +117,8 @@ def render_image( and ctx.triggered_id == "show-result-overlay-toggle" ): return [dash.no_update] * 7 + ["hidden"] - if str(image_idx + 1) in tiled_dataset.get_annotated_segmented_results(): - result = tiled_dataset.get_data_sequence_by_name(seg_result_selection)[ + if str(image_idx + 1) in tiled_masks.get_annotated_segmented_results(): + result = tiled_results.get_data_sequence_by_name(seg_result_selection)[ image_idx ] else: @@ -485,7 +485,7 @@ def update_slider_values(project_name, annotation_store): """ # Retrieve data shape if project_name is valid and points to a 3d array data_shape = ( - tiled_dataset.get_data_shape_by_name(project_name) if project_name else None + tiled_datasets.get_data_shape_by_name(project_name) if project_name else None ) disable_slider = data_shape is None if not disable_slider: diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 69a206c..40f47b6 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -7,7 +7,7 @@ from dash import ALL, Input, Output, State, callback, no_update from dash.exceptions import PreventUpdate -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_masks MODE = os.getenv("MODE", "") @@ -84,7 +84,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name): ) else: - tiled_dataset.save_annotations_data( + tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) job_submitted = requests.post( diff --git a/components/control_bar.py b/components/control_bar.py index b4ab2b0..007f856 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -6,7 +6,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets def _tooltip(text, children): @@ -62,7 +62,7 @@ def layout(): Returns the layout for the control panel in the app UI """ DATA_OPTIONS = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] return drawer_section( dmc.Stack( diff --git a/utils/data_utils.py b/utils/data_utils.py index c542590..c51feb8 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -14,6 +14,10 @@ DATA_TILED_URI = os.getenv("DATA_TILED_URI") DATA_TILED_API_KEY = os.getenv("DATA_TILED_API_KEY") +MASK_TILED_URI = os.getenv("MASK_TILED_URI") +MASK_TILED_API_KEY = os.getenv("MASK_TILED_API_KEY") +SEG_TILED_URI = os.getenv("SEG_TILED_URI") +SEG_TILED_API_KEY = os.getenv("SEG_TILED_API_KEY") class TiledDataLoader: @@ -84,6 +88,23 @@ def get_data_shape_by_name(self, project_name): return project_container.shape return None + +class TiledMaskHandler: + """ + This class is used to handle the masks that are generated from the annotations. + """ + + def __init__( + self, mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY + ): + self.mask_tiled_uri = mask_tiled_uri + self.mask_tiled_api_key = mask_tiled_api_key + self.mask_client = from_uri( + self.mask_tiled_uri, + api_key=self.mask_tiled_api_key, + timeout=httpx.Timeout(30.0), + ) + @staticmethod def get_annotated_segmented_results(json_file_path="exported_annotation_data.json"): annotated_slices = [] @@ -128,7 +149,8 @@ def DEV_load_exported_json_data(file_path, USER_NAME, PROJECT_NAME): def DEV_filter_json_data_by_timestamp(data, timestamp): return [data for data in data if data["time"] == timestamp] - def save_annotations_data(self, global_store, all_annotations, project_name): + @staticmethod + def save_annotations_data(global_store, all_annotations, project_name): """ Transforms annotations data to a pixelated mask and outputs to the Tiled server @@ -144,13 +166,9 @@ def save_annotations_data(self, global_store, all_annotations, project_name): # Get raw images associated with each annotated slice img_idx = list(metadata.keys()) - img = self.data[project_name] - raw = [] - for idx in img_idx: - ar = img[int(idx)] - raw.append(ar) + metadata["mask_idx"] = img_idx + metadata["project_name"] = project_name try: - raw = np.stack(raw) mask = np.stack(mask) except ValueError: return "No annotations to process." @@ -158,4 +176,14 @@ def save_annotations_data(self, global_store, all_annotations, project_name): return -tiled_dataset = TiledDataLoader() +tiled_datasets = TiledDataLoader( + data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY +) + +tiled_masks = TiledMaskHandler( + mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY +) + +tiled_results = TiledDataLoader( + data_tiled_uri=SEG_TILED_URI, data_tiled_api_key=SEG_TILED_API_KEY +) From 98284371fd1197a282cb975649f4c73a974dafe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 27 Feb 2024 22:24:15 -0800 Subject: [PATCH 03/17] Generate overlay colormap based on selected colors Resolves #169 --- callbacks/image_viewer.py | 17 ++++++----------- utils/plot_utils.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index 48c6507..3f02a78 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -21,6 +21,7 @@ create_viewfinder, downscale_view, generate_notification, + generate_segmentation_colormap, get_view_finder_max_min, resize_canvas, ) @@ -127,20 +128,14 @@ def render_image( tf = np.zeros((500, 500)) fig = px.imshow(tf, binary_string=True) if toggle_seg_result and result is not None: - unique_segmentation_values = np.unique(result) - normalized_range = np.linspace( - 0, 1, len(unique_segmentation_values) - ) # heatmap requires a normalized range - color_list = ( - px.colors.qualitative.Plotly - ) # TODO placeholder - replace with user defined classess - colorscale = [ - [normalized_range[i], color_list[i % len(color_list)]] - for i in range(len(unique_segmentation_values)) - ] + colorscale, max_class_id = generate_segmentation_colormap( + all_annotation_class_store + ) fig.add_trace( go.Heatmap( z=result, + zmin=-0.5, + zmax=max_class_id + 0.5, colorscale=colorscale, showscale=False, ) diff --git a/utils/plot_utils.py b/utils/plot_utils.py index 55a7dd1..9e70224 100644 --- a/utils/plot_utils.py +++ b/utils/plot_utils.py @@ -1,6 +1,7 @@ import random import dash_mantine_components as dmc +import numpy as np import plotly.express as px import plotly.graph_objects as go from dash_iconify import DashIconify @@ -170,6 +171,37 @@ def resize_canvas(h, w, H, W, figure): return figure, image_center_coor +def generate_segmentation_colormap(all_annotations_data): + """ + Generates a discrete colormap for the segmentation overlay + based on the color information per class. + + The discrete colormap maps values from 0 to 1 to colors, + but is meant to be applied to images with class ids as values, + with these varying from 0 to the number of classes - 1. + To account for numerical inaccuracies, it is best to center the plot range + around the class ids, by setting cmin=-0.5 and cmax=max_class_id+0.5. + """ + max_class_id = max( + [annotation_class["class_id"] for annotation_class in all_annotations_data] + ) + # heatmap requires a normalized range from 0 to 1 + # We need to specify color for at least the range limits (0 and 1) + # as well for every additional class + # due to using zero-based class ids, we need to add 2 to the max class id + normalized_range = np.linspace(0, 1, max_class_id + 2) + color_list = [ + annotation_class["color"] for annotation_class in all_annotations_data + ] + colorscale = [ + [normalized_range[i + j], color_list[i % len(color_list)]] + for i in range(0, normalized_range.size - 1) + for j in range(2) + ] + + return colorscale, max_class_id + + def generate_notification(title, color, icon, message=""): return dmc.Notification( title=title, From 3fa5801b089abde6d56ee4607a454364fc7fa1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 28 Feb 2024 20:03:42 -0800 Subject: [PATCH 04/17] :boom: Change class labels to be 0-based instead of 1-based --- components/annotation_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/annotation_class.py b/components/annotation_class.py index a855609..d97b836 100644 --- a/components/annotation_class.py +++ b/components/annotation_class.py @@ -33,7 +33,7 @@ def annotation_class_item(class_color, class_label, existing_ids, data=None): annotations = data["annotations"] is_visible = data["is_visible"] else: - class_id = 1 if not existing_ids else max(existing_ids) + 1 + class_id = 0 if not existing_ids else max(existing_ids) + 1 annotations = {} is_visible = True class_color_transparent = class_color + "50" From dab2b920a39b73e57d3f6368ee78ea18c6bf4d37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 28 Feb 2024 20:12:11 -0800 Subject: [PATCH 05/17] :bug: Use mapped index when accessing segmented frames with annotations --- callbacks/image_viewer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index 3f02a78..ea29e75 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -118,9 +118,12 @@ def render_image( and ctx.triggered_id == "show-result-overlay-toggle" ): return [dash.no_update] * 7 + ["hidden"] - if str(image_idx + 1) in tiled_masks.get_annotated_segmented_results(): + annotation_indices = tiled_masks.get_annotated_segmented_results() + if str(image_idx + 1) in annotation_indices: + # Will not return an error since we already checked if image_idx+1 is in the list + mapped_index = annotation_indices.index(str(image_idx + 1)) result = tiled_results.get_data_sequence_by_name(seg_result_selection)[ - image_idx + mapped_index ] else: result = None From 570ce78a3d6840b0aef6f6b20ec686eb1ae388dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Thu, 29 Feb 2024 11:23:42 -0800 Subject: [PATCH 06/17] Save masks to Tiled on job submission --- callbacks/segmentation.py | 4 +++- utils/data_utils.py | 27 +++++++++++++++++++-------- utils/plot_utils.py | 4 +++- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 40f47b6..bef1ad3 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -74,6 +74,9 @@ def run_job(n_clicks, global_store, all_annotations, project_name): """ if n_clicks: if MODE == "dev": + tiled_masks.save_annotations_data( + global_store, all_annotations, project_name + ) job_uid = str(uuid.uuid4()) return ( dmc.Text( @@ -83,7 +86,6 @@ def run_job(n_clicks, global_store, all_annotations, project_name): job_uid, ) else: - tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) diff --git a/utils/data_utils.py b/utils/data_utils.py index c51feb8..1f0df06 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -18,6 +18,7 @@ MASK_TILED_API_KEY = os.getenv("MASK_TILED_API_KEY") SEG_TILED_URI = os.getenv("SEG_TILED_URI") SEG_TILED_API_KEY = os.getenv("SEG_TILED_API_KEY") +USER_NAME = os.getenv("USER_NAME", "user1") class TiledDataLoader: @@ -149,16 +150,13 @@ def DEV_load_exported_json_data(file_path, USER_NAME, PROJECT_NAME): def DEV_filter_json_data_by_timestamp(data, timestamp): return [data for data in data if data["time"] == timestamp] - @staticmethod - def save_annotations_data(global_store, all_annotations, project_name): + def save_annotations_data(self, global_store, all_annotations, project_name): """ - Transforms annotations data to a pixelated mask and outputs to - the Tiled server - - # TODO: Save data to Tiled server after transformation + Transforms annotations data to a pixelated mask and outputs to the Tiled server """ annotations = Annotations(all_annotations, global_store) - annotations.create_annotation_mask(sparse=True) # TODO: Check sparse status + # TODO: Check sparse status + annotations.create_annotation_mask(sparse=False) # Get metadata and annotation data metadata = annotations.get_annotations() @@ -173,7 +171,20 @@ def save_annotations_data(global_store, all_annotations, project_name): except ValueError: return "No annotations to process." - return + # Store the mask in the Tiled server under /username/project_name/uid/mask" + container_keys = [USER_NAME, project_name] + last_container = self.mask_client + for key in container_keys: + if key not in last_container.keys(): + last_container = last_container.create_container(key=key) + else: + last_container = last_container[key] + # Add json metadata to a container with a uuid as key + # (uuid will be created by Tiled, since no key is given) + last_container = last_container.create_container(metadata=metadata) + mask = last_container.write_array(key="mask", array=mask) + # print("Created a mask array with the following uri: ", mask.uri) + return mask.uri tiled_datasets = TiledDataLoader( diff --git a/utils/plot_utils.py b/utils/plot_utils.py index 9e70224..f42ccb3 100644 --- a/utils/plot_utils.py +++ b/utils/plot_utils.py @@ -193,12 +193,14 @@ def generate_segmentation_colormap(all_annotations_data): color_list = [ annotation_class["color"] for annotation_class in all_annotations_data ] + # We need to repeat each color twice, to create discrete color segments + # This loop contains the range limits 0 and 1 once, + # but every other value in between twice colorscale = [ [normalized_range[i + j], color_list[i % len(color_list)]] for i in range(0, normalized_range.size - 1) for j in range(2) ] - return colorscale, max_class_id From 228cd31407ca41f6b6590e8dc6398450cfe67de2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Fri, 1 Mar 2024 15:13:37 -0800 Subject: [PATCH 07/17] Change values of unlabeled pixels to -1 --- utils/annotations.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index 765ac73..c69dccd 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -84,7 +84,9 @@ def create_annotation_mask(self, sparse=False): for slice_idx, slice_data in self.annotations.items(): image_height = slice_data[0]["image_shape"][0] image_width = slice_data[0]["image_shape"][1] - slice_mask = np.zeros([image_height, image_width], dtype=np.uint8) + slice_mask = np.full( + [image_height, image_width], fill_value=-1, dtype=np.int8 + ) for shape in slice_data: if shape["type"] == "Closed Freeform": shape_mask = ShapeConversion.closed_path_to_array( @@ -100,7 +102,7 @@ def create_annotation_mask(self, sparse=False): ) else: continue - slice_mask[shape_mask > 0] = shape_mask[shape_mask > 0] + slice_mask[shape_mask >= 0] = shape_mask[shape_mask >= 0] annotation_mask.append(slice_mask) if sparse: @@ -154,7 +156,7 @@ def ellipse_to_array(self, svg_data, image_shape, mask_class): c_radius = abs(svg_data["y0"] - svg_data["y1"]) / 2 # Vertical radius # Create mask and draw ellipse - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.ellipse( cy, cx, c_radius, r_radius ) # Vertical radius first, then horizontal @@ -181,7 +183,7 @@ def rectangle_to_array(self, svg_data, image_shape, mask_class): y1 = max(min(y1, image_height - 1), 0) # # Draw the rectangle - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1)) # Convert coordinates to integers @@ -217,8 +219,9 @@ def closed_path_to_array(self, svg_data, image_shape, mask_class): is_inside = polygon_path.contains_points(points) # Reshape the result back into the 2D shape - mask = is_inside.reshape(image_height, image_width).astype(int) + mask = is_inside.reshape(image_height, image_width).astype(np.int8) # Set the class value for the pixels inside the polygon mask[mask == 1] = mask_class + mask[mask == 0] = -1 return mask From 57a41bd100ddb26adb0f0a19f987c63651bbd115 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Fri, 1 Mar 2024 15:18:37 -0800 Subject: [PATCH 08/17] Restructure meta-data generation to remove redundancies --- utils/annotations.py | 34 ++++++++++++++++++++++++---------- utils/data_utils.py | 22 +++++++++++++++------- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index c69dccd..a830412 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -17,24 +17,34 @@ def __init__(self, annotation_store, global_store): slices = set(slices) annotations = {key: [] for key in slices} + all_class_labels = [ + annotation_class["class_id"] for annotation_class in annotation_store + ] + annotation_classes = {} + for annotation_class in annotation_store: + condensed_id = str(all_class_labels.index(annotation_class["class_id"])) + annotation_classes[condensed_id] = { + "label": annotation_class["label"], + "color": annotation_class["color"], + } for image_idx, slice_data in annotation_class["annotations"].items(): for shape in slice_data: self._set_annotation_type(shape) self._set_annotation_svg(shape) annotation = { - "id": annotation_class["class_id"], + "class_id": condensed_id, "type": self.annotation_type, - "class": annotation_class["label"], - # TODO: This is the same across all images in a dataset - "image_shape": global_store["image_shapes"][0], "svg_data": self.svg_data, } annotations[image_idx].append(annotation) else: - annotations = [] + annotations = None + annotation_classes = None + self.annotation_classes = annotation_classes self.annotations = annotations + self.image_shape = global_store["image_shapes"][0] def get_annotations(self): return self.annotations @@ -42,6 +52,9 @@ def get_annotations(self): def get_annotation_mask(self): return self.annotation_mask + def get_annotation_classes(self): + return self.annotation_classes + def get_annotation_mask_as_bytes(self): buffer = io.BytesIO() zip_buffer = io.BytesIO() @@ -81,24 +94,25 @@ def create_annotation_mask(self, sparse=False): self.sparse = sparse annotation_mask = [] + image_height = self.image_shape[0] + image_width = self.image_shape[1] + for slice_idx, slice_data in self.annotations.items(): - image_height = slice_data[0]["image_shape"][0] - image_width = slice_data[0]["image_shape"][1] slice_mask = np.full( [image_height, image_width], fill_value=-1, dtype=np.int8 ) for shape in slice_data: if shape["type"] == "Closed Freeform": shape_mask = ShapeConversion.closed_path_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Rectangle": shape_mask = ShapeConversion.rectangle_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Ellipse": shape_mask = ShapeConversion.ellipse_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) else: continue diff --git a/utils/data_utils.py b/utils/data_utils.py index 1f0df06..8bc1690 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -155,23 +155,31 @@ def save_annotations_data(self, global_store, all_annotations, project_name): Transforms annotations data to a pixelated mask and outputs to the Tiled server """ annotations = Annotations(all_annotations, global_store) - # TODO: Check sparse status + # TODO: Check sparse status, it may be worthwhile to store the mask as a sparse array + # if our machine learning models can handle sparse arrays annotations.create_annotation_mask(sparse=False) # Get metadata and annotation data - metadata = annotations.get_annotations() + annnotations_per_slice = annotations.get_annotations() + annotation_classes = annotations.get_annotation_classes() + + metadata = { + "classes": annotation_classes, + "unlabeled_class_id": -1, + "annotations": annnotations_per_slice, + "image_shape": global_store["image_shapes"][0], + "project_name": project_name, + "mask_idx": list(annnotations_per_slice.keys()), + } + print("Metadata: ", metadata) mask = annotations.get_annotation_mask() - # Get raw images associated with each annotated slice - img_idx = list(metadata.keys()) - metadata["mask_idx"] = img_idx - metadata["project_name"] = project_name try: mask = np.stack(mask) except ValueError: return "No annotations to process." - # Store the mask in the Tiled server under /username/project_name/uid/mask" + # Store the mask in the Tiled server under /username/project_name/uuid/mask" container_keys = [USER_NAME, project_name] last_container = self.mask_client for key in container_keys: From 9337fb78f1ddacc43b270c0c29c653c720a6d6d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Fri, 1 Mar 2024 15:26:11 -0800 Subject: [PATCH 09/17] Sort mask slices by index to ensure consistent order --- utils/annotations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/utils/annotations.py b/utils/annotations.py index a830412..1526bbf 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -15,7 +15,9 @@ def __init__(self, annotation_store, global_store): for annotation_class in annotation_store: slices.extend(list(annotation_class["annotations"].keys())) slices = set(slices) - annotations = {key: [] for key in slices} + # Slices need to be sorted to ensure that the exported mask slices + # have the same order as the original data set + annotations = {key: [] for key in sorted(slices, key=int)} all_class_labels = [ annotation_class["class_id"] for annotation_class in annotation_store From 70a1146b0150ee9729dc6b236cebe9d203bcb187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 15:00:39 -0800 Subject: [PATCH 10/17] Add `.dockerignore` ignoring `.env` --- .dockerignore | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..1fafbd5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.env +.git +.gitignore From 55d90e3f764f0bba38199f328be23546c6bdbf93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 15:51:19 -0800 Subject: [PATCH 11/17] Add example script that saves a mask plot based on a given uri --- examples/plot_mask.py | 72 +++++++++++++++++++++++++++++++++++++++ examples/requirements.txt | 2 ++ 2 files changed, 74 insertions(+) create mode 100644 examples/plot_mask.py create mode 100644 examples/requirements.txt diff --git a/examples/plot_mask.py b/examples/plot_mask.py new file mode 100644 index 0000000..eee8d8c --- /dev/null +++ b/examples/plot_mask.py @@ -0,0 +1,72 @@ +import os +import sys + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +from dotenv import load_dotenv +from matplotlib.colors import ListedColormap +from tiled.client import from_uri + + +def plot_mask(mask_uri, api_key, slice_idx, output_path): + """ + Saves a plot of a given mask using metadata information such as class colors and labels. + It is assumed that the given uri is the uri of a mask container, with associated meta data + and a mask array under the key "mask". + The given slice index references mask slices, not the original data. + However, the printed slice index in the figure will be the index of the original data. + """ + # Retrieve mask and metadata + mask_client = from_uri(mask_uri, api_key=api_key) + mask = mask_client["mask"][slice_idx] + meta_data = mask_client.metadata + mask_idx = meta_data["mask_idx"] + + if slice_idx > len(mask_idx): + raise ValueError("Slice index out of range") + + class_meta_data = meta_data["classes"] + max_class_id = len(class_meta_data.keys()) - 1 + + colors = [ + annotation_class["color"] for _, annotation_class in class_meta_data.items() + ] + labels = [ + annotation_class["label"] for _, annotation_class in class_meta_data.items() + ] + plt.imshow( + mask, + cmap=ListedColormap(colors), + vmin=-0.5, + vmax=max_class_id + 0.5, + ) + plt.title(meta_data["project_name"] + ", slice: " + mask_idx[slice_idx]) + + # create a patch for every color + patches = [ + mpatches.Patch(color=colors[i], label=labels[i]) for i in range(len(labels)) + ] + # Plot legend below the image + plt.legend( + handles=patches, loc="upper center", bbox_to_anchor=(0.5, -0.075), ncol=3 + ) + plt.savefig(output_path, bbox_inches="tight") + + +if __name__ == "__main__": + """ + Example usage: python3 plot_mask.py http://localhost:8000/api/v1/metadata/mlex_store/mlex_store/username/dataset/uuid + """ + + load_dotenv() + api_key = os.getenv("MASK_TILED_API_KEY", None) + + if len(sys.argv) < 2: + print("Usage: python3 plot_mask.py [slice_idx] [output_path]") + sys.exit(1) + + mask_uri = sys.argv[1] + slice_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0 + output_path = sys.argv[3] if len(sys.argv) > 3 else "mask.png" + + plot_mask(mask_uri, api_key, slice_idx, output_path) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..618af5e --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,2 @@ +matplotlib +tiled[client] From 3a951b9934aeb440e59a1245facdcb514a9ecea8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 18:19:53 -0800 Subject: [PATCH 12/17] Include `data_uri` in mask meta data --- utils/data_utils.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/utils/data_utils.py b/utils/data_utils.py index 8bc1690..a580b28 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -89,6 +89,20 @@ def get_data_shape_by_name(self, project_name): return project_container.shape return None + def get_data_uri_by_name(self, project_name): + """ + Retrieve uri of the data + """ + project_container = self.get_data_sequence_by_name(project_name) + if project_container: + return project_container.uri + return None + + +tiled_datasets = TiledDataLoader( + data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY +) + class TiledMaskHandler: """ @@ -169,6 +183,7 @@ def save_annotations_data(self, global_store, all_annotations, project_name): "annotations": annnotations_per_slice, "image_shape": global_store["image_shapes"][0], "project_name": project_name, + "data_uri": tiled_datasets.get_data_uri_by_name(project_name), "mask_idx": list(annnotations_per_slice.keys()), } print("Metadata: ", metadata) @@ -191,13 +206,9 @@ def save_annotations_data(self, global_store, all_annotations, project_name): # (uuid will be created by Tiled, since no key is given) last_container = last_container.create_container(metadata=metadata) mask = last_container.write_array(key="mask", array=mask) - # print("Created a mask array with the following uri: ", mask.uri) - return mask.uri + return mask.uri -tiled_datasets = TiledDataLoader( - data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY -) tiled_masks = TiledMaskHandler( mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY From 0b88839e0c67881686439659d384040cc20e2d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:31:40 -0800 Subject: [PATCH 13/17] :bug: Fix export of free-forms for class with index 0 --- utils/annotations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index 1526bbf..7877448 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -237,7 +237,7 @@ def closed_path_to_array(self, svg_data, image_shape, mask_class): # Reshape the result back into the 2D shape mask = is_inside.reshape(image_height, image_width).astype(np.int8) - # Set the class value for the pixels inside the polygon - mask[mask == 1] = mask_class + # Set the class value for the pixels inside the polygon, -1 for the rest mask[mask == 0] = -1 + mask[mask == 1] = mask_class return mask From b709dbdae055c774eae85aaf4f0396a7607b19f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:38:31 -0800 Subject: [PATCH 14/17] :sparkles: Use hash as key for masks instead of randomly generated uuid --- utils/annotations.py | 9 +++++++++ utils/data_utils.py | 26 ++++++++++++++++---------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index 7877448..6978c5f 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -1,4 +1,6 @@ +import hashlib import io +import json import zipfile import numpy as np @@ -46,6 +48,7 @@ def __init__(self, annotation_store, global_store): self.annotation_classes = annotation_classes self.annotations = annotations + self.annotations_hash = self.get_annotations_hash() self.image_shape = global_store["image_shapes"][0] def get_annotations(self): @@ -57,6 +60,12 @@ def get_annotation_mask(self): def get_annotation_classes(self): return self.annotation_classes + def get_annotations_hash(self): + hash_object = hashlib.md5() + hash_object.update(json.dumps(self.annotations, sort_keys=True).encode()) + hash_object.update(json.dumps(self.annotation_classes, sort_keys=True).encode()) + return hash_object.hexdigest() + def get_annotation_mask_as_bytes(self): buffer = io.BytesIO() zip_buffer = io.BytesIO() diff --git a/utils/data_utils.py b/utils/data_utils.py index a580b28..a60b919 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -176,17 +176,18 @@ def save_annotations_data(self, global_store, all_annotations, project_name): # Get metadata and annotation data annnotations_per_slice = annotations.get_annotations() annotation_classes = annotations.get_annotation_classes() + annotations_hash = annotations.get_annotations_hash() metadata = { - "classes": annotation_classes, - "unlabeled_class_id": -1, - "annotations": annnotations_per_slice, - "image_shape": global_store["image_shapes"][0], "project_name": project_name, "data_uri": tiled_datasets.get_data_uri_by_name(project_name), + "image_shape": global_store["image_shapes"][0], "mask_idx": list(annnotations_per_slice.keys()), + "classes": annotation_classes, + "annotations": annnotations_per_slice, + "unlabeled_class_id": -1, } - print("Metadata: ", metadata) + mask = annotations.get_annotation_mask() try: @@ -202,12 +203,17 @@ def save_annotations_data(self, global_store, all_annotations, project_name): last_container = last_container.create_container(key=key) else: last_container = last_container[key] - # Add json metadata to a container with a uuid as key - # (uuid will be created by Tiled, since no key is given) - last_container = last_container.create_container(metadata=metadata) - mask = last_container.write_array(key="mask", array=mask) - return mask.uri + # Add json metadata to a container with the md5 hash as key + # if a mask with that hash does not already exist + if annotations_hash not in last_container.keys(): + last_container = last_container.create_container( + key=annotations_hash, metadata=metadata + ) + mask = last_container.write_array(key="mask", array=mask) + else: + last_container = last_container[annotations_hash] + return last_container.uri tiled_masks = TiledMaskHandler( From ee37e3581ecb5115aa2f10540cd33d24749a00e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:40:11 -0800 Subject: [PATCH 15/17] Check for unlabled label name and add color for unlabled pixels in plotting --- callbacks/control_bar.py | 4 ++++ examples/plot_mask.py | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index f5c208b..d0da110 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -332,6 +332,10 @@ def open_annotation_class_modal( disable_class_creation = True error_msg.append("Label Already in Use!") error_msg.append(html.Br()) + if new_label == "Unlabeled": + disable_class_creation = True + error_msg.append("Label name cannot be 'Unlabeled'") + error_msg.append(html.Br()) if new_color in current_colors: disable_class_creation = True error_msg.append("Color Already in use!") diff --git a/examples/plot_mask.py b/examples/plot_mask.py index eee8d8c..fa4735d 100644 --- a/examples/plot_mask.py +++ b/examples/plot_mask.py @@ -19,6 +19,7 @@ def plot_mask(mask_uri, api_key, slice_idx, output_path): # Retrieve mask and metadata mask_client = from_uri(mask_uri, api_key=api_key) mask = mask_client["mask"][slice_idx] + meta_data = mask_client.metadata mask_idx = meta_data["mask_idx"] @@ -34,10 +35,14 @@ def plot_mask(mask_uri, api_key, slice_idx, output_path): labels = [ annotation_class["label"] for _, annotation_class in class_meta_data.items() ] + # Add color for unlabeled pixels + colors = ["#D3D3D3"] + colors + labels = ["Unlabeled"] + labels + plt.imshow( mask, cmap=ListedColormap(colors), - vmin=-0.5, + vmin=-1.5, vmax=max_class_id + 0.5, ) plt.title(meta_data["project_name"] + ", slice: " + mask_idx[slice_idx]) From 212530d7cf5832938fb56517fbb9783841502d9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Sat, 2 Mar 2024 19:41:17 -0800 Subject: [PATCH 16/17] Capture `mask_uri` for job submission --- callbacks/segmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index bef1ad3..c7c8d40 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -74,19 +74,19 @@ def run_job(n_clicks, global_store, all_annotations, project_name): """ if n_clicks: if MODE == "dev": - tiled_masks.save_annotations_data( + mask_uri = tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) job_uid = str(uuid.uuid4()) return ( dmc.Text( - f"Workflow has been succesfully submitted with uid: {job_uid}", + f"Workflow has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}", size="sm", ), job_uid, ) else: - tiled_masks.save_annotations_data( + mask_uri = tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) job_submitted = requests.post( From f0d6a428ee4e20f1e2ca3c1036ecd710118f88c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 12:33:47 -0800 Subject: [PATCH 17/17] Use `canonicaljson` instead of `json.dumps(..., sort_keys=True)` for hashing --- requirements.txt | 1 + utils/annotations.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7a2fbcc..6eca7ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ scipy dash-extensions==1.0.1 dash-bootstrap-components==1.5.0 dash_auth==2.0.0 +canonicaljson diff --git a/utils/annotations.py b/utils/annotations.py index 6978c5f..c2eafcd 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -1,8 +1,8 @@ import hashlib import io -import json import zipfile +import canonicaljson import numpy as np import scipy.sparse as sp from matplotlib.path import Path @@ -62,8 +62,8 @@ def get_annotation_classes(self): def get_annotations_hash(self): hash_object = hashlib.md5() - hash_object.update(json.dumps(self.annotations, sort_keys=True).encode()) - hash_object.update(json.dumps(self.annotation_classes, sort_keys=True).encode()) + hash_object.update(canonicaljson.encode_canonical_json(self.annotations)) + hash_object.update(canonicaljson.encode_canonical_json(self.annotation_classes)) return hash_object.hexdigest() def get_annotation_mask_as_bytes(self):