diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index 85ff4c67..58bf263c 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -37,6 +37,7 @@ data: batch_size: 32 num_workers: 16 yx_patch_size: [256, 256] + pyramid_resolution: "0" normalizations: - class_path: viscy.transforms.NormalizeSampled init_args: @@ -87,4 +88,4 @@ data: sigma_z: [0.25, 1.5] sigma_y: [0.25, 1.5] sigma_x: [0.25, 1.5] - caching: false \ No newline at end of file + caching: false diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml index b2556139..aaa69628 100644 --- a/examples/configs/predict_example.yml +++ b/examples/configs/predict_example.yml @@ -63,5 +63,6 @@ predict: - 256 caching: false predict_scale_source: null + pyramid_resolution: "0" return_predictions: false ckpt_path: null diff --git a/examples/configs/test_example.yml b/examples/configs/test_example.yml index 6c7130a2..9dd0b872 100644 --- a/examples/configs/test_example.yml +++ b/examples/configs/test_example.yml @@ -62,5 +62,6 @@ data: - 256 caching: false ground_truth_masks: null + pyramid_resolution: "0" ckpt_path: null verbose: true diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index e8ba12fa..17081408 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -104,6 +104,8 @@ class SlidingWindowDataset(Dataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D + :param str pyramid_resolution: pyramid level. + defaults to 0 (full resolution) :param DictTransform | None transform: a callable that transforms data, defaults to None """ @@ -113,6 +115,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, + pyramid_resolution: str = "0", transform: DictTransform | None = None, ) -> None: super().__init__() @@ -128,6 +131,7 @@ def __init__( ) self.z_window_size = z_window_size self.transform = transform + self.pyramid_resolution = pyramid_resolution self._get_windows() def _get_windows(self) -> None: @@ -138,7 +142,7 @@ def _get_windows(self) -> None: self.window_arrays = [] self.window_norm_meta: list[NormMeta | None] = [] for fov in self.positions: - img_arr: ImageArray = fov["0"] + img_arr: ImageArray = fov[str(self.pyramid_resolution)] ts = img_arr.frames zs = img_arr.slices - self.z_window_size + 1 w += ts * zs @@ -219,7 +223,7 @@ def __getitem__(self, index: int) -> Sample: sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), - "norm_meta": norm_meta, + # "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") @@ -301,6 +305,8 @@ class HCSDataModule(LightningDataModule): :param Path | None ground_truth_masks: path to the ground truth masks, used in the test stage to compute segmentation metrics, defaults to None + :param str pyramid_resolution: pyramid resolution level. + defaults to 0 (full resolution) """ def __init__( @@ -318,6 +324,7 @@ def __init__( augmentations: list[MapTransform] = [], caching: bool = False, ground_truth_masks: Path | None = None, + pyramid_resolution: str = "0", ): super().__init__() self.data_path = Path(data_path) @@ -334,6 +341,7 @@ def __init__( self.caching = caching self.ground_truth_masks = ground_truth_masks self.prepare_data_per_node = True + self.pyramid_resolution = pyramid_resolution @property def cache_path(self): @@ -390,6 +398,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: return { "channels": {"source": self.source_channel}, "z_window_size": self.z_window_size, + "pyramid_resolution": self.pyramid_resolution, } def setup(self, stage: Literal["fit", "validate", "test", "predict"]):