Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonSantiago committed Dec 3, 2024
1 parent ebb5a87 commit de2d5ad
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions sparsecoding/transforms/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,10 @@ def quilt(
patches: torch.Tensor,
stride: int = None,
):
"""Gather square patches into an image, supporting overlapping patches.
Works with patches created by `patchify()` with a custom stride.
"""Gather square patches into an image.
Inverse of `patchify()`
Parameters
----------
height : int
Expand All @@ -430,7 +431,7 @@ def quilt(
stride : int, optional
Stride used when creating patches. If None, assumes non-overlapping patches
(stride = patch_size).
Returns
-------
image : Tensor, shape [*, C, height, width]
Expand All @@ -440,10 +441,10 @@ def quilt(
N, C, P = patches.shape[-4:-1]
H = height
W = width

if stride is None:
stride = P

# Calculate expected number of patches based on stride
expected_N = (
int((H - P + 1 + stride) // stride)
Expand All @@ -454,7 +455,12 @@ def quilt(
f"Expected {expected_N} patches per image based on stride {stride}, "
f"got {N} patches."
)


if stride > P:
raise RuntimeError(
"Stride cannot be larger than the size of a patch when quilting"
)

if (
H % stride != 0
or W % stride != 0
Expand All @@ -464,28 +470,28 @@ def quilt(
f"parts on the bottom and/or right may be affected.",
UserWarning,
)

# Reshape patches for folding operation
patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P]
patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N]

# Use fold operation with the specified stride
image = torch.nn.functional.fold(
input=patches,
output_size=(H, W),
kernel_size=P,
stride=stride,
) # [prod(*), C, H, W]

# Create a ones tensor of the same shape as patches to track overlapping regions
normalization = torch.nn.functional.fold(
torch.ones_like(patches),
output_size=(H, W),
kernel_size=P,
stride=stride,
)

# Normalize by the count of overlapping patches
image = image / (normalization + 1e-6) # Add small epsilon to avoid division by zero

return image.reshape(*leading_dims, C, H, W)

0 comments on commit de2d5ad

Please sign in to comment.