Skip to content

Commit

Permalink
fix shapes for batched outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt committed Jun 16, 2024
1 parent cdc56f0 commit 8dc7928
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
17 changes: 17 additions & 0 deletions examples/extract_from_single_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
import torch

from torch_subpixel_crop import subpixel_crop_2d
from skimage import data

image = torch.tensor(data.binary_blobs(length=512, n_dim=2)).float()
positions = torch.tensor(np.random.uniform(low=0, high=511, size=(100, 2))).float()

crops = subpixel_crop_2d(
image=image,
positions=positions,
sidelength=32
)

# (100, 32, 32)
print(crops.shape)
26 changes: 18 additions & 8 deletions src/torch_subpixel_crop/subpixel_crop_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def subpixel_crop_2d(
images: torch.Tensor, positions: torch.Tensor, sidelength: int,
image: torch.Tensor, positions: torch.Tensor, sidelength: int,
):
"""Extract square patches from 2D images at positions with subpixel precision.
Expand All @@ -18,7 +18,7 @@ def subpixel_crop_2d(
Parameters
----------
images: torch.Tensor
image: torch.Tensor
`(b, h, w)` or `(h, w)` array of 2D images.
positions: torch.Tensor
`(..., b, 2)` or `(..., 2)` array of coordinates for patch centers.
Expand All @@ -28,12 +28,18 @@ def subpixel_crop_2d(
Returns
-------
patches: torch.Tensor
`(..., b, sidelength, sidelength)` array of patches from `images` with their
centers (DC components of fftshifted DFTs at `positions`.
`(..., b, sidelength, sidelength)` or `(..., sidelength, sidelength)` array
of patches from `images` with their centers at `positions`.
"""
if images.ndim == 2:
images = einops.rearrange(images, 'h w -> 1 h w')
# handling batched input
if image.ndim == 2:
input_images_are_batched = False
image = einops.rearrange(image, 'h w -> 1 h w')
positions = einops.rearrange(positions, '... yx -> ... 1 yx')
else:
input_images_are_batched = True

# setup coordinates and extract
positions, ps = einops.pack([positions], pattern='* t yx')
positions = einops.rearrange(positions, 'b t yx -> t b yx')
patches = einops.rearrange(
Expand All @@ -44,11 +50,15 @@ def subpixel_crop_2d(
output_image_sidelength=sidelength
)
for _image, _positions
in zip(images, positions)
in zip(image, positions)
],
pattern='t b h w -> b t h w'
)
[patches] = einops.unpack(patches, pattern='* t h w', packed_shapes=ps)

# unbatch output if input images weren't batched
if input_images_are_batched is False:
patches = einops.rearrange(patches, pattern='... 1 h w -> ... h w')
return patches


Expand Down Expand Up @@ -86,5 +96,5 @@ def _extract_patches_from_single_image(
patches = einops.rearrange(patches, 'b 1 h w -> b h w')

# phase shift to center images
patches = fourier_shift_image_2d(images=patches, shifts=shifts)
patches = fourier_shift_image_2d(image=patches, shifts=shifts)
return patches

0 comments on commit 8dc7928

Please sign in to comment.