Skip to content

Commit

Permalink
Merge pull request #172 from Wiebke/save-and-load-masks-tiled
Browse files Browse the repository at this point in the history
Save pixelated masks to Tiled
  • Loading branch information
taxe10 committed Mar 6, 2024
2 parents 26a8dcb + f0d6a42 commit 555c27f
Show file tree
Hide file tree
Showing 12 changed files with 286 additions and 72 deletions.
3 changes: 3 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.env
.git
.gitignore
19 changes: 12 additions & 7 deletions callbacks/control_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from components.annotation_class import annotation_class_item
from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS
from utils.annotations import Annotations
from utils.data_utils import tiled_dataset
from utils.data_utils import tiled_datasets, tiled_masks, tiled_results
from utils.plot_utils import generate_notification, generate_notification_bg_icon_col

# TODO - temporary local file path and user for annotation saving and exporting
Expand Down Expand Up @@ -332,6 +332,10 @@ def open_annotation_class_modal(
disable_class_creation = True
error_msg.append("Label Already in Use!")
error_msg.append(html.Br())
if new_label == "Unlabeled":
disable_class_creation = True
error_msg.append("Label name cannot be 'Unlabeled'")
error_msg.append(html.Br())
if new_color in current_colors:
disable_class_creation = True
error_msg.append("Color Already in use!")
Expand Down Expand Up @@ -758,7 +762,7 @@ def populate_load_annotations_dropdown_menu_options(modal_opened, image_src):
if not modal_opened:
raise PreventUpdate

data = tiled_dataset.DEV_load_exported_json_data(
data = tiled_masks.DEV_load_exported_json_data(
EXPORT_FILE_PATH, USER_NAME, image_src
)
if not data:
Expand Down Expand Up @@ -804,10 +808,10 @@ def load_and_apply_selected_annotations(selected_annotation, image_src, img_idx)
)["index"]

# TODO : when quering from the server, load (data) for user, source, time
data = tiled_dataset.DEV_load_exported_json_data(
data = tiled_masks.DEV_load_exported_json_data(
EXPORT_FILE_PATH, USER_NAME, image_src
)
data = tiled_dataset.DEV_filter_json_data_by_timestamp(
data = tiled_masks.DEV_filter_json_data_by_timestamp(
data, str(selected_annotation_timestamp)
)
data = data[0]["data"]
Expand Down Expand Up @@ -853,10 +857,10 @@ def populate_classification_results(
image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled
):
if refresh_tiled:
tiled_dataset.refresh_data()
tiled_datasets.refresh_data_client()

data_options = [
item for item in tiled_dataset.get_data_project_names() if "seg" not in item
item for item in tiled_datasets.get_data_project_names() if "seg" not in item
]
results = []
value = None
Expand All @@ -872,9 +876,10 @@ def populate_classification_results(
disabled_toggle = False
disabled_slider = slider_enabled
else:
# TODO: Match by mask uid instead of image_src
results = [
item
for item in tiled_dataset.get_data_project_names()
for item in tiled_results.get_data_project_names()
if ("seg" in item and image_src in item)
]
if results:
Expand Down
32 changes: 15 additions & 17 deletions callbacks/image_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from dash.exceptions import PreventUpdate

from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEYBINDS
from utils.data_utils import tiled_dataset
from utils.data_utils import tiled_datasets, tiled_masks, tiled_results
from utils.plot_utils import (
create_viewfinder,
downscale_view,
generate_notification,
generate_segmentation_colormap,
get_view_finder_max_min,
resize_canvas,
)
Expand Down Expand Up @@ -108,7 +109,7 @@ def render_image(

if image_idx:
image_idx -= 1 # slider starts at 1, so subtract 1 to get the correct index
tf = tiled_dataset.get_data_sequence_by_name(project_name)[image_idx]
tf = tiled_datasets.get_data_sequence_by_name(project_name)[image_idx]
if toggle_seg_result:
# if toggle is true and overlay exists already (2 images in data) this will
# be handled in hide_show_segmentation_overlay callback
Expand All @@ -117,30 +118,27 @@ def render_image(
and ctx.triggered_id == "show-result-overlay-toggle"
):
return [dash.no_update] * 7 + ["hidden"]
if str(image_idx + 1) in tiled_dataset.get_annotated_segmented_results():
result = tiled_dataset.get_data_sequence_by_name(seg_result_selection)[
image_idx
annotation_indices = tiled_masks.get_annotated_segmented_results()
if str(image_idx + 1) in annotation_indices:
# Will not return an error since we already checked if image_idx+1 is in the list
mapped_index = annotation_indices.index(str(image_idx + 1))
result = tiled_results.get_data_sequence_by_name(seg_result_selection)[
mapped_index
]
else:
result = None
else:
tf = np.zeros((500, 500))
fig = px.imshow(tf, binary_string=True)
if toggle_seg_result and result is not None:
unique_segmentation_values = np.unique(result)
normalized_range = np.linspace(
0, 1, len(unique_segmentation_values)
) # heatmap requires a normalized range
color_list = (
px.colors.qualitative.Plotly
) # TODO placeholder - replace with user defined classess
colorscale = [
[normalized_range[i], color_list[i % len(color_list)]]
for i in range(len(unique_segmentation_values))
]
colorscale, max_class_id = generate_segmentation_colormap(
all_annotation_class_store
)
fig.add_trace(
go.Heatmap(
z=result,
zmin=-0.5,
zmax=max_class_id + 0.5,
colorscale=colorscale,
showscale=False,
)
Expand Down Expand Up @@ -485,7 +483,7 @@ def update_slider_values(project_name, annotation_store):
"""
# Retrieve data shape if project_name is valid and points to a 3d array
data_shape = (
tiled_dataset.get_data_shape_by_name(project_name) if project_name else None
tiled_datasets.get_data_shape_by_name(project_name) if project_name else None
)
disable_slider = data_shape is None
if not disable_slider:
Expand Down
10 changes: 6 additions & 4 deletions callbacks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dash import ALL, Input, Output, State, callback, no_update
from dash.exceptions import PreventUpdate

from utils.data_utils import tiled_dataset
from utils.data_utils import tiled_masks

MODE = os.getenv("MODE", "")

Expand Down Expand Up @@ -74,17 +74,19 @@ def run_job(n_clicks, global_store, all_annotations, project_name):
"""
if n_clicks:
if MODE == "dev":
mask_uri = tiled_masks.save_annotations_data(
global_store, all_annotations, project_name
)
job_uid = str(uuid.uuid4())
return (
dmc.Text(
f"Workflow has been succesfully submitted with uid: {job_uid}",
f"Workflow has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}",
size="sm",
),
job_uid,
)
else:

tiled_dataset.save_annotations_data(
mask_uri = tiled_masks.save_annotations_data(
global_store, all_annotations, project_name
)
job_submitted = requests.post(
Expand Down
2 changes: 1 addition & 1 deletion components/annotation_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def annotation_class_item(class_color, class_label, existing_ids, data=None):
annotations = data["annotations"]
is_visible = data["is_visible"]
else:
class_id = 1 if not existing_ids else max(existing_ids) + 1
class_id = 0 if not existing_ids else max(existing_ids) + 1
annotations = {}
is_visible = True
class_color_transparent = class_color + "50"
Expand Down
4 changes: 2 additions & 2 deletions components/control_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from components.annotation_class import annotation_class_item
from constants import ANNOT_ICONS, KEYBINDS
from utils.data_utils import tiled_dataset
from utils.data_utils import tiled_datasets


def _tooltip(text, children):
Expand Down Expand Up @@ -62,7 +62,7 @@ def layout():
Returns the layout for the control panel in the app UI
"""
DATA_OPTIONS = [
item for item in tiled_dataset.get_data_project_names() if "seg" not in item
item for item in tiled_datasets.get_data_project_names() if "seg" not in item
]
return drawer_section(
dmc.Stack(
Expand Down
77 changes: 77 additions & 0 deletions examples/plot_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import sys

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from dotenv import load_dotenv
from matplotlib.colors import ListedColormap
from tiled.client import from_uri


def plot_mask(mask_uri, api_key, slice_idx, output_path):
"""
Saves a plot of a given mask using metadata information such as class colors and labels.
It is assumed that the given uri is the uri of a mask container, with associated meta data
and a mask array under the key "mask".
The given slice index references mask slices, not the original data.
However, the printed slice index in the figure will be the index of the original data.
"""
# Retrieve mask and metadata
mask_client = from_uri(mask_uri, api_key=api_key)
mask = mask_client["mask"][slice_idx]

meta_data = mask_client.metadata
mask_idx = meta_data["mask_idx"]

if slice_idx > len(mask_idx):
raise ValueError("Slice index out of range")

class_meta_data = meta_data["classes"]
max_class_id = len(class_meta_data.keys()) - 1

colors = [
annotation_class["color"] for _, annotation_class in class_meta_data.items()
]
labels = [
annotation_class["label"] for _, annotation_class in class_meta_data.items()
]
# Add color for unlabeled pixels
colors = ["#D3D3D3"] + colors
labels = ["Unlabeled"] + labels

plt.imshow(
mask,
cmap=ListedColormap(colors),
vmin=-1.5,
vmax=max_class_id + 0.5,
)
plt.title(meta_data["project_name"] + ", slice: " + mask_idx[slice_idx])

# create a patch for every color
patches = [
mpatches.Patch(color=colors[i], label=labels[i]) for i in range(len(labels))
]
# Plot legend below the image
plt.legend(
handles=patches, loc="upper center", bbox_to_anchor=(0.5, -0.075), ncol=3
)
plt.savefig(output_path, bbox_inches="tight")


if __name__ == "__main__":
"""
Example usage: python3 plot_mask.py http://localhost:8000/api/v1/metadata/mlex_store/mlex_store/username/dataset/uuid
"""

load_dotenv()
api_key = os.getenv("MASK_TILED_API_KEY", None)

if len(sys.argv) < 2:
print("Usage: python3 plot_mask.py <mask_uri> [slice_idx] [output_path]")
sys.exit(1)

mask_uri = sys.argv[1]
slice_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0
output_path = sys.argv[3] if len(sys.argv) > 3 else "mask.png"

plot_mask(mask_uri, api_key, slice_idx, output_path)
2 changes: 2 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
matplotlib
tiled[client]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ scipy
dash-extensions==1.0.1
dash-bootstrap-components==1.5.0
dash_auth==2.0.0
canonicaljson
Loading

0 comments on commit 555c27f

Please sign in to comment.