Skip to content

Commit

Permalink
Merge pull request #6 from funkelab/v0.0.3
Browse files Browse the repository at this point in the history
Update Version of Plugin
  • Loading branch information
lmanan authored Oct 30, 2023
2 parents 3160603 + f0831ae commit 8127332
Show file tree
Hide file tree
Showing 8 changed files with 811 additions and 906 deletions.
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ install_requires =
matplotlib
torch
gunpowder



python_requires = >=3.8
Expand Down
14 changes: 2 additions & 12 deletions src/napari_cellulus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
__version__ = "0.0.1"
__version__ = "0.0.3"

from ._sample_data import tissuenet_sample
from .widgets._widget import (
TrainWidget,
model_config_widget,
train_config_widget,
)

__all__ = (
"tissuenet_sample",
"train_config_widget",
"model_config_widget",
"TrainWidget",
)
__all__ = ("tissuenet_sample",)
113 changes: 78 additions & 35 deletions src/napari_cellulus/dataset.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import math
from typing import Tuple, List
from typing import List

import gunpowder as gp
import numpy as np
from napari.layers import Image
from torch.utils.data import IterableDataset

from .gp.nodes.napari_image_source import NapariImageSource

from cellulus.datasets import DatasetMetaData
from .meta_data import NapariDatasetMetaData


class NapariDataset(IterableDataset): # type: ignore
def __init__(
self,
layer: Image,
axis_names: List[str],
crop_size: Tuple[int, ...],
crop_size: int,
control_point_spacing: int,
control_point_jitter: float,
):
Expand Down Expand Up @@ -58,40 +57,61 @@ def __init__(

self.layer = layer
self.axis_names = axis_names
self.crop_size = crop_size
self.control_point_spacing = control_point_spacing
self.control_point_jitter = control_point_jitter
self.__read_meta_data()

assert len(crop_size) == self.num_spatial_dims, (
f'"crop_size" must have the same dimension as the '
f'spatial(temporal) dimensions of the "{self.layer.name}"'
f"layer which is {self.num_spatial_dims}, but it is {crop_size}"
)

print(f"Number of spatial dims is {self.num_spatial_dims}")
self.crop_size = (crop_size,) * self.num_spatial_dims
self.__setup_pipeline()

def __iter__(self):
return iter(self.__yield_sample())

def __setup_pipeline(self):
self.raw = gp.ArrayKey("RAW")
# treat all dimensions as spatial, with a voxel size = 1
voxel_size = gp.Coordinate((1,) * self.num_dims)
offset = gp.Coordinate((0,) * self.num_dims)
shape = gp.Coordinate(self.layer.data.shape)
raw_spec = gp.ArraySpec(
roi=gp.Roi(offset, voxel_size * shape),
dtype=np.float32,
interpolatable=True,
voxel_size=voxel_size,
)

self.pipeline = (
NapariImageSource(self.layer, self.raw)
+ gp.RandomLocation()
+ gp.ElasticAugment(
control_point_spacing=(self.control_point_spacing,)
* self.num_spatial_dims,
jitter_sigma=(self.control_point_jitter,)
* self.num_spatial_dims,
rotation_interval=(0, math.pi / 2),
scale_interval=(0.9, 1.1),
subsample=4,
spatial_dims=self.num_spatial_dims,
if self.num_channels == 0 and self.num_samples == 0:
self.pipeline = (
NapariImageSource(
self.layer, self.raw, raw_spec, self.spatial_dims
)
+ gp.RandomLocation()
+ gp.Unsqueeze([self.raw], 0)
+ gp.Unsqueeze([self.raw], 0)
)
elif self.num_channels == 0 and self.num_samples != 0:
self.pipeline = (
NapariImageSource(
self.layer, self.raw, raw_spec, self.spatial_dims
)
+ gp.RandomLocation()
+ gp.Unsqueeze([self.raw], 1)
)
elif self.num_channels != 0 and self.num_samples == 0:
self.pipeline = (
NapariImageSource(
self.layer, self.raw, raw_spec, self.spatial_dims
)
+ gp.RandomLocation()
+ gp.Unsqueeze([self.raw], 0)
)
elif self.num_channels != 0 and self.num_samples != 0:
self.pipeline = (
NapariImageSource(
self.layer, self.raw, raw_spec, self.spatial_dims
)
+ gp.RandomLocation()
)
# + gp.SimpleAugment(mirror_only=spatial_dims, transpose_only=spatial_dims)
)

def __yield_sample(self):
"""An infinite generator of crops."""
Expand All @@ -100,18 +120,40 @@ def __yield_sample(self):
while True:
# request one sample, all channels, plus crop dimensions
request = gp.BatchRequest()
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * self.num_dims,
(1, self.num_channels, *self.crop_size),
if self.num_channels == 0 and self.num_samples == 0:
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * (self.num_dims),
self.crop_size,
)
)
elif self.num_channels == 0 and self.num_samples != 0:
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * (self.num_dims), (1, *self.crop_size)
)
)
elif self.num_channels != 0 and self.num_samples == 0:
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * (self.num_dims),
(self.num_channels, *self.crop_size),
)
)
elif self.num_channels != 0 and self.num_samples != 0:
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * (self.num_dims),
(1, self.num_channels, *self.crop_size),
)
)
)

sample = self.pipeline.request_batch(request)
yield sample[self.raw].data[0]

def __read_meta_data(self):
meta_data = DatasetMetaData(self.layer.data.shape, self.axis_names)
meta_data = NapariDatasetMetaData(
self.layer.data.shape, self.axis_names
)

