diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 2b163433..26271438 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -115,26 +115,23 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - pyramid_resolution: int = 0, + pyramid_resolution: str = "0", transform: DictTransform | None = None, ) -> None: super().__init__() self.positions = positions self.channels = {k: _ensure_channel_list(v) for k, v in channels.items()} self.source_ch_idx = [ - positions[pyramid_resolution].get_channel_index(c) - for c in channels["source"] + positions[0].get_channel_index(c) for c in channels["source"] ] self.target_ch_idx = ( - [ - positions[pyramid_resolution].get_channel_index(c) - for c in channels["target"] - ] + [positions[0].get_channel_index(c) for c in channels["target"]] if "target" in channels else None ) self.z_window_size = z_window_size self.transform = transform + self.pyramid_resolution = pyramid_resolution self._get_windows() def _get_windows(self) -> None: @@ -145,7 +142,7 @@ def _get_windows(self) -> None: self.window_arrays = [] self.window_norm_meta: list[NormMeta | None] = [] for fov in self.positions: - img_arr: ImageArray = fov["0"] + img_arr: ImageArray = fov[str(self.pyramid_resolution)] ts = img_arr.frames zs = img_arr.slices - self.z_window_size + 1 w += ts * zs @@ -226,7 +223,7 @@ def __getitem__(self, index: int) -> Sample: sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), - "norm_meta": norm_meta, + # "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") @@ -327,7 +324,7 @@ def __init__( augmentations: list[MapTransform] = [], caching: bool = False, ground_truth_masks: Path | None = None, - pyramid_resolution: int = 0, + pyramid_resolution: str = "0", ): super().__init__() self.data_path = Path(data_path) @@ -344,6 +341,7 @@ def __init__( self.caching = caching self.ground_truth_masks = ground_truth_masks self.prepare_data_per_node = True + self.pyramid_resolution = pyramid_resolution @property def cache_path(self): @@ -400,6 +398,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: return { "channels": {"source": self.source_channel}, "z_window_size": self.z_window_size, + "pyramid_resolution": self.pyramid_resolution, } def setup(self, stage: Literal["fit", "validate", "test", "predict"]):