diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..1fafbd5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.env +.git +.gitignore diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9d7d438..d0da110 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -26,7 +26,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import generate_notification, generate_notification_bg_icon_col # TODO - temporary local file path and user for annotation saving and exporting @@ -332,6 +332,10 @@ def open_annotation_class_modal( disable_class_creation = True error_msg.append("Label Already in Use!") error_msg.append(html.Br()) + if new_label == "Unlabeled": + disable_class_creation = True + error_msg.append("Label name cannot be 'Unlabeled'") + error_msg.append(html.Br()) if new_color in current_colors: disable_class_creation = True error_msg.append("Color Already in use!") @@ -758,7 +762,7 @@ def populate_load_annotations_dropdown_menu_options(modal_opened, image_src): if not modal_opened: raise PreventUpdate - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) if not data: @@ -804,10 +808,10 @@ def load_and_apply_selected_annotations(selected_annotation, image_src, img_idx) )["index"] # TODO : when quering from the server, load (data) for user, source, time - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) - data = tiled_dataset.DEV_filter_json_data_by_timestamp( + data = tiled_masks.DEV_filter_json_data_by_timestamp( data, str(selected_annotation_timestamp) ) data = data[0]["data"] @@ -853,10 +857,10 @@ def populate_classification_results( image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled ): if refresh_tiled: - tiled_dataset.refresh_data() + tiled_datasets.refresh_data_client() data_options = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] results = [] value = None @@ -872,9 +876,10 @@ def populate_classification_results( disabled_toggle = False disabled_slider = slider_enabled else: + # TODO: Match by mask uid instead of image_src results = [ item - for item in tiled_dataset.get_data_project_names() + for item in tiled_results.get_data_project_names() if ("seg" in item and image_src in item) ] if results: diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index a864640..ea29e75 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -16,11 +16,12 @@ from dash.exceptions import PreventUpdate from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import ( create_viewfinder, downscale_view, generate_notification, + generate_segmentation_colormap, get_view_finder_max_min, resize_canvas, ) @@ -108,7 +109,7 @@ def render_image( if image_idx: image_idx -= 1 # slider starts at 1, so subtract 1 to get the correct index - tf = tiled_dataset.get_data_sequence_by_name(project_name)[image_idx] + tf = tiled_datasets.get_data_sequence_by_name(project_name)[image_idx] if toggle_seg_result: # if toggle is true and overlay exists already (2 images in data) this will # be handled in hide_show_segmentation_overlay callback @@ -117,9 +118,12 @@ def render_image( and ctx.triggered_id == "show-result-overlay-toggle" ): return [dash.no_update] * 7 + ["hidden"] - if str(image_idx + 1) in tiled_dataset.get_annotated_segmented_results(): - result = tiled_dataset.get_data_sequence_by_name(seg_result_selection)[ - image_idx + annotation_indices = tiled_masks.get_annotated_segmented_results() + if str(image_idx + 1) in annotation_indices: + # Will not return an error since we already checked if image_idx+1 is in the list + mapped_index = annotation_indices.index(str(image_idx + 1)) + result = tiled_results.get_data_sequence_by_name(seg_result_selection)[ + mapped_index ] else: result = None @@ -127,20 +131,14 @@ def render_image( tf = np.zeros((500, 500)) fig = px.imshow(tf, binary_string=True) if toggle_seg_result and result is not None: - unique_segmentation_values = np.unique(result) - normalized_range = np.linspace( - 0, 1, len(unique_segmentation_values) - ) # heatmap requires a normalized range - color_list = ( - px.colors.qualitative.Plotly - ) # TODO placeholder - replace with user defined classess - colorscale = [ - [normalized_range[i], color_list[i % len(color_list)]] - for i in range(len(unique_segmentation_values)) - ] + colorscale, max_class_id = generate_segmentation_colormap( + all_annotation_class_store + ) fig.add_trace( go.Heatmap( z=result, + zmin=-0.5, + zmax=max_class_id + 0.5, colorscale=colorscale, showscale=False, ) @@ -485,7 +483,7 @@ def update_slider_values(project_name, annotation_store): """ # Retrieve data shape if project_name is valid and points to a 3d array data_shape = ( - tiled_dataset.get_data_shape_by_name(project_name) if project_name else None + tiled_datasets.get_data_shape_by_name(project_name) if project_name else None ) disable_slider = data_shape is None if not disable_slider: diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 69a206c..c7c8d40 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -7,7 +7,7 @@ from dash import ALL, Input, Output, State, callback, no_update from dash.exceptions import PreventUpdate -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_masks MODE = os.getenv("MODE", "") @@ -74,17 +74,19 @@ def run_job(n_clicks, global_store, all_annotations, project_name): """ if n_clicks: if MODE == "dev": + mask_uri = tiled_masks.save_annotations_data( + global_store, all_annotations, project_name + ) job_uid = str(uuid.uuid4()) return ( dmc.Text( - f"Workflow has been succesfully submitted with uid: {job_uid}", + f"Workflow has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}", size="sm", ), job_uid, ) else: - - tiled_dataset.save_annotations_data( + mask_uri = tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) job_submitted = requests.post( diff --git a/components/annotation_class.py b/components/annotation_class.py index a855609..d97b836 100644 --- a/components/annotation_class.py +++ b/components/annotation_class.py @@ -33,7 +33,7 @@ def annotation_class_item(class_color, class_label, existing_ids, data=None): annotations = data["annotations"] is_visible = data["is_visible"] else: - class_id = 1 if not existing_ids else max(existing_ids) + 1 + class_id = 0 if not existing_ids else max(existing_ids) + 1 annotations = {} is_visible = True class_color_transparent = class_color + "50" diff --git a/components/control_bar.py b/components/control_bar.py index b4ab2b0..007f856 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -6,7 +6,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets def _tooltip(text, children): @@ -62,7 +62,7 @@ def layout(): Returns the layout for the control panel in the app UI """ DATA_OPTIONS = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] return drawer_section( dmc.Stack( diff --git a/examples/plot_mask.py b/examples/plot_mask.py new file mode 100644 index 0000000..fa4735d --- /dev/null +++ b/examples/plot_mask.py @@ -0,0 +1,77 @@ +import os +import sys + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +from dotenv import load_dotenv +from matplotlib.colors import ListedColormap +from tiled.client import from_uri + + +def plot_mask(mask_uri, api_key, slice_idx, output_path): + """ + Saves a plot of a given mask using metadata information such as class colors and labels. + It is assumed that the given uri is the uri of a mask container, with associated meta data + and a mask array under the key "mask". + The given slice index references mask slices, not the original data. + However, the printed slice index in the figure will be the index of the original data. + """ + # Retrieve mask and metadata + mask_client = from_uri(mask_uri, api_key=api_key) + mask = mask_client["mask"][slice_idx] + + meta_data = mask_client.metadata + mask_idx = meta_data["mask_idx"] + + if slice_idx > len(mask_idx): + raise ValueError("Slice index out of range") + + class_meta_data = meta_data["classes"] + max_class_id = len(class_meta_data.keys()) - 1 + + colors = [ + annotation_class["color"] for _, annotation_class in class_meta_data.items() + ] + labels = [ + annotation_class["label"] for _, annotation_class in class_meta_data.items() + ] + # Add color for unlabeled pixels + colors = ["#D3D3D3"] + colors + labels = ["Unlabeled"] + labels + + plt.imshow( + mask, + cmap=ListedColormap(colors), + vmin=-1.5, + vmax=max_class_id + 0.5, + ) + plt.title(meta_data["project_name"] + ", slice: " + mask_idx[slice_idx]) + + # create a patch for every color + patches = [ + mpatches.Patch(color=colors[i], label=labels[i]) for i in range(len(labels)) + ] + # Plot legend below the image + plt.legend( + handles=patches, loc="upper center", bbox_to_anchor=(0.5, -0.075), ncol=3 + ) + plt.savefig(output_path, bbox_inches="tight") + + +if __name__ == "__main__": + """ + Example usage: python3 plot_mask.py http://localhost:8000/api/v1/metadata/mlex_store/mlex_store/username/dataset/uuid + """ + + load_dotenv() + api_key = os.getenv("MASK_TILED_API_KEY", None) + + if len(sys.argv) < 2: + print("Usage: python3 plot_mask.py [slice_idx] [output_path]") + sys.exit(1) + + mask_uri = sys.argv[1] + slice_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0 + output_path = sys.argv[3] if len(sys.argv) > 3 else "mask.png" + + plot_mask(mask_uri, api_key, slice_idx, output_path) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..618af5e --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,2 @@ +matplotlib +tiled[client] diff --git a/requirements.txt b/requirements.txt index 7a2fbcc..6eca7ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ scipy dash-extensions==1.0.1 dash-bootstrap-components==1.5.0 dash_auth==2.0.0 +canonicaljson diff --git a/utils/annotations.py b/utils/annotations.py index 765ac73..c2eafcd 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -1,6 +1,8 @@ +import hashlib import io import zipfile +import canonicaljson import numpy as np import scipy.sparse as sp from matplotlib.path import Path @@ -15,26 +17,39 @@ def __init__(self, annotation_store, global_store): for annotation_class in annotation_store: slices.extend(list(annotation_class["annotations"].keys())) slices = set(slices) - annotations = {key: [] for key in slices} + # Slices need to be sorted to ensure that the exported mask slices + # have the same order as the original data set + annotations = {key: [] for key in sorted(slices, key=int)} + + all_class_labels = [ + annotation_class["class_id"] for annotation_class in annotation_store + ] + annotation_classes = {} for annotation_class in annotation_store: + condensed_id = str(all_class_labels.index(annotation_class["class_id"])) + annotation_classes[condensed_id] = { + "label": annotation_class["label"], + "color": annotation_class["color"], + } for image_idx, slice_data in annotation_class["annotations"].items(): for shape in slice_data: self._set_annotation_type(shape) self._set_annotation_svg(shape) annotation = { - "id": annotation_class["class_id"], + "class_id": condensed_id, "type": self.annotation_type, - "class": annotation_class["label"], - # TODO: This is the same across all images in a dataset - "image_shape": global_store["image_shapes"][0], "svg_data": self.svg_data, } annotations[image_idx].append(annotation) else: - annotations = [] + annotations = None + annotation_classes = None + self.annotation_classes = annotation_classes self.annotations = annotations + self.annotations_hash = self.get_annotations_hash() + self.image_shape = global_store["image_shapes"][0] def get_annotations(self): return self.annotations @@ -42,6 +57,15 @@ def get_annotations(self): def get_annotation_mask(self): return self.annotation_mask + def get_annotation_classes(self): + return self.annotation_classes + + def get_annotations_hash(self): + hash_object = hashlib.md5() + hash_object.update(canonicaljson.encode_canonical_json(self.annotations)) + hash_object.update(canonicaljson.encode_canonical_json(self.annotation_classes)) + return hash_object.hexdigest() + def get_annotation_mask_as_bytes(self): buffer = io.BytesIO() zip_buffer = io.BytesIO() @@ -81,26 +105,29 @@ def create_annotation_mask(self, sparse=False): self.sparse = sparse annotation_mask = [] + image_height = self.image_shape[0] + image_width = self.image_shape[1] + for slice_idx, slice_data in self.annotations.items(): - image_height = slice_data[0]["image_shape"][0] - image_width = slice_data[0]["image_shape"][1] - slice_mask = np.zeros([image_height, image_width], dtype=np.uint8) + slice_mask = np.full( + [image_height, image_width], fill_value=-1, dtype=np.int8 + ) for shape in slice_data: if shape["type"] == "Closed Freeform": shape_mask = ShapeConversion.closed_path_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Rectangle": shape_mask = ShapeConversion.rectangle_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Ellipse": shape_mask = ShapeConversion.ellipse_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) else: continue - slice_mask[shape_mask > 0] = shape_mask[shape_mask > 0] + slice_mask[shape_mask >= 0] = shape_mask[shape_mask >= 0] annotation_mask.append(slice_mask) if sparse: @@ -154,7 +181,7 @@ def ellipse_to_array(self, svg_data, image_shape, mask_class): c_radius = abs(svg_data["y0"] - svg_data["y1"]) / 2 # Vertical radius # Create mask and draw ellipse - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.ellipse( cy, cx, c_radius, r_radius ) # Vertical radius first, then horizontal @@ -181,7 +208,7 @@ def rectangle_to_array(self, svg_data, image_shape, mask_class): y1 = max(min(y1, image_height - 1), 0) # # Draw the rectangle - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1)) # Convert coordinates to integers @@ -217,8 +244,9 @@ def closed_path_to_array(self, svg_data, image_shape, mask_class): is_inside = polygon_path.contains_points(points) # Reshape the result back into the 2D shape - mask = is_inside.reshape(image_height, image_width).astype(int) + mask = is_inside.reshape(image_height, image_width).astype(np.int8) - # Set the class value for the pixels inside the polygon + # Set the class value for the pixels inside the polygon, -1 for the rest + mask[mask == 0] = -1 mask[mask == 1] = mask_class return mask diff --git a/utils/data_utils.py b/utils/data_utils.py index bb54d29..a60b919 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -14,6 +14,11 @@ DATA_TILED_URI = os.getenv("DATA_TILED_URI") DATA_TILED_API_KEY = os.getenv("DATA_TILED_API_KEY") +MASK_TILED_URI = os.getenv("MASK_TILED_URI") +MASK_TILED_API_KEY = os.getenv("MASK_TILED_API_KEY") +SEG_TILED_URI = os.getenv("SEG_TILED_URI") +SEG_TILED_API_KEY = os.getenv("SEG_TILED_API_KEY") +USER_NAME = os.getenv("USER_NAME", "user1") class TiledDataLoader: @@ -22,14 +27,14 @@ def __init__( ): self.data_tiled_uri = data_tiled_uri self.data_tiled_api_key = data_tiled_api_key - self.data = from_uri( + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), ) - def refresh_data(self): - self.data = from_uri( + def refresh_data_client(self): + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), @@ -42,8 +47,8 @@ def get_data_project_names(self): """ project_names = [ project - for project in list(self.data) - if isinstance(self.data[project], (Container, ArrayClient)) + for project in list(self.data_client) + if isinstance(self.data_client[project], (Container, ArrayClient)) ] return project_names @@ -53,7 +58,7 @@ def get_data_sequence_by_name(self, project_name): but can also be additionally encapsulated in a folder, multiple container or in a .nxs file. We make use of specs to figure out the path to the 3d data. """ - project_client = self.data[project_name] + project_client = self.data_client[project_name] # If the project directly points to an array, directly return it if isinstance(project_client, ArrayClient): return project_client @@ -84,6 +89,37 @@ def get_data_shape_by_name(self, project_name): return project_container.shape return None + def get_data_uri_by_name(self, project_name): + """ + Retrieve uri of the data + """ + project_container = self.get_data_sequence_by_name(project_name) + if project_container: + return project_container.uri + return None + + +tiled_datasets = TiledDataLoader( + data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY +) + + +class TiledMaskHandler: + """ + This class is used to handle the masks that are generated from the annotations. + """ + + def __init__( + self, mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY + ): + self.mask_tiled_uri = mask_tiled_uri + self.mask_tiled_api_key = mask_tiled_api_key + self.mask_client = from_uri( + self.mask_tiled_uri, + api_key=self.mask_tiled_api_key, + timeout=httpx.Timeout(30.0), + ) + @staticmethod def get_annotated_segmented_results(json_file_path="exported_annotation_data.json"): annotated_slices = [] @@ -130,32 +166,60 @@ def DEV_filter_json_data_by_timestamp(data, timestamp): def save_annotations_data(self, global_store, all_annotations, project_name): """ - Transforms annotations data to a pixelated mask and outputs to - the Tiled server - - # TODO: Save data to Tiled server after transformation + Transforms annotations data to a pixelated mask and outputs to the Tiled server """ annotations = Annotations(all_annotations, global_store) - annotations.create_annotation_mask(sparse=True) # TODO: Check sparse status + # TODO: Check sparse status, it may be worthwhile to store the mask as a sparse array + # if our machine learning models can handle sparse arrays + annotations.create_annotation_mask(sparse=False) # Get metadata and annotation data - metadata = annotations.get_annotations() + annnotations_per_slice = annotations.get_annotations() + annotation_classes = annotations.get_annotation_classes() + annotations_hash = annotations.get_annotations_hash() + + metadata = { + "project_name": project_name, + "data_uri": tiled_datasets.get_data_uri_by_name(project_name), + "image_shape": global_store["image_shapes"][0], + "mask_idx": list(annnotations_per_slice.keys()), + "classes": annotation_classes, + "annotations": annnotations_per_slice, + "unlabeled_class_id": -1, + } + mask = annotations.get_annotation_mask() - # Get raw images associated with each annotated slice - img_idx = list(metadata.keys()) - img = self.data[project_name] - raw = [] - for idx in img_idx: - ar = img[int(idx)] - raw.append(ar) try: - raw = np.stack(raw) mask = np.stack(mask) except ValueError: return "No annotations to process." - return - - -tiled_dataset = TiledDataLoader() + # Store the mask in the Tiled server under /username/project_name/uuid/mask" + container_keys = [USER_NAME, project_name] + last_container = self.mask_client + for key in container_keys: + if key not in last_container.keys(): + last_container = last_container.create_container(key=key) + else: + last_container = last_container[key] + + # Add json metadata to a container with the md5 hash as key + # if a mask with that hash does not already exist + if annotations_hash not in last_container.keys(): + last_container = last_container.create_container( + key=annotations_hash, metadata=metadata + ) + mask = last_container.write_array(key="mask", array=mask) + else: + last_container = last_container[annotations_hash] + return last_container.uri + + +tiled_masks = TiledMaskHandler( + mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY +) + +tiled_results = TiledDataLoader( + data_tiled_uri=SEG_TILED_URI, data_tiled_api_key=SEG_TILED_API_KEY +) diff --git a/utils/plot_utils.py b/utils/plot_utils.py index 55a7dd1..f42ccb3 100644 --- a/utils/plot_utils.py +++ b/utils/plot_utils.py @@ -1,6 +1,7 @@ import random import dash_mantine_components as dmc +import numpy as np import plotly.express as px import plotly.graph_objects as go from dash_iconify import DashIconify @@ -170,6 +171,39 @@ def resize_canvas(h, w, H, W, figure): return figure, image_center_coor +def generate_segmentation_colormap(all_annotations_data): + """ + Generates a discrete colormap for the segmentation overlay + based on the color information per class. + + The discrete colormap maps values from 0 to 1 to colors, + but is meant to be applied to images with class ids as values, + with these varying from 0 to the number of classes - 1. + To account for numerical inaccuracies, it is best to center the plot range + around the class ids, by setting cmin=-0.5 and cmax=max_class_id+0.5. + """ + max_class_id = max( + [annotation_class["class_id"] for annotation_class in all_annotations_data] + ) + # heatmap requires a normalized range from 0 to 1 + # We need to specify color for at least the range limits (0 and 1) + # as well for every additional class + # due to using zero-based class ids, we need to add 2 to the max class id + normalized_range = np.linspace(0, 1, max_class_id + 2) + color_list = [ + annotation_class["color"] for annotation_class in all_annotations_data + ] + # We need to repeat each color twice, to create discrete color segments + # This loop contains the range limits 0 and 1 once, + # but every other value in between twice + colorscale = [ + [normalized_range[i + j], color_list[i % len(color_list)]] + for i in range(0, normalized_range.size - 1) + for j in range(2) + ] + return colorscale, max_class_id + + def generate_notification(title, color, icon, message=""): return dmc.Notification( title=title,