Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harrison/quilt #85

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions sparsecoding/transforms/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,11 @@ def quilt(
height: int,
width: int,
patches: torch.Tensor,
stride: int = None,
):
"""Gather square patches into an image.
"""Gather square patches into an image. When patches overlap, take the average of overlapping pixels

Inverse of `patchify()`.
Inverse of `patchify()`

Parameters
----------
Expand All @@ -422,11 +423,14 @@ def quilt(
width : int
Width for the reconstructed image.
patches : Tensor, shape [*, N, C, P, P]
Non-overlapping patches from an input image,
Potentially overlapping patches from an input image,
where:
P is the patch size,
N is the number of patches,
C is the number of channels in the image.
stride : int, optional
Stride used when creating patches. If None, assumes non-overlapping patches
(stride = patch_size).

Returns
-------
Expand All @@ -438,29 +442,51 @@ def quilt(
H = height
W = width

if int(H / P) * int(W / P) != N:
if stride is None:
stride = P

expected_N = (
int((H - P + 1 + stride) // stride)
* int((W - P + 1 + stride) // stride)
)
if expected_N != N:
raise ValueError(
f"Expected {N} patches per image, "
f"got int(H/P) * int(W/P) = {int(H / P) * int(W / P)}."
f"Expected {expected_N} patches per image based on stride {stride}, "
f"got {N} patches."
)

if stride > P:
raise RuntimeError(
"Stride cannot be larger than the size of a patch when quilting"
)

if (
H % P != 0
or W % P != 0
H % stride != 0
or W % stride != 0
):
warnings.warn(
f"Image size ({H, W}) not evenly divisible by `patch_size` ({P}),"
f"parts on the bottom and/or right will be zeroed.",
f"Image size ({H, W}) not evenly divisible by stride ({stride}),"
f"parts on the bottom and/or right may be affected.",
UserWarning,
)

patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P]
patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N]

image = torch.nn.functional.fold(
input=patches,
output_size=(H, W),
kernel_size=P,
stride=P,
stride=stride,
) # [prod(*), C, H, W]

normalization = torch.nn.functional.fold(
torch.ones_like(patches),
output_size=(H, W),
kernel_size=P,
stride=stride,
)

image = image / (normalization + 1e-6)

return image.reshape(*leading_dims, C, H, W)
Loading