diff --git a/src/napari_cellulus/__init__.py b/src/napari_cellulus/__init__.py index 6049029..60ccee5 100644 --- a/src/napari_cellulus/__init__.py +++ b/src/napari_cellulus/__init__.py @@ -2,9 +2,9 @@ from ._sample_data import tissuenet_sample from .widgets._widget import ( - train_config_widget, TrainWidget, model_config_widget, + train_config_widget, ) __all__ = ( diff --git a/src/napari_cellulus/_sample_data.py b/src/napari_cellulus/_sample_data.py index 04287d6..159abd1 100644 --- a/src/napari_cellulus/_sample_data.py +++ b/src/napari_cellulus/_sample_data.py @@ -1,7 +1,7 @@ -import numpy as np - from pathlib import Path +import numpy as np + TISSUENET_SAMPLE = Path(__file__).parent / "sample_data/tissuenet-sample.npy" diff --git a/src/napari_cellulus/dataset.py b/src/napari_cellulus/dataset.py index 7e93c6d..b54e855 100644 --- a/src/napari_cellulus/dataset.py +++ b/src/napari_cellulus/dataset.py @@ -1,13 +1,11 @@ -from .gp.nodes.napari_image_source import NapariImageSource - import math from typing import Tuple import gunpowder as gp +from napari.layers import Image from torch.utils.data import IterableDataset -import numpy as np -from napari.layers import Image +from .gp.nodes.napari_image_source import NapariImageSource class NapariDataset(IterableDataset): # type: ignore diff --git a/src/napari_cellulus/gp/nodes/__init__.py b/src/napari_cellulus/gp/nodes/__init__.py index b0530bb..e69de29 100644 --- a/src/napari_cellulus/gp/nodes/__init__.py +++ b/src/napari_cellulus/gp/nodes/__init__.py @@ -1 +0,0 @@ -from .napari_image_source import NapariImageSource diff --git a/src/napari_cellulus/gp/nodes/napari_image_source.py b/src/napari_cellulus/gp/nodes/napari_image_source.py index 8f57f93..a3abc4c 100644 --- a/src/napari_cellulus/gp/nodes/napari_image_source.py +++ b/src/napari_cellulus/gp/nodes/napari_image_source.py @@ -1,10 +1,9 @@ -from napari.layers import Image +from typing import Optional import gunpowder as gp -from gunpowder.profiling import Timing from gunpowder.array_spec import ArraySpec - -from typing import Optional +from gunpowder.profiling import Timing +from napari.layers import Image class NapariImageSource(gp.BatchProvider): diff --git a/src/napari_cellulus/gui_helpers.py b/src/napari_cellulus/gui_helpers.py index 1fb89ef..84fd39f 100644 --- a/src/napari_cellulus/gui_helpers.py +++ b/src/napari_cellulus/gui_helpers.py @@ -1,25 +1,25 @@ -from magicgui.widgets import create_widget, FunctionGui - +from magicgui.widgets import FunctionGui, create_widget from matplotlib.backends.backend_qt5agg import ( FigureCanvasQTAgg, +) +from matplotlib.backends.backend_qt5agg import ( NavigationToolbar2QT as NavigationToolbar, ) from matplotlib.figure import Figure - -from PyQt5 import QtCore, QtGui, QtWidgets +from PyQt5 import QtWidgets class MplCanvas(FigureCanvasQTAgg): def __init__(self, parent=None, width=5, height=4, dpi=100): fig = Figure(figsize=(width, height), dpi=dpi) self.axes = fig.add_subplot(111) - super(MplCanvas, self).__init__(fig) + super().__init__(fig) fig.set_tight_layout(True) class MainWindow(QtWidgets.QMainWindow): def __init__(self, *args, **kwargs): - super(MainWindow, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) sc = MplCanvas(self, width=5, height=4, dpi=100) diff --git a/src/napari_cellulus/widgets/_widget.py b/src/napari_cellulus/widgets/_widget.py index 7f1471e..8a9ab98 100644 --- a/src/napari_cellulus/widgets/_widget.py +++ b/src/napari_cellulus/widgets/_widget.py @@ -6,54 +6,46 @@ Replace code below according to your needs. """ -from typing import TYPE_CHECKING, Optional, List -import time - -from magicgui import magic_factory -from magicgui.widgets import Container -from qtpy.QtWidgets import QPushButton, QWidget - -from cellulus.configs.model_config import ModelConfig -from cellulus.configs.train_config import TrainConfig -from cellulus.models import get_model -from cellulus.criterions import get_loss -from cellulus.utils.mean_shift import mean_shift_segmentation - +import dataclasses +import contextlib -# local package imports -from ..gui_helpers import layer_choice_widget, MplCanvas -from ..dataset import NapariDataset -from ..gp.nodes import NapariImageSource +# python built in libraries +from pathlib import Path +from typing import List, Optional # github repo libraries import gunpowder as gp # pip installed libraries import napari -from napari.qt.threading import FunctionWorker, thread_worker -import torch import numpy as np +import torch +from cellulus.configs.model_config import ModelConfig +from cellulus.configs.train_config import TrainConfig +from cellulus.criterions import get_loss +from cellulus.models import get_model +from cellulus.utils.mean_shift import mean_shift_segmentation +from magicgui import magic_factory +from magicgui.widgets import Container # widget stuff from matplotlib.backends.backend_qt5agg import ( NavigationToolbar2QT as NavigationToolbar, ) -from magicgui.widgets import Container -from qtpy.QtCore import QEvent, QObject +from napari.qt.threading import FunctionWorker, thread_worker from qtpy.QtWidgets import ( - QWidget, - QVBoxLayout, QPushButton, + QVBoxLayout, + QWidget, ) from superqt import QCollapsible - -# python built in libraries -from pathlib import Path -import dataclasses from tqdm import tqdm -if TYPE_CHECKING: - import napari +from ..dataset import NapariDataset +from ..gp.nodes import NapariImageSource + +# local package imports +from ..gui_helpers import MplCanvas, layer_choice_widget @dataclasses.dataclass @@ -64,8 +56,8 @@ class TrainingStats: def reset(self): self.iteration = 0 - self.losses = list([]) - self.iterations = list([]) + self.losses = [] + self.iterations = [] def load(self, other): self.iteration = other.iteration @@ -102,7 +94,7 @@ def get_train_config(**kwargs): @magic_factory(call_button="Save") def train_config_widget( - crop_size: list[int] = list([256, 256]), + crop_size: list[int] = [256, 256], batch_size: int = 8, max_iterations: int = 100_000, initial_learning_rate: float = 4e-5, @@ -150,7 +142,7 @@ def model_config_widget( num_fmaps: int = 256, fmap_inc_factor: int = 3, features_in_last_layer: int = 64, - downsampling_factors: list[list[int]] = list([list([2, 2])]), + downsampling_factors: list[list[int]] = [[2, 2]], ): get_model_config( num_fmaps=num_fmaps, @@ -187,7 +179,7 @@ def get_training_state(dataset: Optional[NapariDataset] = None): # Weight initialization # TODO: move weight initialization to funlib.learn.torch - for name, layer in _model.named_modules(): + for _name, layer in _model.named_modules(): if isinstance(layer, torch.nn.modules.conv._ConvNd): torch.nn.init.kaiming_normal_( layer.weight, nonlinearity="relu" @@ -323,7 +315,7 @@ def segment_widget(self): @magic_factory(call_button="Segment") def segment( raw: napari.layers.Image, - crop_size: list[int] = list([252, 252]), + crop_size: list[int] = [252, 252], p_salt_pepper: float = 0.1, num_infer_iterations: int = 16, bandwidth: int = 7, @@ -578,7 +570,7 @@ def reset_training_state(self, keep_stats=False): label="Training Loss", )[0] self.progress_plot.axes.legend() - self.progress_plot.axes.set_title(f"Training Progress") + self.progress_plot.axes.set_title("Training Progress") self.progress_plot.axes.set_xlabel("Iterations") self.progress_plot.axes.set_ylabel("Loss") self.update_progress_plot() @@ -589,14 +581,12 @@ def update_progress_plot(self): self.loss_plot.set_ydata(training_stats.losses) self.progress_plot.axes.relim() self.progress_plot.axes.autoscale_view() - try: - self.progress_plot.draw() - except np.linalg.LinAlgError as e: + with contextlib.suppress(np.linalg.LinAlgError): # matplotlib seems to throw a LinAlgError on draw sometimes. Not sure # why yet. Seems to only happen when initializing models without any # layers loaded. No idea whats going wrong. # For now just avoid drawing. Seems to work as soon as there is data to plot - pass + self.progress_plot.draw() def start_training_loop(self): self.reset_training_state(keep_stats=True) @@ -694,7 +684,7 @@ def add_layers(self, layers): assert batch_dim in [ -1, 0, - ], f"Batch dim must be first" + ], "Batch dim must be first" if batch_dim == 0: data = data[0]