self.num_dims = meta_data.num_dims
self.num_spatial_dims = meta_data.num_spatial_dims
Expand All @@ -120,9 +162,10 @@ def __read_meta_data(self):
self.sample_dim = meta_data.sample_dim
self.channel_dim = meta_data.channel_dim
self.time_dim = meta_data.time_dim
self.spatial_dims = meta_data.spatial_dims

def get_num_channels(self):
return self.num_channels
return 1 if self.num_channels == 0 else self.num_channels

def get_num_spatial_dims(self):
return self.num_spatial_dims
66 changes: 13 additions & 53 deletions src/napari_cellulus/gp/nodes/napari_image_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional

import gunpowder as gp
import numpy as np
from csbdeep.utils import normalize
from gunpowder.array_spec import ArraySpec
from gunpowder.profiling import Timing
from napari.layers import Image


Expand All @@ -11,70 +10,31 @@ class NapariImageSource(gp.BatchProvider):
A gunpowder interface to a napari Image
Args:
image (Image):
The napari Image to pull data from
The napari image layer to pull data from
key (``gp.ArrayKey``):
The key to provide data into
"""

def __init__(
self, image: Image, key: gp.ArrayKey, spec: Optional[ArraySpec] = None
self, image: Image, key: gp.ArrayKey, spec: ArraySpec, spatial_dims
):
if spec is None:
self.array_spec = self._read_metadata(image)
else:
self.array_spec = spec
self.array_spec = spec
self.image = gp.Array(
self._remove_leading_dims(image.data), self.array_spec
normalize(
image.data.astype(np.float32),
pmin=1,
pmax=99.8,
axis=spatial_dims,
),
self.array_spec,
)
self.spatial_dims = spatial_dims
self.key = key

def setup(self):
self.provides(self.key, self.array_spec.copy())

def provide(self, request):
output = gp.Batch()

timing_provide = Timing(self, "provide")
timing_provide.start()

output[self.key] = self.image.crop(request[self.key].roi)

timing_provide.stop()

output.profiling_stats.add(timing_provide)

return output

def _remove_leading_dims(self, data):
while data.shape[0] == 1:
data = data[0]
return data

def _read_metadata(self, image):
# offset assumed to be in world coordinates
# TODO: read from metadata
data_shape = image.data.shape
# strip leading singleton dimensions (2D data is often given a leading singleton 3rd dimension)
while data_shape[0] == 1:
data_shape = data_shape[1:]
axes = image.metadata.get("axes")
if axes is not None:
ndims = len(axes)
assert ndims <= len(
data_shape
), f"{axes} incompatible with shape: {data_shape}"
else:
ndims = len(data_shape)

offset = gp.Coordinate(image.metadata.get("offset", (0,) * ndims))
voxel_size = gp.Coordinate(
image.metadata.get("resolution", (1,) * ndims)
)
shape = gp.Coordinate(image.data.shape[-offset.dims :])

return gp.ArraySpec(
roi=gp.Roi(offset, voxel_size * shape),
dtype=image.dtype,
interpolatable=True,
voxel_size=voxel_size,
)
2 changes: 1 addition & 1 deletion src/napari_cellulus/gui_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
NavigationToolbar2QT as NavigationToolbar,
)
from matplotlib.figure import Figure
from PyQt5 import QtWidgets
from qtpy import QtWidgets


class MplCanvas(FigureCanvasQTAgg):
Expand Down
36 changes: 36 additions & 0 deletions src/napari_cellulus/meta_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Tuple


class NapariDatasetMetaData:
def __init__(self, shape, axis_names):
self.num_dims = len(axis_names)
self.num_spatial_dims: int = 0
self.num_samples: int = 0
self.num_channels: int = 0
self.sample_dim = None
self.channel_dim = None
self.time_dim = None
self.spatial_array: Tuple[int, ...] = ()
self.spatial_dims = ()
for dim, axis_name in enumerate(axis_names):
if axis_name == "s":
self.sample_dim = dim
self.num_samples = shape[dim]
elif axis_name == "c":
self.channel_dim = dim
self.num_channels = shape[dim]
elif axis_name == "t":
self.num_spatial_dims += 1
self.time_dim = dim
elif axis_name == "z":
self.num_spatial_dims += 1
self.spatial_array += (shape[dim],)
self.spatial_dims += (-3,)
elif axis_name == "y":
self.num_spatial_dims += 1
self.spatial_array += (shape[dim],)
self.spatial_dims += (-2,)
elif axis_name == "x":
self.num_spatial_dims += 1
self.spatial_array += (shape[dim],)
self.spatial_dims += (-1,)
20 changes: 5 additions & 15 deletions src/napari_cellulus/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,13 @@ contributions:
- id: napari-cellulus.tissuenet_sample
python_name: napari_cellulus._sample_data:tissuenet_sample
title: Load sample data from Cellulus
- id: napari-cellulus.train_config
python_name: napari_cellulus.widgets._widget:train_config_widget
title: Make the training config widget
- id: napari-cellulus.model_config
python_name: napari_cellulus.widgets._widget:model_config_widget
title: Make the model config widget
- id: napari-cellulus.train_widget
python_name: napari_cellulus.widgets._widget:TrainWidget
title: Make the train widget
- id: napari-cellulus.SegmentationWidget
python_name: napari_cellulus.widgets._widget:SegmentationWidget
title: Cellulus
sample_data:
- command: napari-cellulus.tissuenet_sample
display_name: Cellulus
key: tissuenet_sample
widgets:
- command: napari-cellulus.train_config
display_name: Train Config Widget
- command: napari-cellulus.model_config
display_name: Model Config Widget
- command: napari-cellulus.train_widget
display_name: Train Widget
- command: napari-cellulus.SegmentationWidget
display_name: Cellulus
Loading

0 comments on commit 8127332

Please sign in to comment.