diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py index 45f85220..9e668fc5 100644 --- a/napari_cellseg3d/_tests/test_plugin_utils.py +++ b/napari_cellseg3d/_tests/test_plugin_utils.py @@ -13,7 +13,7 @@ def test_utils_plugin(make_napari_viewer_proxy): view = make_napari_viewer_proxy() widget = Utilities(view) - image = rand_gen.random((10, 10, 10)).astype(np.uint8) + image = rand_gen.random((10, 10, 10)) # .astype(np.uint8) image_layer = view.add_image(image, name="image") label_layer = view.add_labels(image.astype(np.uint8), name="labels") diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py index 0197c3a2..c095b4e9 100644 --- a/napari_cellseg3d/_tests/test_training.py +++ b/napari_cellseg3d/_tests/test_training.py @@ -17,6 +17,8 @@ ) from napari_cellseg3d.config import MODEL_LIST +WANDB_MODE = "disabled" + im_path = Path(__file__).resolve().parent / "res/test.tif" im_path_str = str(im_path) lab_path = Path(__file__).resolve().parent / "res/test_labels.tif" diff --git a/napari_cellseg3d/code_plugins/plugin_convert.py b/napari_cellseg3d/code_plugins/plugin_convert.py index b723e61c..dcc18c05 100644 --- a/napari_cellseg3d/code_plugins/plugin_convert.py +++ b/napari_cellseg3d/code_plugins/plugin_convert.py @@ -363,6 +363,12 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.container = self._build() self.function = clear_small_objects + self._set_tooltips() + + def _set_tooltips(self): + self.size_for_removal_counter.setToolTip( + "Size of the objects to remove, in pixels." + ) def _build(self): container = ui.ContainerWidget() @@ -647,6 +653,15 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent=None): self.container = self._build() self.function = threshold + self._set_tooltips() + + def _set_tooltips(self): + self.binarize_counter.setToolTip( + "Value to use as threshold for binarization." + "For labels, use the highest ID you want to keep. All lower IDs will be removed." + "For images, use the intensity value (pixel value) to threshold the image." + ) + def _build(self): container = ui.ContainerWidget() diff --git a/napari_cellseg3d/code_plugins/plugin_model_inference.py b/napari_cellseg3d/code_plugins/plugin_model_inference.py index 88effabc..27ff2d5e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_inference.py +++ b/napari_cellseg3d/code_plugins/plugin_model_inference.py @@ -281,6 +281,7 @@ def _set_tooltips(self): ) thresh_desc = ( + "NOT RECOMMENDED ON FIRST RUN - check results without first!\n" "Thresholding : all values in the image below the chosen probability" " threshold will be set to 0, and all others to 1." ) @@ -301,6 +302,7 @@ def _set_tooltips(self): "If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA" ) self.use_instance_choice.setToolTip( + "NOT RECOMMENDED ON FIRST RUN - check results without first!\n" "Instance segmentation will convert instance (0/1) labels to labels" " that attempt to assign an unique ID to each cell." ) @@ -653,6 +655,8 @@ def _display_results(self, result: InferenceResult): if result.semantic_segmentation[channel].sum() > 0: index_channel_least_labelled = channel break + # if no channel has any label, use the first one + index_channel_least_labelled = 0 viewer.dims.set_point( 0, index_channel_least_labelled ) # TODO(cyril: check if this is always the right axis diff --git a/napari_cellseg3d/code_plugins/plugin_model_training.py b/napari_cellseg3d/code_plugins/plugin_model_training.py index 5d727956..7d43265e 100644 --- a/napari_cellseg3d/code_plugins/plugin_model_training.py +++ b/napari_cellseg3d/code_plugins/plugin_model_training.py @@ -1247,7 +1247,7 @@ def on_finish(self): self.log.print_and_log("*" * 10) try: self._make_csv() - except ValueError as e: + except (ValueError, KeyError) as e: logger.warning(f"Error while saving CSV report: {e}") self.start_btn.setText("Start") @@ -1375,11 +1375,11 @@ def _make_csv(self): try: self.loss_1_values["Loss"] supervised = True - except KeyError("Loss"): + except KeyError: try: self.loss_1_values["SoftNCuts"] supervised = False - except KeyError("SoftNCuts") as e: + except KeyError as e: raise KeyError( "Error when making csv. Check loss dict keys ?" ) from e @@ -1398,8 +1398,8 @@ def _make_csv(self): "validation": val, } ) - if len(val) != len(self.loss_1_values): - err = f"Validation and loss values don't have the same length ! Got {len(val)} and {len(self.loss_1_values)}" + if len(val) != len(self.loss_1_values["Loss"]): + err = f"Validation and loss values don't have the same length ! Got {len(val)} and {len(self.loss_1_values['Loss'])}" logger.error(err) raise ValueError(err) else: diff --git a/napari_cellseg3d/interface.py b/napari_cellseg3d/interface.py index 0b4da851..d06a5554 100644 --- a/napari_cellseg3d/interface.py +++ b/napari_cellseg3d/interface.py @@ -1,7 +1,9 @@ """User interface functions and aliases.""" +import contextlib import threading from functools import partial from typing import List, Optional +from warnings import warn import napari @@ -804,9 +806,12 @@ def __init__( self.layer_description.setVisible(False) # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? + # connect to LayerList events self._viewer.layers.events.inserted.connect(partial(self._add_layer)) self._viewer.layers.events.removed.connect(partial(self._remove_layer)) + self._viewer.layers.events.changed.connect(self._check_for_layers) + # update self.layer_list when layers are added or removed self.layer_list.currentIndexChanged.connect(self._update_tooltip) self.layer_list.currentTextChanged.connect(self._update_description) @@ -816,10 +821,68 @@ def __init__( ) self._check_for_layers() + def _get_all_layers(self): + return [ + self.layer_list.itemText(i) for i in range(self.layer_list.count()) + ] + def _check_for_layers(self): + """Check for layers of the correct type and update the dropdown menu. + + Also removes layers that have been removed from the viewer. + """ for layer in self._viewer.layers: - if isinstance(layer, self.layer_type): + layer.events.name.connect(self._rename_layer) + + if ( + isinstance(layer, self.layer_type) + and layer.name not in self._get_all_layers() + ): + logger.debug( + f"Layer {layer.name} - List : {self._get_all_layers()}" + ) + # add new layers of correct type self.layer_list.addItem(layer.name) + logger.debug(f"Layer {layer.name} has been added to the menu") + # break + # once added, check again for previously renamed layers + self._check_for_removed_layer(layer) + + if layer.name in self._get_all_layers() and not isinstance( + layer, self.layer_type + ): + # remove layers of incorrect type + index = self.layer_list.findText(layer.name) + self.layer_list.removeItem(index) + logger.debug( + f"Layer {layer.name} has been removed from the menu" + ) + + self._check_for_removed_layers() + self._update_tooltip() + self._update_description() + + def _check_for_removed_layer(self, layer): + """Check if a specific layer has been removed from the viewer and must be removed from the menu.""" + if isinstance(layer, str): + name = layer + elif isinstance(layer, self.layer_type): + name = layer.name + else: + logger.warning("Layer is not a string or a valid napari layer") + return + + if name in self._get_all_layers() and name not in [ + l.name for l in self._viewer.layers + ]: + index = self.layer_list.findText(name) + self.layer_list.removeItem(index) + logger.debug(f"Layer {name} has been removed from the menu") + + def _check_for_removed_layers(self): + """Check for layers that have been removed from the viewer and must be removed from the menu.""" + for layer in self._get_all_layers(): + self._check_for_removed_layer(layer) def _update_tooltip(self): self.layer_list.setToolTip(self.layer_list.currentText()) @@ -827,9 +890,12 @@ def _update_tooltip(self): def _update_description(self): try: if self.layer_list.currentText() != "": - self.layer_description.setVisible(True) - shape_desc = f"Shape : {self.layer_data().shape}" - self.layer_description.setText(shape_desc) + try: + shape_desc = f"Shape : {self.layer_data().shape}" + self.layer_description.setText(shape_desc) + self.layer_description.setVisible(True) + except AttributeError: + self.layer_description.setVisible(False) else: self.layer_description.setVisible(False) except KeyError: @@ -841,6 +907,13 @@ def _add_layer(self, event): if isinstance(inserted_layer, self.layer_type): self.layer_list.addItem(inserted_layer.name) + # check for renaming + inserted_layer.events.name.connect(self._rename_layer) + + def _rename_layer(self, _): + # on layer rename, check for removed/new layers + self._check_for_layers() + def _remove_layer(self, event): removed_layer = event.value @@ -867,15 +940,24 @@ def layer(self): def layer_name(self): """Returns the name of the layer selected in the dropdown menu.""" - return self.layer_list.currentText() + try: + return self.layer_list.currentText() + except (KeyError, ValueError): + logger.warning("Layer list is empty") + return None def layer_data(self): """Returns the data of the layer selected in the dropdown menu.""" if self.layer_list.count() < 1: logger.debug("Layer list is empty") return None - - return self.layer().data + try: + return self.layer().data + except (KeyError, ValueError): + msg = f"Layer {self.layer_name()} has no data. Layer might have been renamed or removed." + logger.warning(msg) + warn(msg, stacklevel=1) + return None class FilePathWidget(QWidget): # TODO include load as folder diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index b340f859..45f50582 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -15,6 +15,7 @@ LOGGER = logging.getLogger(__name__) ############### # Global logging level setting +# SET TO INFO FOR RELEASE # LOGGER.setLevel(logging.DEBUG) LOGGER.setLevel(logging.INFO) ###############