Skip to content

Commit

Permalink
added the pyramid option to the slidingwindow
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Aug 15, 2024
1 parent b07f0ad commit 8a137d4
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,23 @@ def __init__(
positions: list[Position],
channels: ChannelMap,
z_window_size: int,
pyramid_resolution: int = 0,
pyramid_resolution: str = "0",
transform: DictTransform | None = None,
) -> None:
super().__init__()
self.positions = positions
self.channels = {k: _ensure_channel_list(v) for k, v in channels.items()}
self.source_ch_idx = [
positions[pyramid_resolution].get_channel_index(c)
for c in channels["source"]
positions[0].get_channel_index(c) for c in channels["source"]
]
self.target_ch_idx = (
[
positions[pyramid_resolution].get_channel_index(c)
for c in channels["target"]
]
[positions[0].get_channel_index(c) for c in channels["target"]]
if "target" in channels
else None
)
self.z_window_size = z_window_size
self.transform = transform
self.pyramid_resolution = pyramid_resolution
self._get_windows()

def _get_windows(self) -> None:
Expand All @@ -145,7 +142,7 @@ def _get_windows(self) -> None:
self.window_arrays = []
self.window_norm_meta: list[NormMeta | None] = []
for fov in self.positions:
img_arr: ImageArray = fov["0"]
img_arr: ImageArray = fov[str(self.pyramid_resolution)]
ts = img_arr.frames
zs = img_arr.slices - self.z_window_size + 1
w += ts * zs
Expand Down Expand Up @@ -226,7 +223,7 @@ def __getitem__(self, index: int) -> Sample:
sample = {
"index": sample_index,
"source": self._stack_channels(sample_images, "source"),
"norm_meta": norm_meta,
# "norm_meta": norm_meta,
}
if self.target_ch_idx is not None:
sample["target"] = self._stack_channels(sample_images, "target")
Expand Down Expand Up @@ -327,7 +324,7 @@ def __init__(
augmentations: list[MapTransform] = [],
caching: bool = False,
ground_truth_masks: Path | None = None,
pyramid_resolution: int = 0,
pyramid_resolution: str = "0",
):
super().__init__()
self.data_path = Path(data_path)
Expand All @@ -344,6 +341,7 @@ def __init__(
self.caching = caching
self.ground_truth_masks = ground_truth_masks
self.prepare_data_per_node = True
self.pyramid_resolution = pyramid_resolution

@property
def cache_path(self):
Expand Down Expand Up @@ -400,6 +398,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]:
return {
"channels": {"source": self.source_channel},
"z_window_size": self.z_window_size,
"pyramid_resolution": self.pyramid_resolution,
}

def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
Expand Down

0 comments on commit 8a137d4

Please sign in to comment.