Skip to content

Commit

Permalink
doc string
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonSantiago committed Dec 3, 2024
1 parent de2d5ad commit 89074a7
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions sparsecoding/transforms/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def quilt(
patches: torch.Tensor,
stride: int = None,
):
"""Gather square patches into an image.
"""Gather square patches into an image. When patches overlap, take the average of overlapping pixels
Inverse of `patchify()`
Expand Down Expand Up @@ -445,7 +445,6 @@ def quilt(
if stride is None:
stride = P

# Calculate expected number of patches based on stride
expected_N = (
int((H - P + 1 + stride) // stride)
* int((W - P + 1 + stride) // stride)
Expand All @@ -471,27 +470,23 @@ def quilt(
UserWarning,
)

# Reshape patches for folding operation
patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P]
patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N]

# Use fold operation with the specified stride
image = torch.nn.functional.fold(
input=patches,
output_size=(H, W),
kernel_size=P,
stride=stride,
) # [prod(*), C, H, W]

# Create a ones tensor of the same shape as patches to track overlapping regions
normalization = torch.nn.functional.fold(
torch.ones_like(patches),
output_size=(H, W),
kernel_size=P,
stride=stride,
)

# Normalize by the count of overlapping patches
image = image / (normalization + 1e-6) # Add small epsilon to avoid division by zero
image = image / (normalization + 1e-6)

return image.reshape(*leading_dims, C, H, W)

0 comments on commit 89074a7

Please sign in to comment.