diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..1fafbd5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.env +.git +.gitignore diff --git a/app.py b/app.py index c8514d9..efda09c 100644 --- a/app.py +++ b/app.py @@ -2,7 +2,7 @@ import dash_auth import dash_mantine_components as dmc -from dash import Dash, dcc +from dash import Dash from callbacks.control_bar import * # noqa: F403, F401 from callbacks.image_viewer import * # noqa: F403, F401 @@ -30,9 +30,8 @@ children=[ control_bar_layout(), image_viewer_layout(), - dcc.Store(id="current-class-selection", data="#FFA200"), ], ) if __name__ == "__main__": - app.run_server(debug=True) + app.run_server(host="0.0.0.0", port=8075, debug=True) diff --git a/assets/models.json b/assets/models.json new file mode 100755 index 0000000..3b22cb1 --- /dev/null +++ b/assets/models.json @@ -0,0 +1,1176 @@ +{ + "contents": [ + { + "model_name": "DSLIA MSDNet", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "MSDNets in DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "layer_width", + "title": "Layer Width", + "param_key": "layer_width", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "num_layers", + "title": "# Layers", + "param_key": "num_layers", + "value": 3, + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "custom_dilation", + "title": "Custom Dilation", + "param_key": "custom_dilation", + "checked": false, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "max_dilation", + "title": "Maximum Dilation", + "param_key": "max_dilation", + "value": 5, + "comp_group": "train_model" + }, + { + "type": "str", + "name": "dilation_array", + "title": "Dilation Array", + "param_key": "dilation_array", + "value": "[1, 2, 4]", + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "data": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "data": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "data": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 50, + "label": "50%" + }, + { + "value": 100, + "label": "100%" + } + ], + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "checked": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Batch Size Training", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Batch Size Validation", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Batch Size Inference", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "DSLIA TUNet", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet in DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "data": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "data": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "data": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 100, + "label": "100%" + } + ], + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "checked": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "DSLIA TUNet3+", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet3+ DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "carryover_channels", + "title": "Carryover Channels", + "param_key": "carryover_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "# Epochs", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "data": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "data": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "Sigmoid", + "data": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 100, + "label": "100%" + } + ], + "comp_group": "train_model" + }, + { + "type": "bool", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "checked": true, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + } + ] +} diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9d7d438..b3086ed 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -2,6 +2,7 @@ import os import random import time +import uuid import dash_mantine_components as dmc import plotly.express as px @@ -24,9 +25,10 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item +from components.parameter_items import ParameterItems from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations -from utils.data_utils import tiled_dataset +from utils.data_utils import models, tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import generate_notification, generate_notification_bg_icon_col # TODO - temporary local file path and user for annotation saving and exporting @@ -332,6 +334,10 @@ def open_annotation_class_modal( disable_class_creation = True error_msg.append("Label Already in Use!") error_msg.append(html.Br()) + if new_label == "Unlabeled": + disable_class_creation = True + error_msg.append("Label name cannot be 'Unlabeled'") + error_msg.append(html.Br()) if new_color in current_colors: disable_class_creation = True error_msg.append("Color Already in use!") @@ -758,7 +764,7 @@ def populate_load_annotations_dropdown_menu_options(modal_opened, image_src): if not modal_opened: raise PreventUpdate - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) if not data: @@ -804,10 +810,10 @@ def load_and_apply_selected_annotations(selected_annotation, image_src, img_idx) )["index"] # TODO : when quering from the server, load (data) for user, source, time - data = tiled_dataset.DEV_load_exported_json_data( + data = tiled_masks.DEV_load_exported_json_data( EXPORT_FILE_PATH, USER_NAME, image_src ) - data = tiled_dataset.DEV_filter_json_data_by_timestamp( + data = tiled_masks.DEV_filter_json_data_by_timestamp( data, str(selected_annotation_timestamp) ) data = data[0]["data"] @@ -853,10 +859,10 @@ def populate_classification_results( image_src, refresh_tiled, toggle, dropdown_enabled, slider_enabled ): if refresh_tiled: - tiled_dataset.refresh_data() + tiled_datasets.refresh_data_client() data_options = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] results = [] value = None @@ -872,9 +878,10 @@ def populate_classification_results( disabled_toggle = False disabled_slider = slider_enabled else: + # TODO: Match by mask uid instead of image_src results = [ item - for item in tiled_dataset.get_data_project_names() + for item in tiled_results.get_data_project_names() if ("seg" in item and image_src in item) ] if results: @@ -912,3 +919,19 @@ def update_current_annotated_slices_values(all_classes): ] disabled = True if len(dropdown_values) == 0 else False return dropdown_values, disabled + + +@callback( + Output("model-parameters", "children"), + Input("model-list", "value"), +) +def update_model_parameters(model_name): + model = models[model_name] + if model["gui_parameters"]: + # TODO: Retain old parameters if they exist + item_list = ParameterItems( + _id={"type": str(uuid.uuid4())}, json_blob=model["gui_parameters"] + ) + return item_list + else: + return html.Div("Model has no parameters") diff --git a/callbacks/image_viewer.py b/callbacks/image_viewer.py index a864640..ea29e75 100644 --- a/callbacks/image_viewer.py +++ b/callbacks/image_viewer.py @@ -16,11 +16,12 @@ from dash.exceptions import PreventUpdate from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import tiled_datasets, tiled_masks, tiled_results from utils.plot_utils import ( create_viewfinder, downscale_view, generate_notification, + generate_segmentation_colormap, get_view_finder_max_min, resize_canvas, ) @@ -108,7 +109,7 @@ def render_image( if image_idx: image_idx -= 1 # slider starts at 1, so subtract 1 to get the correct index - tf = tiled_dataset.get_data_sequence_by_name(project_name)[image_idx] + tf = tiled_datasets.get_data_sequence_by_name(project_name)[image_idx] if toggle_seg_result: # if toggle is true and overlay exists already (2 images in data) this will # be handled in hide_show_segmentation_overlay callback @@ -117,9 +118,12 @@ def render_image( and ctx.triggered_id == "show-result-overlay-toggle" ): return [dash.no_update] * 7 + ["hidden"] - if str(image_idx + 1) in tiled_dataset.get_annotated_segmented_results(): - result = tiled_dataset.get_data_sequence_by_name(seg_result_selection)[ - image_idx + annotation_indices = tiled_masks.get_annotated_segmented_results() + if str(image_idx + 1) in annotation_indices: + # Will not return an error since we already checked if image_idx+1 is in the list + mapped_index = annotation_indices.index(str(image_idx + 1)) + result = tiled_results.get_data_sequence_by_name(seg_result_selection)[ + mapped_index ] else: result = None @@ -127,20 +131,14 @@ def render_image( tf = np.zeros((500, 500)) fig = px.imshow(tf, binary_string=True) if toggle_seg_result and result is not None: - unique_segmentation_values = np.unique(result) - normalized_range = np.linspace( - 0, 1, len(unique_segmentation_values) - ) # heatmap requires a normalized range - color_list = ( - px.colors.qualitative.Plotly - ) # TODO placeholder - replace with user defined classess - colorscale = [ - [normalized_range[i], color_list[i % len(color_list)]] - for i in range(len(unique_segmentation_values)) - ] + colorscale, max_class_id = generate_segmentation_colormap( + all_annotation_class_store + ) fig.add_trace( go.Heatmap( z=result, + zmin=-0.5, + zmax=max_class_id + 0.5, colorscale=colorscale, showscale=False, ) @@ -485,7 +483,7 @@ def update_slider_values(project_name, annotation_store): """ # Retrieve data shape if project_name is valid and points to a 3d array data_shape = ( - tiled_dataset.get_data_shape_by_name(project_name) if project_name else None + tiled_datasets.get_data_shape_by_name(project_name) if project_name else None ) disable_slider = data_shape is None if not disable_slider: diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 185e9a5..ecc062d 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -8,7 +8,7 @@ from dash import ALL, Input, Output, State, callback, no_update from constants import ANNOT_ICONS -from utils.data_utils import tiled_dataset +from utils.data_utils import extract_parameters_from_html, tiled_masks from utils.plot_utils import generate_notification from utils.prefect import get_flow_run_name, query_flow_run, schedule_prefect_flow @@ -57,14 +57,16 @@ @callback( Output("notifications-container", "children", allow_duplicate=True), + Output("model-parameter-values", "data"), Input("run-train", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), + State("model-parameters", "children"), State("job-name", "value"), prevent_initial_call=True, ) -def run_train(n_clicks, global_store, all_annotations, project_name, job_name): +def run_train(n_clicks, global_store, all_annotations, project_name, model_parameters, job_name): """ This callback collects parameters from the UI and submits a training job to Prefect. If the app is run from "dev" mode, then only a placeholder job_uid will be created. @@ -72,13 +74,20 @@ def run_train(n_clicks, global_store, all_annotations, project_name, job_name): # TODO: Appropriately paramaterize the job json depending on user inputs and relevant file paths """ + input_params = {} if n_clicks: + input_params = extract_parameters_from_html(model_parameters) + # return the input values in dictionary and save to the model parameter store + print(f"input_param:\n{input_params}") if MODE == "dev": + mask_uri = tiled_masks.save_annotations_data( + global_store, all_annotations, project_name + ) job_uid = str(uuid.uuid4()) - job_message = f"Job has been succesfully submitted with uid: {job_uid}" + job_message = f"Workflow has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}" notification_color = "indigo" else: - tiled_dataset.save_annotations_data( + mask_uri = tiled_masks.save_annotations_data( global_store, all_annotations, project_name ) try: @@ -105,8 +114,8 @@ def run_train(n_clicks, global_store, all_annotations, project_name, job_name): "Job Submission", notification_color, ANNOT_ICONS["submit"], job_message ) - return notification - return no_update + return notification, input_params + return no_update, no_update @callback( diff --git a/components/annotation_class.py b/components/annotation_class.py index a855609..d97b836 100644 --- a/components/annotation_class.py +++ b/components/annotation_class.py @@ -33,7 +33,7 @@ def annotation_class_item(class_color, class_label, existing_ids, data=None): annotations = data["annotations"] is_visible = data["is_visible"] else: - class_id = 1 if not existing_ids else max(existing_ids) + 1 + class_id = 0 if not existing_ids else max(existing_ids) + 1 annotations = {} is_visible = True class_color_transparent = class_color + "50" diff --git a/components/control_bar.py b/components/control_bar.py index 968b8d5..03d2343 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -5,8 +5,9 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item +from components.parameter_items import ControlItem from constants import ANNOT_ICONS, KEYBINDS -from utils.data_utils import tiled_dataset +from utils.data_utils import models, tiled_datasets def _tooltip(text, children): @@ -18,24 +19,6 @@ def _tooltip(text, children): ) -def _control_item(title, title_id, item): - """ - Returns a customized layout for a control item - """ - return dmc.Grid( - [ - dmc.Text( - title, - id=title_id, - size="sm", - style={"width": "100px", "margin": "auto", "paddingRight": "5px"}, - align="right", - ), - html.Div(item, style={"width": "265px", "margin": "auto"}), - ] - ) - - def _accordion_item(title, icon, value, children, id): """ Returns a customized layout for an accordion item @@ -62,7 +45,7 @@ def layout(): Returns the layout for the control panel in the app UI """ DATA_OPTIONS = [ - item for item in tiled_dataset.get_data_project_names() if "seg" not in item + item for item in tiled_datasets.get_data_project_names() if "seg" not in item ] return drawer_section( dmc.Stack( @@ -78,7 +61,7 @@ def layout(): id="data-selection-controls", children=[ dmc.Space(h=5), - _control_item( + ControlItem( "Dataset", "image-selector", dmc.Grid( @@ -114,7 +97,7 @@ def layout(): ), ), dmc.Space(h=25), - _control_item( + ControlItem( "Slice 1", "image-selection-text", [ @@ -177,7 +160,7 @@ def layout(): ], ), dmc.Space(h=25), - _control_item( + ControlItem( _tooltip( "Jump to your annotated slices", "Annotated slices", @@ -207,7 +190,7 @@ def layout(): children=html.Div( [ dmc.Space(h=5), - _control_item( + ControlItem( "Brightness", "bightness-text", [ @@ -251,7 +234,7 @@ def layout(): ], ), dmc.Space(h=20), - _control_item( + ControlItem( "Contrast", "contrast-text", dmc.Grid( @@ -452,6 +435,9 @@ def layout(): }, className="add-class-btn", ), + dcc.Store( + id="current-class-selection", data="#FFA200" + ), dmc.Space(h=20), ], ), @@ -603,7 +589,25 @@ def layout(): "run-model", id="model-configuration", children=[ - _control_item( + ControlItem( + "Model", + "model-selector", + dmc.Select( + id="model-list", + data=models.modelname_list, + value=( + models.modelname_list[0] + if models.modelname_list[0] + else None + ), + placeholder="Select a model...", + ), + ), + dmc.Space(h=15), + html.Div(id="model-parameters"), + dcc.Store(id="model-parameter-values", data={}), + dmc.Space(h=25), + ControlItem( "Name", "job-name-input", dmc.TextInput( @@ -619,7 +623,7 @@ def layout(): style={"width": "100%", "margin": "5px"}, ), dmc.Space(h=10), - _control_item( + ControlItem( "Train Jobs", "selected-train-job", dmc.Select( @@ -635,7 +639,7 @@ def layout(): style={"width": "100%", "margin": "5px"}, ), dmc.Space(h=10), - _control_item( + ControlItem( "Inference Jobs", "selected-inference-job", dmc.Select( @@ -655,7 +659,7 @@ def layout(): styles={"trackLabel": {"cursor": "pointer"}}, ), dmc.Space(h=25), - _control_item( + ControlItem( "Results", "", dmc.Select( @@ -664,7 +668,7 @@ def layout(): ), ), dmc.Space(h=25), - _control_item( + ControlItem( "Opacity", "", dmc.Slider( diff --git a/components/parameter_items.py b/components/parameter_items.py new file mode 100644 index 0000000..f165125 --- /dev/null +++ b/components/parameter_items.py @@ -0,0 +1,271 @@ +import dash_bootstrap_components as dbc +import dash_mantine_components as dmc +from dash import html + + +class ControlItem(dmc.Grid): + """ + Customized layout for a control item + """ + + def __init__(self, title, title_id, item, style={}): + super(ControlItem, self).__init__( + children=[ + dmc.Text( + title, + id=title_id, + size="sm", + style={"width": "100px", "margin": "auto", "paddingRight": "5px"}, + align="right", + ), + html.Div(item, style={"width": "265px", "margin": "auto"}), + ], + style=style, + ) + + +class NumberItem(ControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.NumberInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(NumberItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class StrItem(ControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.TextInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(StrItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class SliderItem(ControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.Slider( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + labelAlwaysOn=False, + **kwargs, + ) + + style = {"padding": "15px 0px 15px 0px"} + if not visible: + style["display"] = "none" + + super(SliderItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class DropdownItem(ControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.Select( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(DropdownItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class RadioItem(ControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + if param_key is None: + param_key = name + + options = [ + dmc.Radio(option["label"], value=option["value"]) + for option in kwargs["options"] + ] + kwargs.pop("options", None) + self.input = dmc.RadioGroup( + options, + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(RadioItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class BoolItem(ControlItem): + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): + if param_key is None: + param_key = name + + self.input = dmc.Switch( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + label=title, + **kwargs, + ) + + style = {"padding": "15px 0px 0px 0px"} + if not visible: + style["display"] = "none" + + super(BoolItem, self).__init__( + title="", # title is already in the switch + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, + style=style, + ) + + +class ParameterItems(dbc.Form): + type_map = { + "float": NumberItem, + "int": NumberItem, + "str": StrItem, + "slider": SliderItem, + "dropdown": DropdownItem, + "radio": RadioItem, + "bool": BoolItem, + } + + def __init__(self, _id, json_blob, values=None): + super(ParameterItems, self).__init__(id=_id, children=[]) + self._json_blob = json_blob + self.children = self.build_children(values=values) + + def _determine_type(self, parameter_dict): + if "type" in parameter_dict: + if parameter_dict["type"] in self.type_map: + return parameter_dict["type"] + elif parameter_dict["type"].__name__ in self.type_map: + return parameter_dict["type"].__name__ + elif type(parameter_dict["value"]) in self.type_map: + return type(parameter_dict["value"]) + raise TypeError( + f"No item type could be determined for this parameter: {parameter_dict}" + ) + + def build_children(self, values=None): + children = [] + for json_record in self._json_blob: + # Build a parameter dict from self.json_blob + type = json_record.get("type", self._determine_type(json_record)) + json_record = json_record.copy() + if values and json_record["name"] in values: + json_record["value"] = values[json_record["name"]] + json_record.pop("type", None) + if "comp_group" in json_record: + json_record.pop("comp_group", None) + item = self.type_map[type](**json_record, base_id=self.id) + children.append(item) + + return children diff --git a/docker-compose.yml b/docker-compose.yml index 9002b3c..b4b5672 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,12 @@ services: environment: DATA_TILED_URI: '${DATA_TILED_URI}' DATA_TILED_API_KEY: '${DATA_TILED_API_KEY}' + MASK_TILED_URI: '${MASK_TILED_URI}' + MASK_TILED_API_KEY: '${TILED_API_KEY}' + SEG_TILED_URI: '${SEG_TILED_URI}' + SEG_TILED_API_KEY: '${SEG_TILED_API_KEY}' + USER_NAME: '${USER_NAME}' + USER_PASSWORD: '${USER_PASSWORD}' RESULTS_DIR: '${RESULTS_DIR}' PREFECT_API_URL: '${PREFECT_API_URL}' FLOW_NAME: '${FLOW_NAME}' diff --git a/examples/plot_mask.py b/examples/plot_mask.py new file mode 100644 index 0000000..fa4735d --- /dev/null +++ b/examples/plot_mask.py @@ -0,0 +1,77 @@ +import os +import sys + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +from dotenv import load_dotenv +from matplotlib.colors import ListedColormap +from tiled.client import from_uri + + +def plot_mask(mask_uri, api_key, slice_idx, output_path): + """ + Saves a plot of a given mask using metadata information such as class colors and labels. + It is assumed that the given uri is the uri of a mask container, with associated meta data + and a mask array under the key "mask". + The given slice index references mask slices, not the original data. + However, the printed slice index in the figure will be the index of the original data. + """ + # Retrieve mask and metadata + mask_client = from_uri(mask_uri, api_key=api_key) + mask = mask_client["mask"][slice_idx] + + meta_data = mask_client.metadata + mask_idx = meta_data["mask_idx"] + + if slice_idx > len(mask_idx): + raise ValueError("Slice index out of range") + + class_meta_data = meta_data["classes"] + max_class_id = len(class_meta_data.keys()) - 1 + + colors = [ + annotation_class["color"] for _, annotation_class in class_meta_data.items() + ] + labels = [ + annotation_class["label"] for _, annotation_class in class_meta_data.items() + ] + # Add color for unlabeled pixels + colors = ["#D3D3D3"] + colors + labels = ["Unlabeled"] + labels + + plt.imshow( + mask, + cmap=ListedColormap(colors), + vmin=-1.5, + vmax=max_class_id + 0.5, + ) + plt.title(meta_data["project_name"] + ", slice: " + mask_idx[slice_idx]) + + # create a patch for every color + patches = [ + mpatches.Patch(color=colors[i], label=labels[i]) for i in range(len(labels)) + ] + # Plot legend below the image + plt.legend( + handles=patches, loc="upper center", bbox_to_anchor=(0.5, -0.075), ncol=3 + ) + plt.savefig(output_path, bbox_inches="tight") + + +if __name__ == "__main__": + """ + Example usage: python3 plot_mask.py http://localhost:8000/api/v1/metadata/mlex_store/mlex_store/username/dataset/uuid + """ + + load_dotenv() + api_key = os.getenv("MASK_TILED_API_KEY", None) + + if len(sys.argv) < 2: + print("Usage: python3 plot_mask.py [slice_idx] [output_path]") + sys.exit(1) + + mask_uri = sys.argv[1] + slice_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0 + output_path = sys.argv[3] if len(sys.argv) > 3 else "mask.png" + + plot_mask(mask_uri, api_key, slice_idx, output_path) diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..618af5e --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,2 @@ +matplotlib +tiled[client] diff --git a/requirements.txt b/requirements.txt index 07675be..2dcbdbe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ scipy dash-extensions==1.0.1 dash-bootstrap-components==1.5.0 dash_auth==2.0.0 +canonicaljson diff --git a/utils/annotations.py b/utils/annotations.py index 765ac73..e7875f1 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -1,6 +1,8 @@ +import hashlib import io import zipfile +import canonicaljson import numpy as np import scipy.sparse as sp from matplotlib.path import Path @@ -9,32 +11,45 @@ class Annotations: - def __init__(self, annotation_store, global_store): + def __init__(self, annotation_store, image_shape): if annotation_store: slices = [] for annotation_class in annotation_store: slices.extend(list(annotation_class["annotations"].keys())) slices = set(slices) - annotations = {key: [] for key in slices} + # Slices need to be sorted to ensure that the exported mask slices + # have the same order as the original data set + annotations = {key: [] for key in sorted(slices, key=int)} + + all_class_labels = [ + annotation_class["class_id"] for annotation_class in annotation_store + ] + annotation_classes = {} for annotation_class in annotation_store: + condensed_id = str(all_class_labels.index(annotation_class["class_id"])) + annotation_classes[condensed_id] = { + "label": annotation_class["label"], + "color": annotation_class["color"], + } for image_idx, slice_data in annotation_class["annotations"].items(): for shape in slice_data: self._set_annotation_type(shape) self._set_annotation_svg(shape) annotation = { - "id": annotation_class["class_id"], + "class_id": condensed_id, "type": self.annotation_type, - "class": annotation_class["label"], - # TODO: This is the same across all images in a dataset - "image_shape": global_store["image_shapes"][0], "svg_data": self.svg_data, } annotations[image_idx].append(annotation) else: - annotations = [] + annotations = None + annotation_classes = None + self.annotation_classes = annotation_classes self.annotations = annotations + self.annotations_hash = self.get_annotations_hash() + self.image_shape = image_shape def get_annotations(self): return self.annotations @@ -42,6 +57,15 @@ def get_annotations(self): def get_annotation_mask(self): return self.annotation_mask + def get_annotation_classes(self): + return self.annotation_classes + + def get_annotations_hash(self): + hash_object = hashlib.md5() + hash_object.update(canonicaljson.encode_canonical_json(self.annotations)) + hash_object.update(canonicaljson.encode_canonical_json(self.annotation_classes)) + return hash_object.hexdigest() + def get_annotation_mask_as_bytes(self): buffer = io.BytesIO() zip_buffer = io.BytesIO() @@ -81,26 +105,29 @@ def create_annotation_mask(self, sparse=False): self.sparse = sparse annotation_mask = [] + image_height = self.image_shape[0] + image_width = self.image_shape[1] + for slice_idx, slice_data in self.annotations.items(): - image_height = slice_data[0]["image_shape"][0] - image_width = slice_data[0]["image_shape"][1] - slice_mask = np.zeros([image_height, image_width], dtype=np.uint8) + slice_mask = np.full( + [image_height, image_width], fill_value=-1, dtype=np.int8 + ) for shape in slice_data: if shape["type"] == "Closed Freeform": shape_mask = ShapeConversion.closed_path_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Rectangle": shape_mask = ShapeConversion.rectangle_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) elif shape["type"] == "Ellipse": shape_mask = ShapeConversion.ellipse_to_array( - shape["svg_data"], shape["image_shape"], shape["id"] + shape["svg_data"], self.image_shape, shape["class_id"] ) else: continue - slice_mask[shape_mask > 0] = shape_mask[shape_mask > 0] + slice_mask[shape_mask >= 0] = shape_mask[shape_mask >= 0] annotation_mask.append(slice_mask) if sparse: @@ -154,7 +181,7 @@ def ellipse_to_array(self, svg_data, image_shape, mask_class): c_radius = abs(svg_data["y0"] - svg_data["y1"]) / 2 # Vertical radius # Create mask and draw ellipse - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.ellipse( cy, cx, c_radius, r_radius ) # Vertical radius first, then horizontal @@ -181,7 +208,7 @@ def rectangle_to_array(self, svg_data, image_shape, mask_class): y1 = max(min(y1, image_height - 1), 0) # # Draw the rectangle - mask = np.zeros((image_height, image_width), dtype=np.uint8) + mask = np.full((image_height, image_width), fill_value=-1, dtype=np.int8) rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1)) # Convert coordinates to integers @@ -217,8 +244,9 @@ def closed_path_to_array(self, svg_data, image_shape, mask_class): is_inside = polygon_path.contains_points(points) # Reshape the result back into the 2D shape - mask = is_inside.reshape(image_height, image_width).astype(int) + mask = is_inside.reshape(image_height, image_width).astype(np.int8) - # Set the class value for the pixels inside the polygon + # Set the class value for the pixels inside the polygon, -1 for the rest + mask[mask == 0] = -1 mask[mask == 1] = mask_class return mask diff --git a/utils/data_utils.py b/utils/data_utils.py index bb54d29..3311019 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -14,6 +14,11 @@ DATA_TILED_URI = os.getenv("DATA_TILED_URI") DATA_TILED_API_KEY = os.getenv("DATA_TILED_API_KEY") +MASK_TILED_URI = os.getenv("MASK_TILED_URI") +MASK_TILED_API_KEY = os.getenv("MASK_TILED_API_KEY") +SEG_TILED_URI = os.getenv("SEG_TILED_URI") +SEG_TILED_API_KEY = os.getenv("SEG_TILED_API_KEY") +USER_NAME = os.getenv("USER_NAME", "user1") class TiledDataLoader: @@ -22,14 +27,14 @@ def __init__( ): self.data_tiled_uri = data_tiled_uri self.data_tiled_api_key = data_tiled_api_key - self.data = from_uri( + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), ) - def refresh_data(self): - self.data = from_uri( + def refresh_data_client(self): + self.data_client = from_uri( self.data_tiled_uri, api_key=self.data_tiled_api_key, timeout=httpx.Timeout(30.0), @@ -42,8 +47,8 @@ def get_data_project_names(self): """ project_names = [ project - for project in list(self.data) - if isinstance(self.data[project], (Container, ArrayClient)) + for project in list(self.data_client) + if isinstance(self.data_client[project], (Container, ArrayClient)) ] return project_names @@ -53,7 +58,7 @@ def get_data_sequence_by_name(self, project_name): but can also be additionally encapsulated in a folder, multiple container or in a .nxs file. We make use of specs to figure out the path to the 3d data. """ - project_client = self.data[project_name] + project_client = self.data_client[project_name] # If the project directly points to an array, directly return it if isinstance(project_client, ArrayClient): return project_client @@ -84,6 +89,37 @@ def get_data_shape_by_name(self, project_name): return project_container.shape return None + def get_data_uri_by_name(self, project_name): + """ + Retrieve uri of the data + """ + project_container = self.get_data_sequence_by_name(project_name) + if project_container: + return project_container.uri + return None + + +tiled_datasets = TiledDataLoader( + data_tiled_uri=DATA_TILED_URI, data_tiled_api_key=DATA_TILED_API_KEY +) + + +class TiledMaskHandler: + """ + This class is used to handle the masks that are generated from the annotations. + """ + + def __init__( + self, mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY + ): + self.mask_tiled_uri = mask_tiled_uri + self.mask_tiled_api_key = mask_tiled_api_key + self.mask_client = from_uri( + self.mask_tiled_uri, + api_key=self.mask_tiled_api_key, + timeout=httpx.Timeout(30.0), + ) + @staticmethod def get_annotated_segmented_results(json_file_path="exported_annotation_data.json"): annotated_slices = [] @@ -130,32 +166,113 @@ def DEV_filter_json_data_by_timestamp(data, timestamp): def save_annotations_data(self, global_store, all_annotations, project_name): """ - Transforms annotations data to a pixelated mask and outputs to - the Tiled server - - # TODO: Save data to Tiled server after transformation + Transforms annotations data to a pixelated mask and outputs to the Tiled server """ - annotations = Annotations(all_annotations, global_store) - annotations.create_annotation_mask(sparse=True) # TODO: Check sparse status + if "image_shapes" in global_store: + image_shape = global_store["image_shapes"][0] + else: + print("Global store was not filled.") + data_shape = ( + tiled_datasets.get_data_shape_by_name(project_name) + if project_name + else None + ) + image_shape = (data_shape[1], data_shape[2]) + + annotations = Annotations(all_annotations, image_shape) + # TODO: Check sparse status, it may be worthwhile to store the mask as a sparse array + # if our machine learning models can handle sparse arrays + annotations.create_annotation_mask(sparse=False) # Get metadata and annotation data - metadata = annotations.get_annotations() + annnotations_per_slice = annotations.get_annotations() + annotation_classes = annotations.get_annotation_classes() + annotations_hash = annotations.get_annotations_hash() + + metadata = { + "project_name": project_name, + "data_uri": tiled_datasets.get_data_uri_by_name(project_name), + "image_shape": global_store["image_shapes"][0], + "mask_idx": list(annnotations_per_slice.keys()), + "classes": annotation_classes, + "annotations": annnotations_per_slice, + "unlabeled_class_id": -1, + } + mask = annotations.get_annotation_mask() - # Get raw images associated with each annotated slice - img_idx = list(metadata.keys()) - img = self.data[project_name] - raw = [] - for idx in img_idx: - ar = img[int(idx)] - raw.append(ar) try: - raw = np.stack(raw) mask = np.stack(mask) except ValueError: return "No annotations to process." - return + # Store the mask in the Tiled server under /username/project_name/uuid/mask" + container_keys = [USER_NAME, project_name] + last_container = self.mask_client + for key in container_keys: + if key not in last_container.keys(): + last_container = last_container.create_container(key=key) + else: + last_container = last_container[key] + + # Add json metadata to a container with the md5 hash as key + # if a mask with that hash does not already exist + if annotations_hash not in last_container.keys(): + last_container = last_container.create_container( + key=annotations_hash, metadata=metadata + ) + mask = last_container.write_array(key="mask", array=mask) + else: + last_container = last_container[annotations_hash] + return last_container.uri + + +tiled_masks = TiledMaskHandler( + mask_tiled_uri=MASK_TILED_URI, mask_tiled_api_key=MASK_TILED_API_KEY +) + +tiled_results = TiledDataLoader( + data_tiled_uri=SEG_TILED_URI, data_tiled_api_key=SEG_TILED_API_KEY +) + + +class Models: + def __init__(self, modelfile_path="./assets/models.json"): + self.path = modelfile_path + f = open(self.path) + + contents = json.load(f)["contents"] + self.modelname_list = [content["model_name"] for content in contents] + self.models = {} + + for i, n in enumerate(self.modelname_list): + self.models[n] = contents[i] + + def __getitem__(self, key): + try: + return self.models[key] + except KeyError: + raise KeyError(f"A model with name {key} does not exist.") + + +models = Models() -tiled_dataset = TiledDataLoader() +def extract_parameters_from_html(model_parameters_html): + """ + Extracts parameters from the children component of a + """ + input_params = {} + for param in model_parameters_html["props"]["children"]: + # param["props"]["children"][0] is the label + # param["props"]["children"][1] is the input + parameter_container = param["props"]["children"][1] + # The achtual parameter item is the first and only child of the parameter container + parameter_item = parameter_container["props"]["children"]["props"] + key = parameter_item["id"]["param_key"] + if "value" in parameter_item: + value = parameter_item["value"] + elif "checked" in parameter_item: + value = parameter_item["checked"] + input_params[key] = value + return input_params diff --git a/utils/plot_utils.py b/utils/plot_utils.py index 55a7dd1..f42ccb3 100644 --- a/utils/plot_utils.py +++ b/utils/plot_utils.py @@ -1,6 +1,7 @@ import random import dash_mantine_components as dmc +import numpy as np import plotly.express as px import plotly.graph_objects as go from dash_iconify import DashIconify @@ -170,6 +171,39 @@ def resize_canvas(h, w, H, W, figure): return figure, image_center_coor +def generate_segmentation_colormap(all_annotations_data): + """ + Generates a discrete colormap for the segmentation overlay + based on the color information per class. + + The discrete colormap maps values from 0 to 1 to colors, + but is meant to be applied to images with class ids as values, + with these varying from 0 to the number of classes - 1. + To account for numerical inaccuracies, it is best to center the plot range + around the class ids, by setting cmin=-0.5 and cmax=max_class_id+0.5. + """ + max_class_id = max( + [annotation_class["class_id"] for annotation_class in all_annotations_data] + ) + # heatmap requires a normalized range from 0 to 1 + # We need to specify color for at least the range limits (0 and 1) + # as well for every additional class + # due to using zero-based class ids, we need to add 2 to the max class id + normalized_range = np.linspace(0, 1, max_class_id + 2) + color_list = [ + annotation_class["color"] for annotation_class in all_annotations_data + ] + # We need to repeat each color twice, to create discrete color segments + # This loop contains the range limits 0 and 1 once, + # but every other value in between twice + colorscale = [ + [normalized_range[i + j], color_list[i % len(color_list)]] + for i in range(0, normalized_range.size - 1) + for j in range(2) + ] + return colorscale, max_class_id + + def generate_notification(title, color, icon, message=""): return dmc.Notification( title=title,