Skip to content

Commit

Permalink
improve formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 10, 2023
1 parent 8c6f362 commit c1fae1f
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/napari_cellulus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from ._sample_data import tissuenet_sample
from .widgets._widget import (
train_config_widget,
TrainWidget,
model_config_widget,
train_config_widget,
)

__all__ = (
Expand Down
4 changes: 2 additions & 2 deletions src/napari_cellulus/_sample_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from pathlib import Path

import numpy as np

TISSUENET_SAMPLE = Path(__file__).parent / "sample_data/tissuenet-sample.npy"


Expand Down
6 changes: 2 additions & 4 deletions src/napari_cellulus/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from .gp.nodes.napari_image_source import NapariImageSource

import math
from typing import Tuple

import gunpowder as gp
from napari.layers import Image
from torch.utils.data import IterableDataset

import numpy as np
from napari.layers import Image
from .gp.nodes.napari_image_source import NapariImageSource


class NapariDataset(IterableDataset): # type: ignore
Expand Down
1 change: 0 additions & 1 deletion src/napari_cellulus/gp/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .napari_image_source import NapariImageSource
7 changes: 3 additions & 4 deletions src/napari_cellulus/gp/nodes/napari_image_source.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from napari.layers import Image
from typing import Optional

import gunpowder as gp
from gunpowder.profiling import Timing
from gunpowder.array_spec import ArraySpec

from typing import Optional
from gunpowder.profiling import Timing
from napari.layers import Image


class NapariImageSource(gp.BatchProvider):
Expand Down
12 changes: 6 additions & 6 deletions src/napari_cellulus/gui_helpers.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
from magicgui.widgets import create_widget, FunctionGui

from magicgui.widgets import FunctionGui, create_widget
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg,
)
from matplotlib.backends.backend_qt5agg import (
NavigationToolbar2QT as NavigationToolbar,
)
from matplotlib.figure import Figure

from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5 import QtWidgets


class MplCanvas(FigureCanvasQTAgg):
def __init__(self, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
self.axes = fig.add_subplot(111)
super(MplCanvas, self).__init__(fig)
super().__init__(fig)
fig.set_tight_layout(True)


class MainWindow(QtWidgets.QMainWindow):
def __init__(self, *args, **kwargs):
super(MainWindow, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

sc = MplCanvas(self, width=5, height=4, dpi=100)

Expand Down
72 changes: 31 additions & 41 deletions src/napari_cellulus/widgets/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,46 @@
Replace code below according to your needs.
"""
from typing import TYPE_CHECKING, Optional, List
import time

from magicgui import magic_factory
from magicgui.widgets import Container
from qtpy.QtWidgets import QPushButton, QWidget

from cellulus.configs.model_config import ModelConfig
from cellulus.configs.train_config import TrainConfig
from cellulus.models import get_model
from cellulus.criterions import get_loss
from cellulus.utils.mean_shift import mean_shift_segmentation

import dataclasses
import contextlib

# local package imports
from ..gui_helpers import layer_choice_widget, MplCanvas
from ..dataset import NapariDataset
from ..gp.nodes import NapariImageSource
# python built in libraries
from pathlib import Path
from typing import List, Optional

# github repo libraries
import gunpowder as gp

# pip installed libraries
import napari
from napari.qt.threading import FunctionWorker, thread_worker
import torch
import numpy as np
import torch
from cellulus.configs.model_config import ModelConfig
from cellulus.configs.train_config import TrainConfig
from cellulus.criterions import get_loss
from cellulus.models import get_model
from cellulus.utils.mean_shift import mean_shift_segmentation
from magicgui import magic_factory
from magicgui.widgets import Container

# widget stuff
from matplotlib.backends.backend_qt5agg import (
NavigationToolbar2QT as NavigationToolbar,
)
from magicgui.widgets import Container
from qtpy.QtCore import QEvent, QObject
from napari.qt.threading import FunctionWorker, thread_worker
from qtpy.QtWidgets import (
QWidget,
QVBoxLayout,
QPushButton,
QVBoxLayout,
QWidget,
)
from superqt import QCollapsible

# python built in libraries
from pathlib import Path
import dataclasses
from tqdm import tqdm

if TYPE_CHECKING:
import napari
from ..dataset import NapariDataset
from ..gp.nodes import NapariImageSource

# local package imports
from ..gui_helpers import MplCanvas, layer_choice_widget


@dataclasses.dataclass
Expand All @@ -64,8 +56,8 @@ class TrainingStats:

def reset(self):
self.iteration = 0
self.losses = list([])
self.iterations = list([])
self.losses = []
self.iterations = []

def load(self, other):
self.iteration = other.iteration
Expand Down Expand Up @@ -102,7 +94,7 @@ def get_train_config(**kwargs):

@magic_factory(call_button="Save")
def train_config_widget(
crop_size: list[int] = list([256, 256]),
crop_size: list[int] = [256, 256],
batch_size: int = 8,
max_iterations: int = 100_000,
initial_learning_rate: float = 4e-5,
Expand Down Expand Up @@ -150,7 +142,7 @@ def model_config_widget(
num_fmaps: int = 256,
fmap_inc_factor: int = 3,
features_in_last_layer: int = 64,
downsampling_factors: list[list[int]] = list([list([2, 2])]),
downsampling_factors: list[list[int]] = [[2, 2]],
):
get_model_config(
num_fmaps=num_fmaps,
Expand Down Expand Up @@ -187,7 +179,7 @@ def get_training_state(dataset: Optional[NapariDataset] = None):

# Weight initialization
# TODO: move weight initialization to funlib.learn.torch
for name, layer in _model.named_modules():
for _name, layer in _model.named_modules():
if isinstance(layer, torch.nn.modules.conv._ConvNd):
torch.nn.init.kaiming_normal_(
layer.weight, nonlinearity="relu"
Expand Down Expand Up @@ -323,7 +315,7 @@ def segment_widget(self):
@magic_factory(call_button="Segment")
def segment(
raw: napari.layers.Image,
crop_size: list[int] = list([252, 252]),
crop_size: list[int] = [252, 252],
p_salt_pepper: float = 0.1,
num_infer_iterations: int = 16,
bandwidth: int = 7,
Expand Down Expand Up @@ -578,7 +570,7 @@ def reset_training_state(self, keep_stats=False):
label="Training Loss",
)[0]
self.progress_plot.axes.legend()
self.progress_plot.axes.set_title(f"Training Progress")
self.progress_plot.axes.set_title("Training Progress")
self.progress_plot.axes.set_xlabel("Iterations")
self.progress_plot.axes.set_ylabel("Loss")
self.update_progress_plot()
Expand All @@ -589,14 +581,12 @@ def update_progress_plot(self):
self.loss_plot.set_ydata(training_stats.losses)
self.progress_plot.axes.relim()
self.progress_plot.axes.autoscale_view()
try:
self.progress_plot.draw()
except np.linalg.LinAlgError as e:
with contextlib.suppress(np.linalg.LinAlgError):
# matplotlib seems to throw a LinAlgError on draw sometimes. Not sure
# why yet. Seems to only happen when initializing models without any
# layers loaded. No idea whats going wrong.
# For now just avoid drawing. Seems to work as soon as there is data to plot
pass
self.progress_plot.draw()

def start_training_loop(self):
self.reset_training_state(keep_stats=True)
Expand Down Expand Up @@ -694,7 +684,7 @@ def add_layers(self, layers):
assert batch_dim in [
-1,
0,
], f"Batch dim must be first"
], "Batch dim must be first"
if batch_dim == 0:
data = data[0]

Expand Down

0 comments on commit c1fae1f

Please sign in to comment.