From 8dc7928ce69c2b4cdc80f1bcc63a3c27df8e01c0 Mon Sep 17 00:00:00 2001 From: Alister Burt Date: Sat, 15 Jun 2024 18:06:46 -0700 Subject: [PATCH] fix shapes for batched outputs --- examples/extract_from_single_2d.py | 17 ++++++++++++++ src/torch_subpixel_crop/subpixel_crop_2d.py | 26 ++++++++++++++------- 2 files changed, 35 insertions(+), 8 deletions(-) create mode 100644 examples/extract_from_single_2d.py diff --git a/examples/extract_from_single_2d.py b/examples/extract_from_single_2d.py new file mode 100644 index 0000000..f431dad --- /dev/null +++ b/examples/extract_from_single_2d.py @@ -0,0 +1,17 @@ +import numpy as np +import torch + +from torch_subpixel_crop import subpixel_crop_2d +from skimage import data + +image = torch.tensor(data.binary_blobs(length=512, n_dim=2)).float() +positions = torch.tensor(np.random.uniform(low=0, high=511, size=(100, 2))).float() + +crops = subpixel_crop_2d( + image=image, + positions=positions, + sidelength=32 +) + +# (100, 32, 32) +print(crops.shape) \ No newline at end of file diff --git a/src/torch_subpixel_crop/subpixel_crop_2d.py b/src/torch_subpixel_crop/subpixel_crop_2d.py index 3d33091..b743fb5 100644 --- a/src/torch_subpixel_crop/subpixel_crop_2d.py +++ b/src/torch_subpixel_crop/subpixel_crop_2d.py @@ -9,7 +9,7 @@ def subpixel_crop_2d( - images: torch.Tensor, positions: torch.Tensor, sidelength: int, + image: torch.Tensor, positions: torch.Tensor, sidelength: int, ): """Extract square patches from 2D images at positions with subpixel precision. @@ -18,7 +18,7 @@ def subpixel_crop_2d( Parameters ---------- - images: torch.Tensor + image: torch.Tensor `(b, h, w)` or `(h, w)` array of 2D images. positions: torch.Tensor `(..., b, 2)` or `(..., 2)` array of coordinates for patch centers. @@ -28,12 +28,18 @@ def subpixel_crop_2d( Returns ------- patches: torch.Tensor - `(..., b, sidelength, sidelength)` array of patches from `images` with their - centers (DC components of fftshifted DFTs at `positions`. + `(..., b, sidelength, sidelength)` or `(..., sidelength, sidelength)` array + of patches from `images` with their centers at `positions`. """ - if images.ndim == 2: - images = einops.rearrange(images, 'h w -> 1 h w') + # handling batched input + if image.ndim == 2: + input_images_are_batched = False + image = einops.rearrange(image, 'h w -> 1 h w') positions = einops.rearrange(positions, '... yx -> ... 1 yx') + else: + input_images_are_batched = True + + # setup coordinates and extract positions, ps = einops.pack([positions], pattern='* t yx') positions = einops.rearrange(positions, 'b t yx -> t b yx') patches = einops.rearrange( @@ -44,11 +50,15 @@ def subpixel_crop_2d( output_image_sidelength=sidelength ) for _image, _positions - in zip(images, positions) + in zip(image, positions) ], pattern='t b h w -> b t h w' ) [patches] = einops.unpack(patches, pattern='* t h w', packed_shapes=ps) + + # unbatch output if input images weren't batched + if input_images_are_batched is False: + patches = einops.rearrange(patches, pattern='... 1 h w -> ... h w') return patches @@ -86,5 +96,5 @@ def _extract_patches_from_single_image( patches = einops.rearrange(patches, 'b 1 h w -> b h w') # phase shift to center images - patches = fourier_shift_image_2d(images=patches, shifts=shifts) + patches = fourier_shift_image_2d(image=patches, shifts=shifts) return patches