diff --git a/MANIFEST.in b/MANIFEST.in index f3155af..411b69a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,3 @@ -include LICENSE -include README.md - +include src/napari_cellulus/sample_data/*.npy +include src/napari_cellulus/napari.yaml recursive-exclude * __pycache__ -recursive-exclude * *.py[co] diff --git a/README.md b/README.md index 352e0aa..3123f90 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ +

A napari plugin for cellulus

+ + - **[Introduction](#introduction)** - **[Installation](#installation)** - **[Getting Started](#getting-started)** @@ -55,8 +58,10 @@ Run the following commands in a terminal window: conda activate napari-cellulus napari ``` +[demo_cellulus.webm](https://github.com/funkelab/napari-cellulus/assets/34229641/35cb09de-c875-487d-9890-86082dcd95b2) + + -Next, select `Cellulus` from the `Plugins` drop-down menu. ### Citation diff --git a/setup.cfg b/setup.cfg index ce0031b..8fd5c50 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,10 @@ package_dir = [options.packages.find] where = src +[options.package_data] +napari_cellulus = *.npy +* = *.yaml + [options.entry_points] napari.manifest = napari-cellulus = napari_cellulus:napari.yaml @@ -58,7 +62,3 @@ testing = pytest-qt # https://pytest-qt.readthedocs.io/en/latest/ napari pyqt5 - - -[options.package_data] -* = *.yaml diff --git a/src/napari_cellulus/__init__.py b/src/napari_cellulus/__init__.py index bf80da0..b666730 100644 --- a/src/napari_cellulus/__init__.py +++ b/src/napari_cellulus/__init__.py @@ -1,5 +1,5 @@ __version__ = "0.1.0" -from .sample_data import tissue_net_sample +from .load_sample_data import load_fluo_c2dl_huh7_sample -__all__ = ("tissue_net_sample",) +__all__ = ("load_fluo_c2dl_huh7_sample",) diff --git a/src/napari_cellulus/datasets/napari_dataset.py b/src/napari_cellulus/datasets/napari_dataset.py index 3952ab8..0a4b324 100644 --- a/src/napari_cellulus/datasets/napari_dataset.py +++ b/src/napari_cellulus/datasets/napari_dataset.py @@ -71,6 +71,7 @@ def __setup_pipeline(self): ) + gp.Unsqueeze([self.raw], 0) + gp.RandomLocation() + + gp.Normalize(self.raw, factor=self.normalization_factor) ) else: self.pipeline = ( @@ -81,6 +82,7 @@ def __setup_pipeline(self): spatial_dims=self.spatial_dims, ) + gp.RandomLocation() + + gp.Normalize(self.raw, factor=self.normalization_factor) ) def __iter__(self): @@ -236,9 +238,20 @@ def sample_coordinates(self): return anchor_samples, reference_samples def get_num_anchors(self): - return int( - self.density * self.unbiased_shape[0] * self.unbiased_shape[1] - ) + if self.num_spatial_dims == 2: + return int( + self.density * self.unbiased_shape[0] * self.unbiased_shape[1] + ) + elif self.num_spatial_dims == 3: + return int( + self.density + * self.unbiased_shape[0] + * self.unbiased_shape[1] + * self.unbiased_shape[2] + ) def get_num_references(self): - return int(self.density * self.kappa**2 * np.pi) + if self.num_spatial_dims == 2: + return int(self.density * np.pi * self.kappa**2) + elif self.num_spatial_dims == 3: + return int(self.density * 4 / 3 * np.pi * self.kappa**3) diff --git a/src/napari_cellulus/datasets/napari_image_source.py b/src/napari_cellulus/datasets/napari_image_source.py index bee332f..2956af2 100644 --- a/src/napari_cellulus/datasets/napari_image_source.py +++ b/src/napari_cellulus/datasets/napari_image_source.py @@ -19,6 +19,7 @@ def __init__( self, image: Image, key: gp.ArrayKey, spec: ArraySpec, spatial_dims ): self.array_spec = spec + self.image = gp.Array( data=normalize( image.data.astype(np.float32), 1, 99.8, axis=spatial_dims diff --git a/src/napari_cellulus/load_sample_data.py b/src/napari_cellulus/load_sample_data.py new file mode 100644 index 0000000..61c4e45 --- /dev/null +++ b/src/napari_cellulus/load_sample_data.py @@ -0,0 +1,25 @@ +from pathlib import Path + +import numpy as np + +FLUO_C2DL_HUH7_SAMPLE_PATH = ( + Path(__file__).parent / "sample_data/Fluo-C2DL-Huh7-sample.npy" +) + + +def load_fluo_c2dl_huh7_sample(): + raw = np.load(FLUO_C2DL_HUH7_SAMPLE_PATH) + num_samples = raw.shape[0] + indices = np.random.choice(np.arange(num_samples), 5, replace=False) + raw = raw[indices] + + return [ + ( + raw, + { + "name": "Raw", + "metadata": {"axes": ["s", "c", "y", "x"]}, + }, + "image", + ) + ] diff --git a/src/napari_cellulus/model.py b/src/napari_cellulus/model.py new file mode 100644 index 0000000..f262e84 --- /dev/null +++ b/src/napari_cellulus/model.py @@ -0,0 +1,48 @@ +import torch + + +class Model(torch.nn.Module): + """ + This class is a wrapper on the model object returned by cellulus. + It updates the `forward` function and handles cases when the input raw + image is not (S, C, (Z), Y, X) type. + """ + + def __init__(self, model, selected_axes): + super().__init__() + self.model = model + self.selected_axes = selected_axes + + def forward(self, raw): + if "s" in self.selected_axes and "c" in self.selected_axes: + pass + elif "s" in self.selected_axes and "c" not in self.selected_axes: + + raw = torch.unsqueeze(raw, 1) + elif "s" not in self.selected_axes and "c" in self.selected_axes: + pass + elif "s" not in self.selected_axes and "c" not in self.selected_axes: + raw = torch.unsqueeze(raw, 1) + return self.model(raw) + + @staticmethod + def select_and_add_coordinates(outputs, coordinates): + selections = [] + # outputs.shape = (b, c, h, w) or (b, c, d, h, w) + for output, coordinate in zip(outputs, coordinates): + if output.ndim == 3: + selection = output[:, coordinate[:, 1], coordinate[:, 0]] + elif output.ndim == 4: + selection = output[ + :, coordinate[:, 2], coordinate[:, 1], coordinate[:, 0] + ] + selection = selection.transpose(1, 0) + selection += coordinate + selections.append(selection) + + # selection.shape = (b, c, p) where p is the number of selected positions + return torch.stack(selections, dim=0) + + def set_infer(self, p_salt_pepper, num_infer_iterations, device): + self.model.eval() + self.model.set_infer(p_salt_pepper, num_infer_iterations, device) diff --git a/src/napari_cellulus/napari.yaml b/src/napari_cellulus/napari.yaml index 39ac969..0c1c0ae 100644 --- a/src/napari_cellulus/napari.yaml +++ b/src/napari_cellulus/napari.yaml @@ -1,23 +1,17 @@ name: napari-cellulus -display_name: Cellulus +display_name: napari-cellulus contributions: commands: - - id: napari-cellulus.tissue_net_sample - python_name: napari_cellulus.sample_data:tissue_net_sample - title: Load sample data from Cellulus - - id: napari-cellulus.fluo_n2dl_hela_sample - python_name: napari_cellulus.sample_data:fluo_n2dl_hela_sample - title: Load sample data from Cellulus + - id: napari-cellulus.load_fluo_c2dl_huh7_sample + python_name: napari_cellulus.load_sample_data:load_fluo_c2dl_huh7_sample + title: Load sample data - id: napari-cellulus.Widget python_name: napari_cellulus.widget:Widget title: Cellulus sample_data: - - command: napari-cellulus.tissue_net_sample - display_name: TissueNet - key: tissue_net_sample - - command: napari-cellulus.fluo_n2dl_hela_sample - display_name: Fluo-N2DL-HeLa - key: fluo_n2dl_hela_sample + - command: napari-cellulus.load_fluo_c2dl_huh7_sample + display_name: Fluo-C2DL-Huh7 + key: load_fluo_c2dl_huh7_sample widgets: - command: napari-cellulus.Widget display_name: Cellulus diff --git a/src/napari_cellulus/sample_data.py b/src/napari_cellulus/sample_data.py deleted file mode 100644 index 2a4efc0..0000000 --- a/src/napari_cellulus/sample_data.py +++ /dev/null @@ -1,45 +0,0 @@ -from pathlib import Path - -import numpy as np -import tifffile - -TISSUE_NET_SAMPLE = Path(__file__).parent / "sample_data/tissue_net_sample.npy" -FLUO_N2DL_HELA = Path(__file__).parent / "sample_data/fluo_n2dl_hela.tif" - - -def fluo_n2dl_hela_sample(): - x = tifffile.imread(FLUO_N2DL_HELA) - return [ - ( - x, - { - "name": "Raw", - "metadata": {"axes": ["s", "c", "y", "x"]}, - }, - "image", - ) - ] - - -def tissue_net_sample(): - (x, y) = np.load(TISSUE_NET_SAMPLE, "r") - x = x.transpose(0, 3, 1, 2) - y = y.transpose(0, 3, 1, 2).astype(np.uint8) - return [ - ( - x, - { - "name": "Raw", - "metadata": {"axes": ["s", "c", "y", "x"]}, - }, - "image", - ), - ( - y, - { - "name": "Labels", - "metadata": {"axes": ["s", "c", "y", "x"]}, - }, - "Labels", - ), - ] diff --git a/src/napari_cellulus/sample_data/Fluo-C2DL-Huh7-sample.npy b/src/napari_cellulus/sample_data/Fluo-C2DL-Huh7-sample.npy new file mode 100644 index 0000000..1fe3ad3 Binary files /dev/null and b/src/napari_cellulus/sample_data/Fluo-C2DL-Huh7-sample.npy differ diff --git a/src/napari_cellulus/sample_data/__init__.py b/src/napari_cellulus/sample_data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif b/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif deleted file mode 100644 index f28870c..0000000 Binary files a/src/napari_cellulus/sample_data/fluo_n2dl_hela.tif and /dev/null differ diff --git a/src/napari_cellulus/sample_data/tissue_net_sample.npy b/src/napari_cellulus/sample_data/tissue_net_sample.npy deleted file mode 100644 index a9191d4..0000000 Binary files a/src/napari_cellulus/sample_data/tissue_net_sample.npy and /dev/null differ diff --git a/src/napari_cellulus/widget.py b/src/napari_cellulus/widget.py index bbad8b8..d5f8f46 100644 --- a/src/napari_cellulus/widget.py +++ b/src/napari_cellulus/widget.py @@ -21,6 +21,7 @@ QButtonGroup, QCheckBox, QComboBox, + QFileDialog, QGridLayout, QLabel, QLineEdit, @@ -29,6 +30,7 @@ QRadioButton, QScrollArea, QVBoxLayout, + QWidget, ) from scipy.ndimage import binary_fill_holes from scipy.ndimage import distance_transform_edt as dtedt @@ -37,47 +39,7 @@ from .datasets.napari_dataset import NapariDataset from .datasets.napari_image_source import NapariImageSource - - -class Model(torch.nn.Module): - def __init__(self, model, selected_axes): - super().__init__() - self.model = model - self.selected_axes = selected_axes - - def forward(self, x): - if "s" in self.selected_axes and "c" in self.selected_axes: - pass - elif "s" in self.selected_axes and "c" not in self.selected_axes: - - x = torch.unsqueeze(x, 1) - elif "s" not in self.selected_axes and "c" in self.selected_axes: - pass - elif "s" not in self.selected_axes and "c" not in self.selected_axes: - x = torch.unsqueeze(x, 1) - return self.model(x) - - @staticmethod - def select_and_add_coordinates(outputs, coordinates): - selections = [] - # outputs.shape = (b, c, h, w) or (b, c, d, h, w) - for output, coordinate in zip(outputs, coordinates): - if output.ndim == 3: - selection = output[:, coordinate[:, 1], coordinate[:, 0]] - elif output.ndim == 4: - selection = output[ - :, coordinate[:, 2], coordinate[:, 1], coordinate[:, 0] - ] - selection = selection.transpose(1, 0) - selection += coordinate - selections.append(selection) - - # selection.shape = (b, c, p) where p is the number of selected positions - return torch.stack(selections, dim=0) - - def set_infer(self, p_salt_pepper, num_infer_iterations, device): - self.model.eval() - self.model.set_infer(p_salt_pepper, num_infer_iterations, device) +from .model import Model class Widget(QMainWindow): @@ -85,6 +47,7 @@ def __init__(self, napari_viewer): super().__init__() self.viewer = napari_viewer self.scroll = QScrollArea() + self.widget = QWidget() # initialize outer layout layout = QVBoxLayout() @@ -105,10 +68,6 @@ def __init__(self, napari_viewer): self.set_grid_6() self.grid_7 = QGridLayout() # feedback self.set_grid_7() - self.create_configs() # configs - self.viewer.dims.events.current_step.connect( - self.update_inference_widgets - ) # listen to viewer slider layout.addLayout(self.grid_0) layout.addLayout(self.grid_1) @@ -118,11 +77,18 @@ def __init__(self, napari_viewer): layout.addLayout(self.grid_5) layout.addLayout(self.grid_6) layout.addLayout(self.grid_7) - self.set_scroll_area(layout) + self.widget.setLayout(layout) + self.set_scroll_area() self.viewer.layers.events.inserted.connect(self.update_raw_selector) self.viewer.layers.events.removed.connect(self.update_raw_selector) def update_raw_selector(self, event): + """ + Whenever a new image is added or removed by the user, + this function is called. + It updates the `raw_selector` attribute. + + """ count = 0 for i in range(self.raw_selector.count() - 1, -1, -1): if self.raw_selector.itemText(i) == f"{event.value}": @@ -133,6 +99,9 @@ def update_raw_selector(self, event): self.raw_selector.addItems([f"{event.value}"]) def set_grid_0(self): + """ + Specifies the title of the plugin. + """ text_label = QLabel("

Cellulus

") method_description_label = QLabel( 'Unsupervised Learning of Object-Centric Embeddings
for Cell Instance Segmentation in Microscopy Images.
If you are using this in your research, please cite us.

https://github.com/funkelab/cellulus' @@ -141,6 +110,9 @@ def set_grid_0(self): self.grid_0.addWidget(method_description_label, 1, 0, 2, 1) def set_grid_1(self): + """ + Specifies the device used for training and inference. + """ device_label = QLabel(self) device_label.setText("Device") self.device_combo_box = QComboBox(self) @@ -152,6 +124,10 @@ def set_grid_1(self): self.grid_1.addWidget(self.device_combo_box, 0, 1, 1, 1) def set_grid_2(self): + """ + Specifies the raw_selector attribute. + This is needed to identify which axes does the image contain. + """ self.raw_selector = QComboBox(self) for layer in self.viewer.layers: self.raw_selector.addItem(f"{layer}") @@ -169,6 +145,9 @@ def set_grid_2(self): self.grid_2.addWidget(self.x_check_box, 1, 4, 1, 1) def set_grid_3(self): + """ + Specifies the configuration parameters for training. + """ crop_size_label = QLabel(self) crop_size_label.setText("Crop Size") self.crop_size_line = QLineEdit(self) @@ -183,7 +162,7 @@ def set_grid_3(self): max_iterations_label.setText("Max iterations") self.max_iterations_line = QLineEdit(self) self.max_iterations_line.setAlignment(Qt.AlignCenter) - self.max_iterations_line.setText("100000") + self.max_iterations_line.setText("5000") self.grid_3.addWidget(crop_size_label, 0, 0, 1, 1) self.grid_3.addWidget(self.crop_size_line, 0, 1, 1, 1) self.grid_3.addWidget(batch_size_label, 1, 0, 1, 1) @@ -192,6 +171,9 @@ def set_grid_3(self): self.grid_3.addWidget(self.max_iterations_line, 2, 1, 1, 1) def set_grid_4(self): + """ + Specifies the configuration parameters for the model. + """ feature_maps_label = QLabel(self) feature_maps_label.setText("Number of feature maps") self.feature_maps_line = QLineEdit(self) @@ -203,59 +185,80 @@ def set_grid_4(self): self.feature_maps_increase_line.setAlignment(Qt.AlignCenter) self.feature_maps_increase_line.setText("3") self.train_model_from_scratch_checkbox = QCheckBox( - "Train model from scratch" + "Train from scratch" ) - + self.train_model_from_scratch_checkbox.stateChanged.connect( + self.affect_load_weights + ) + self.load_model_button = QPushButton("Load weights") + self.load_model_button.clicked.connect(self.load_weights) self.train_model_from_scratch_checkbox.setChecked(False) self.grid_4.addWidget(feature_maps_label, 0, 0, 1, 1) self.grid_4.addWidget(self.feature_maps_line, 0, 1, 1, 1) self.grid_4.addWidget(feature_maps_increase_label, 1, 0, 1, 1) self.grid_4.addWidget(self.feature_maps_increase_line, 1, 1, 1, 1) self.grid_4.addWidget( - self.train_model_from_scratch_checkbox, 2, 0, 1, 2 + self.train_model_from_scratch_checkbox, 2, 0, 1, 1 ) + self.grid_4.addWidget(self.load_model_button, 2, 1, 1, 1) def set_grid_5(self): + """ + Specifies the loss widget. + """ self.losses_widget = pg.PlotWidget() self.losses_widget.setBackground((37, 41, 49)) styles = {"color": "white", "font-size": "16px"} self.losses_widget.setLabel("left", "Loss", **styles) self.losses_widget.setLabel("bottom", "Iterations", **styles) self.start_training_button = QPushButton("Start training") - self.start_training_button.setFixedSize(140, 30) + self.start_training_button.setFixedSize(88, 30) self.stop_training_button = QPushButton("Stop training") - self.stop_training_button.setFixedSize(140, 30) + self.stop_training_button.setFixedSize(88, 30) + self.save_weights_button = QPushButton("Save weights") + self.save_weights_button.setFixedSize(88, 30) self.grid_5.addWidget(self.losses_widget, 0, 0, 4, 4) - self.grid_5.addWidget(self.start_training_button, 5, 0, 1, 2) - self.grid_5.addWidget(self.stop_training_button, 5, 2, 1, 2) + self.grid_5.addWidget(self.start_training_button, 5, 0, 1, 1) + self.grid_5.addWidget(self.stop_training_button, 5, 1, 1, 1) + self.grid_5.addWidget(self.save_weights_button, 5, 2, 1, 1) + self.start_training_button.clicked.connect( self.prepare_for_start_training ) self.stop_training_button.clicked.connect( self.prepare_for_stop_training ) + self.save_weights_button.clicked.connect(self.save_weights) def set_grid_6(self): + """ + Specifies the inference configuration parameters. + """ threshold_label = QLabel("Threshold") self.threshold_line = QLineEdit(self) + self.threshold_line.textChanged.connect(self.prepare_thresholds) self.threshold_line.setAlignment(Qt.AlignCenter) self.threshold_line.setText(None) bandwidth_label = QLabel("Bandwidth") self.bandwidth_line = QLineEdit(self) self.bandwidth_line.setAlignment(Qt.AlignCenter) + self.bandwidth_line.textChanged.connect(self.prepare_bandwidths) self.radio_button_group = QButtonGroup(self) self.radio_button_cell = QRadioButton("Cell") self.radio_button_nucleus = QRadioButton("Nucleus") self.radio_button_group.addButton(self.radio_button_nucleus) self.radio_button_group.addButton(self.radio_button_cell) - - self.radio_button_nucleus.setChecked(True) + self.radio_button_cell.toggled.connect(self.update_post_processing) + self.radio_button_nucleus.toggled.connect(self.update_post_processing) + self.radio_button_cell.setChecked(True) self.min_size_label = QLabel("Minimum Size") self.min_size_line = QLineEdit(self) self.min_size_line.setAlignment(Qt.AlignCenter) + self.min_size_line.textChanged.connect(self.prepare_min_sizes) + self.start_inference_button = QPushButton("Start inference") self.start_inference_button.setFixedSize(140, 30) self.stop_inference_button = QPushButton("Stop inference") @@ -263,6 +266,7 @@ def set_grid_6(self): self.grid_6.addWidget(threshold_label, 0, 0, 1, 1) self.grid_6.addWidget(self.threshold_line, 0, 1, 1, 1) + self.grid_6.addWidget(bandwidth_label, 1, 0, 1, 1) self.grid_6.addWidget(self.bandwidth_line, 1, 1, 1, 1) self.grid_6.addWidget(self.radio_button_cell, 2, 0, 1, 1) @@ -279,22 +283,33 @@ def set_grid_6(self): ) def set_grid_7(self): - # Initialize Feedback Button + """ + Specifies the feedback URL. + """ + feedback_label = QLabel( 'Please share any feedback here.' ) self.grid_7.addWidget(feedback_label, 0, 0, 2, 1) - def set_scroll_area(self, layout): - self.scroll.setLayout(layout) + def set_scroll_area(self): + """ + Creates a scroll area. + In case the main napari window is resized, the scroll area + would appear. + """ + self.scroll.setWidget(self.widget) self.scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) self.scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) self.scroll.setWidgetResizable(True) - self.setFixedWidth(300) + self.setFixedWidth(320) self.setCentralWidget(self.scroll) def get_selected_axes(self): + """ + Returns the axes based on which of the checkboxes were selected. + """ names = [] for name, check_box in zip( "sczyx", @@ -312,57 +327,78 @@ def get_selected_axes(self): return names def create_configs(self): - self.train_config = TrainConfig( - crop_size=[int(self.crop_size_line.text())], - batch_size=int(self.batch_size_line.text()), - max_iterations=int(self.max_iterations_line.text()), - device=self.device_combo_box.currentText(), - ) - self.model_config = ModelConfig( - num_fmaps=int(self.feature_maps_line.text()), - fmap_inc_factor=int(self.feature_maps_increase_line.text()), - ) + """ + This reads from the various line edits and initializes config objects. + """ + if not hasattr(self, "train_config"): + self.train_config = TrainConfig( + crop_size=[int(self.crop_size_line.text())], + batch_size=int(self.batch_size_line.text()), + max_iterations=int(self.max_iterations_line.text()), + device=self.device_combo_box.currentText(), + ) + if not hasattr(self, "model_config"): + self.model_config = ModelConfig( + num_fmaps=int(self.feature_maps_line.text()), + fmap_inc_factor=int(self.feature_maps_increase_line.text()), + ) + if not hasattr(self, "experiment_config"): + self.experiment_config = ExperimentConfig( + train_config=asdict(self.train_config), + model_config=asdict(self.model_config), + normalization_factor=1.0, + ) + if not hasattr(self, "losses"): + self.losses = [] + if not hasattr(self, "iterations"): + self.iterations = [] + if not hasattr(self, "start_iteration"): + self.start_iteration = 0 - self.experiment_config = ExperimentConfig( - train_config=asdict(self.train_config), - model_config=asdict(self.model_config), - ) - self.losses, self.iterations = [], [] - self.start_iteration = 0 self.model_dir = "/tmp/models" - self.thresholds = [] - self.band_widths = [] - self.min_sizes = [] - if len(self.thresholds) == 0: - self.threshold_line.setEnabled(False) - if len(self.band_widths) == 0: - self.bandwidth_line.setEnabled(False) - if len(self.min_sizes) == 0: - self.min_size_line.setEnabled(False) + self.threshold_line.setEnabled(False) + self.bandwidth_line.setEnabled(False) + self.min_size_line.setEnabled(False) def update_inference_widgets(self, event: Event): + """ + This function listens to which sample the viewer is currently on, + and displays the corresponding inference config parameter for the + present sample. + """ if self.s_check_box.isChecked(): shape = event.value sample_index = shape[0] - if len(self.thresholds) == self.napari_dataset.get_num_samples(): - if self.thresholds[sample_index]!=None: - self.threshold_line.setText( - str(round(self.thresholds[sample_index], 3)) - ) - if len(self.band_widths) == self.napari_dataset.get_num_samples(): - if self.band_widths[sample_index]!=None: - self.bandwidth_line.setText( - str(round(self.band_widths[sample_index], 3)) - ) - if len(self.min_sizes) == self.napari_dataset.get_num_samples(): - if self.min_sizes[sample_index]!=None: - self.min_size_line.setText( - str(round(self.min_sizes[sample_index], 3)) - ) + if ( + hasattr(self, "thresholds") + and self.thresholds[sample_index] is not None + ): + self.threshold_line.setText( + str(round(self.thresholds[sample_index], 3)) + ) + if ( + hasattr(self, "bandwidths") + and self.bandwidths[sample_index] is not None + ): + self.bandwidth_line.setText( + str(round(self.bandwidths[sample_index], 3)) + ) + if ( + hasattr(self, "min_sizes") + and self.min_sizes[sample_index] is not None + ): + self.min_size_line.setText( + str(round(self.min_sizes[sample_index], 3)) + ) def prepare_for_start_training(self): + """ + Each time the `train` button is clicked, + the inference config line edits and other buttons are disabled. + """ self.start_training_button.setEnabled(False) self.stop_training_button.setEnabled(True) + self.save_weights_button.setEnabled(False) self.threshold_line.setEnabled(False) self.bandwidth_line.setEnabled(False) self.radio_button_nucleus.setEnabled(False) @@ -373,10 +409,48 @@ def prepare_for_start_training(self): self.train_worker = self.train() self.train_worker.yielded.connect(self.on_yield_training) + self.train_worker.returned.connect(self.prepare_for_stop_training) self.train_worker.start() + def remove_inference_attributes(self): + """ + When training is initiated, then existing attributes such as + `embeddings`, `detection` etc are removed. + """ + if hasattr(self, "embeddings"): + delattr(self, "embeddings") + if hasattr(self, "detection"): + delattr(self, "detection") + if hasattr(self, "segmentation"): + delattr(self, "segmentation") + if hasattr(self, "thresholds"): + delattr(self, "thresholds") + if hasattr(self, "thresholds_last"): + delattr(self, "thresholds_last") + if hasattr(self, "bandwidths"): + delattr(self, "bandwidths") + if hasattr(self, "bandwidths_last"): + delattr(self, "bandwidths_last") + if hasattr(self, "min_sizes"): + delattr(self, "min_sizes") + if hasattr(self, "min_sizes_last"): + delattr(self, "min_sizes_last") + if hasattr(self, "post_processing"): + delattr(self, "post_processing") + if hasattr(self, "post_processing_last"): + delattr(self, "post_processing_last") + @thread_worker def train(self): + """ + The main function where training happens! + """ + self.create_configs() # configs + self.remove_inference_attributes() + self.viewer.dims.events.current_step.connect( + self.update_inference_widgets + ) # listen to viewer slider + for layer in self.viewer.layers: if f"{layer}" == self.raw_selector.currentText(): raw_image_layer = layer @@ -420,6 +494,7 @@ def train(self): # Set device self.device = torch.device(self.train_config.device) + model = Model( model=model_original, selected_axes=self.get_selected_axes() ) @@ -448,6 +523,8 @@ def train(self): lr=self.train_config.initial_learning_rate, weight_decay=0.01, ) + if hasattr(self, "pre_trained_model_checkpoint"): + self.model_config.checkpoint = self.pre_trained_model_checkpoint # Resume training if self.train_model_from_scratch_checkbox.isChecked(): @@ -474,6 +551,7 @@ def train(self): ) # Call Train Iteration + for iteration, batch in tqdm( zip( range(self.start_iteration, self.train_config.max_iterations), @@ -488,8 +566,12 @@ def train(self): device=self.device, ) yield loss, iteration + return def on_yield_training(self, loss_iteration): + """ + The loss plot is updated every training iteration. + """ loss, iteration = loss_iteration print(f"===> Iteration: {iteration}, loss: {loss:.6f}") self.iterations.append(iteration) @@ -497,19 +579,23 @@ def on_yield_training(self, loss_iteration): self.losses_widget.plot(self.iterations, self.losses) def prepare_for_stop_training(self): + """ + This function defines the sequence of events once training is stopped. + """ self.start_training_button.setEnabled(True) self.stop_training_button.setEnabled(True) - if len(self.thresholds) == 0: + self.save_weights_button.setEnabled(True) + if not hasattr(self, "thresholds"): self.threshold_line.setEnabled(False) else: self.threshold_line.setEnabled(True) - if len(self.band_widths) == 0: + if not hasattr(self, "bandwidths"): self.bandwidth_line.setEnabled(False) else: self.bandwidth_line.setEnabled(True) self.radio_button_nucleus.setEnabled(True) self.radio_button_cell.setEnabled(True) - if len(self.min_sizes) == 0: + if not hasattr(self, "min_sizes"): self.min_size_line.setEnabled(False) else: self.min_size_line.setEnabled(True) @@ -528,8 +614,12 @@ def prepare_for_stop_training(self): self.model_config.checkpoint = checkpoint_file_name def prepare_for_start_inference(self): + """ + When the inference begins, then training-related buttons are disabled. + """ self.start_training_button.setEnabled(False) self.stop_training_button.setEnabled(False) + self.save_weights_button.setEnabled(False) self.threshold_line.setEnabled(False) self.bandwidth_line.setEnabled(False) self.radio_button_nucleus.setEnabled(False) @@ -546,13 +636,17 @@ def prepare_for_start_inference(self): ) self.inference_worker = self.infer() - # self.inference_worker.yielded.connect(self.on_yield_infer) self.inference_worker.returned.connect(self.on_return_infer) self.inference_worker.start() def prepare_for_stop_inference(self): + """ + This function defines the sequence of events which ensue once inference is stopped. + """ self.start_training_button.setEnabled(True) self.stop_training_button.setEnabled(True) + self.save_weights_button.setEnabled(True) + self.threshold_line.setEnabled(True) self.bandwidth_line.setEnabled(True) self.radio_button_nucleus.setEnabled(True) @@ -562,33 +656,48 @@ def prepare_for_stop_inference(self): self.stop_inference_button.setEnabled(True) if self.napari_dataset.get_num_samples() == 0: self.threshold_line.setText(str(round(self.thresholds[0], 3))) - self.bandwidth_line.setText(str(round(self.band_widths[0], 3))) + self.bandwidth_line.setText(str(round(self.bandwidths[0], 3))) self.min_size_line.setText(str(round(self.min_sizes[0], 3))) + if self.inference_worker is not None: + self.inference_worker.quit() @thread_worker def infer(self): + """ + The main inference function. + """ for layer in self.viewer.layers: if f"{layer}" == self.raw_selector.currentText(): raw_image_layer = layer break - self.thresholds = ( - [None] * self.napari_dataset.get_num_samples() - if self.napari_dataset.get_num_samples() != 0 - else [None] * 1 - ) - if ( + if not hasattr(self, "thresholds"): + self.thresholds = ( + [None] * self.napari_dataset.get_num_samples() + if self.napari_dataset.get_num_samples() != 0 + else [None] * 1 + ) + + if not hasattr(self, "thresholds_last"): + self.thresholds_last = self.thresholds.copy() + + if not hasattr(self, "bandwidths") and ( self.inference_config.bandwidth is None - and len(self.band_widths) == 0 ): - self.band_widths = ( + self.bandwidths = ( [0.5 * self.experiment_config.object_size] * self.napari_dataset.get_num_samples() if self.napari_dataset.get_num_samples() != 0 else [0.5 * self.experiment_config.object_size] ) - if self.inference_config.min_size is None and len(self.min_sizes) == 0: + if not hasattr(self, "bandwidths_last"): + self.bandwidths_last = self.bandwidths.copy() + + if ( + not hasattr(self, "min_sizes") + and self.inference_config.min_size is None + ): if self.napari_dataset.get_num_spatial_dims() == 2: self.min_sizes = ( [ @@ -639,7 +748,20 @@ def infer(self): ] ) + if not hasattr(self, "min_sizes_last"): + self.min_sizes_last = self.min_sizes.copy() + + if not hasattr(self, "post_processing"): + self.post_processing = ( + "cell" if self.radio_button_cell.isChecked() else "nucleus" + ) + + if not hasattr(self, "post_processing_last"): + self.post_processing_last = self.post_processing + # set in eval mode + self.model = self.model.to(self.device) + self.model.eval() self.model.set_infer( p_salt_pepper=self.inference_config.p_salt_pepper, @@ -649,9 +771,14 @@ def infer(self): if self.napari_dataset.get_num_spatial_dims() == 2: crop_size_tuple = (self.inference_config.crop_size[0],) * 2 - + predicted_crop_size_tuple = ( + self.inference_config.crop_size[0] - 16, + ) * 2 elif self.napari_dataset.get_num_spatial_dims() == 3: crop_size_tuple = (self.inference_config.crop_size[0],) * 3 + predicted_crop_size_tuple = ( + self.inference_config.crop_size[0] - 16, + ) * 3 input_shape = gp.Coordinate( ( @@ -675,7 +802,12 @@ def infer(self): output_shape = gp.Coordinate( self.model( torch.zeros( - (1, 1, *crop_size_tuple), dtype=torch.float32 + ( + 1, + self.napari_dataset.get_num_channels(), + *crop_size_tuple, + ), + dtype=torch.float32, ).to(self.device) ).shape ) @@ -693,17 +825,20 @@ def infer(self): context = (input_size - output_size) // 2 raw = gp.ArrayKey("RAW") prediction = gp.ArrayKey("PREDICT") - scan_request = gp.BatchRequest() - # scan_request.add(raw, input_size) + scan_request = gp.BatchRequest() scan_request[raw] = gp.Roi( - (-8,) * (self.napari_dataset.get_num_spatial_dims()), + (-8,) * self.napari_dataset.get_num_spatial_dims(), crop_size_tuple, ) - scan_request.add(prediction, output_size) + scan_request[prediction] = gp.Roi( + (0,) * self.napari_dataset.get_num_spatial_dims(), + predicted_crop_size_tuple, + ) + predict = gp.torch.Predict( self.model, - inputs={"x": raw}, + inputs={"raw": raw}, outputs={0: prediction}, array_specs={prediction: gp.ArraySpec(voxel_size=voxel_size)}, ) @@ -748,25 +883,30 @@ def infer(self): # Obtain Embeddings print("Predicting Embeddings ...") - with gp.build(pipeline): - batch = pipeline.request_batch(request) + if hasattr(self, "embeddings"): + pass + else: + with gp.build(pipeline): + batch = pipeline.request_batch(request) - embeddings = batch.arrays[prediction].data - embeddings_centered = np.zeros_like(embeddings) - foreground_mask = np.zeros_like(embeddings[:, 0:1, ...], dtype=bool) + self.embeddings = batch.arrays[prediction].data + embeddings_centered = np.zeros_like(self.embeddings) + foreground_mask = np.zeros_like( + self.embeddings[:, 0:1, ...], dtype=bool + ) colormaps = ["red", "green", "blue"] # Obtain Object Centered Embeddings - for sample in tqdm(range(embeddings.shape[0])): - embeddings_sample = embeddings[sample] + for sample in tqdm(range(self.embeddings.shape[0])): + embeddings_sample = self.embeddings[sample] embeddings_std = embeddings_sample[-1, ...] embeddings_mean = embeddings_sample[ np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ... ].copy() - threshold = threshold_otsu(embeddings_std) - - self.thresholds[sample] = threshold - binary_mask = embeddings_std < threshold + if self.thresholds[sample] is None: + threshold = threshold_otsu(embeddings_std) + self.thresholds[sample] = round(threshold, 3) + binary_mask = embeddings_std < self.thresholds[sample] foreground_mask[sample] = binary_mask[np.newaxis, ...] embeddings_centered_sample = embeddings_sample.copy() embeddings_mean_masked = ( @@ -812,31 +952,51 @@ def infer(self): ) for i in range(self.napari_dataset.get_num_spatial_dims() + 1) ] + print("Clustering Objects in the obtained Foreground Mask ...") - detection = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16) - for sample in tqdm(range(embeddings.shape[0])): - embeddings_sample = embeddings[sample] + if hasattr(self, "detection"): + pass + else: + self.detection = np.zeros_like( + self.embeddings[:, 0:1, ...], dtype=np.uint16 + ) + for sample in tqdm(range(self.embeddings.shape[0])): + embeddings_sample = self.embeddings[sample] embeddings_std = embeddings_sample[-1, ...] embeddings_mean = embeddings_sample[ np.newaxis, : self.napari_dataset.get_num_spatial_dims(), ... ].copy() - detection_sample = mean_shift_segmentation( - embeddings_mean, - embeddings_std, - bandwidth=self.band_widths[sample], - min_size=self.inference_config.min_size, - reduction_probability=self.inference_config.reduction_probability, - threshold=self.thresholds[sample], - seeds=None, - ) - detection[sample, 0, ...] = detection_sample + if ( + self.thresholds[sample] != self.thresholds_last[sample] + or self.bandwidths[sample] != self.bandwidths_last[sample] + ): + detection_sample = mean_shift_segmentation( + embeddings_mean, + embeddings_std, + bandwidth=self.bandwidths[sample], + min_size=self.inference_config.min_size, + reduction_probability=self.inference_config.reduction_probability, + threshold=self.thresholds[sample], + seeds=None, + ) + self.detection[sample, 0, ...] = detection_sample + self.thresholds_last[sample] = self.thresholds[sample] + self.bandwidths_last[sample] = self.bandwidths[sample] print("Converting Detections to Segmentations ...") - segmentation = np.zeros_like(embeddings[:, 0:1, ...], dtype=np.uint16) + if ( + hasattr(self, "segmentation") + and self.post_processing == self.post_processing_last + ): + pass + else: + self.segmentation = np.zeros_like( + self.embeddings[:, 0:1, ...], dtype=np.uint16 + ) if self.radio_button_cell.isChecked(): - for sample in tqdm(range(embeddings.shape[0])): - segmentation_sample = detection[sample, 0].copy() + for sample in tqdm(range(self.embeddings.shape[0])): + segmentation_sample = self.detection[sample, 0].copy() distance_foreground = dtedt(segmentation_sample == 0) expanded_mask = ( distance_foreground < self.inference_config.grow_distance @@ -845,11 +1005,11 @@ def infer(self): segmentation_sample[ distance_background < self.inference_config.shrink_distance ] = 0 - segmentation[sample, 0, ...] = segmentation_sample + self.segmentation[sample, 0, ...] = segmentation_sample elif self.radio_button_nucleus.isChecked(): raw_image = raw_image_layer.data - for sample in tqdm(range(embeddings.shape[0])): - segmentation_sample = detection[sample, 0] + for sample in tqdm(range(self.embeddings.shape[0])): + segmentation_sample = self.detection[sample, 0].copy() if ( self.napari_dataset.get_num_samples() == 0 and self.napari_dataset.get_num_channels() == 0 @@ -903,7 +1063,7 @@ def infer(self): ) mask[y_min : y_max + 1, x_min : x_max + 1] = mask_small y, x = np.where(mask) - segmentation[sample, 0, y, x] = id_ + self.segmentation[sample, 0, y, x] = id_ elif self.napari_dataset.get_num_spatial_dims() == 3: mask_small = binary_fill_holes( mask[ @@ -918,28 +1078,66 @@ def infer(self): x_min : x_max + 1, ] = mask_small z, y, x = np.where(mask) - segmentation[sample, 0, z, y, x] = id_ + self.segmentation[sample, 0, z, y, x] = id_ print("Removing small objects ...") # size filter - remove small objects - for sample in tqdm(range(embeddings.shape[0])): - segmentation[sample, 0, ...] = size_filter( - segmentation[sample, 0], self.min_sizes[sample] - ) + for sample in tqdm(range(self.embeddings.shape[0])): + if ( + self.min_sizes[sample] != self.min_sizes_last[sample] + or self.post_processing_last != self.post_processing + ): + self.segmentation[sample, 0, ...] = size_filter( + self.segmentation[sample, 0], self.min_sizes[sample] + ) + self.min_sizes_last[sample] = self.min_sizes[sample] + self.post_processing_last = self.post_processing + return ( embeddings_layers + [(foreground_mask, {"name": "Foreground Mask"}, "labels")] - + [(detection, {"name": "Detection"}, "labels")] - + [(segmentation, {"name": "Segmentation"}, "labels")] + + [(self.detection, {"name": "Detection"}, "labels")] + + [(self.segmentation, {"name": "Segmentation"}, "labels")] ) def on_return_infer(self, layers): + """ + Once inference is over, the old result layers are removed + and the new output layers are displayed. + + Args: + layers: Tuple + (embedding_layers, foreground layer, detection layer, segmentation layer) + + """ + + if "Offset (x)" in self.viewer.layers: + del self.viewer.layers["Offset (x)"] + if "Offset (y)" in self.viewer.layers: + del self.viewer.layers["Offset (y)"] + if "Offset (z)" in self.viewer.layers: + del self.viewer.layers["Offset (z)"] + if "Uncertainty" in self.viewer.layers: + del self.viewer.layers["Uncertainty"] + if "Foreground Mask" in self.viewer.layers: + del self.viewer.layers["Foreground Mask"] + if "Segmentation" in self.viewer.layers: + del self.viewer.layers["Segmentation"] + if "Detection" in self.viewer.layers: + del self.viewer.layers["Detection"] + for data, metadata, layer_type in layers: if layer_type == "image": self.viewer.add_image(data, **metadata) elif layer_type == "labels": - self.viewer.add_labels(data.astype(int), **metadata) + if ( + self.napari_dataset.get_num_samples() != 0 + and self.napari_dataset.get_num_channels() != 0 + ): + self.viewer.add_labels(data.astype(int), **metadata) + else: + self.viewer.add_labels(data[:, 0].astype(int), **metadata) self.viewer.layers["Offset (x)"].visible = False self.viewer.layers["Offset (y)"].visible = False self.viewer.layers["Uncertainty"].visible = False @@ -948,3 +1146,80 @@ def on_return_infer(self, layers): self.viewer.layers["Segmentation"].visible = True self.inference_worker.quit() self.prepare_for_stop_inference() + + def prepare_thresholds(self): + """ + In case, the `Threshold` lineedit is changed by the user, + the attribute `thresholds` is updated. + """ + sample_index = self.viewer.dims.current_step[0] + self.thresholds[sample_index] = float(self.threshold_line.text()) + + def prepare_bandwidths(self): + """ + In case, the `Bandwidth` lineedit is changed by the user, + the attribute `bandwidths` is updated. + """ + sample_index = self.viewer.dims.current_step[0] + self.bandwidths[sample_index] = float(self.bandwidth_line.text()) + + def prepare_min_sizes(self): + """ + In case, the `Minimum Size` lineedit is changed by the user, + the attribute `min_sizes` is updated. + + """ + sample_index = self.viewer.dims.current_step[0] + self.min_sizes[sample_index] = float(self.min_size_line.text()) + + def load_weights(self): + """ + Describes sequence of actions, which ensue after `Load Weights` button is pressed + + """ + file_name, _ = QFileDialog.getOpenFileName( + caption="Load Model Weights" + ) + self.pre_trained_model_checkpoint = file_name + print( + f"Model weights will be loaded from {self.pre_trained_model_checkpoint}" + ) + + def update_post_processing(self): + self.post_processing = ( + "cell" if self.radio_button_nucleus.isChecked() else "nucleus" + ) + + def affect_load_weights(self): + """ + In case `train from scratch` checkbox is selected, + the `Load weights` is disabled, and vice versa. + + """ + if self.train_model_from_scratch_checkbox.isChecked(): + self.load_model_button.setEnabled(False) + else: + self.load_model_button.setEnabled(True) + + def save_weights(self): + """ + Describes sequence of actions which ensue, after `Save weights` button is pressed + + """ + checkpoint_file_name, _ = QFileDialog.getSaveFileName( + caption="Save Model Weights" + ) + if ( + hasattr(self, "model") + and hasattr(self, "optimizer") + and hasattr(self, "iterations") + and hasattr(self, "losses") + ): + state = { + "model_state_dict": self.model.state_dict(), + "optim_state_dict": self.optimizer.state_dict(), + "iterations": self.iterations, + "losses": self.losses, + } + torch.save(state, checkpoint_file_name) + print(f"Model weights will be saved at {checkpoint_file_name}") diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 0808e73..0000000 --- a/tox.ini +++ /dev/null @@ -1,32 +0,0 @@ -# For more information about tox, see https://tox.readthedocs.io/en/latest/ -[tox] -envlist = py{38,39,310}-{linux,macos,windows} -isolated_build=true - -[gh-actions] -python = - 3.8: py38 - 3.9: py39 - 3.10: py310 - -[gh-actions:env] -PLATFORM = - ubuntu-latest: linux - macos-latest: macos - windows-latest: windows - -[testenv] -platform = - macos: darwin - linux: linux - windows: win32 -passenv = - CI - GITHUB_ACTIONS - DISPLAY - XAUTHORITY - NUMPY_EXPERIMENTAL_ARRAY_FUNCTION - PYVISTA_OFF_SCREEN -extras = - testing -commands = pytest -v --color=yes --cov=napari_cellulus --cov-report=xml