Skip to content

Commit

Permalink
minor formating edits to fix sphinx formatting warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
belsten committed Nov 20, 2024
1 parent afc126e commit c5ed4b1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
40 changes: 22 additions & 18 deletions sparsecoding/data/transforms/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ def sample_random_patches(
Patch side length.
num_patches : int
Number of patches to sample.
image : Tensor, shape [*, C, H, W]
image : Tensor, shape [M, C, H, W]
where:
C is the number of channels,
H is the image height,
W is the image width.
M is the number of images,
C is the number of channels,
H is the image height,
W is the image width.
Returns
-------
Expand Down Expand Up @@ -75,23 +76,25 @@ def patchify(
----------
patch_size : int
Patch side length.
image : Tensor, shape [*, C, H, W]
where:
C is the number of channels,
H is the image height,
W is the image width.
image : Tensor, shape [M, C, H, W]
where:
M is the number of images,
C is the number of channels,
H is the image height,
W is the image width.
stride : int, optional
Stride between patches in pixel space. If not specified, set to
`patch_size` (non-overlapping patches).
Returns
-------
patches : Tensor, shape [*, N, C, P, P]
patches : Tensor, shape [M, N, C, P, P]
Non-overlapping patches taken from the input image,
where:
P is the patch size,
N is the number of patches, equal to H//P * W//P,
C is the number of channels of the input image.
M is the number of images,
P is the patch size,
N is the number of patches, equal to H//P * W//P,
C is the number of channels of the input image.
"""
leading_dims = image.shape[:-3]
C, H, W = image.shape[-3:]
Expand Down Expand Up @@ -141,16 +144,17 @@ def quilt(
Height for the reconstructed image.
width : int
Width for the reconstructed image.
patches : Tensor, shape [*, N, C, P, P]
patches : Tensor, shape [M, N, C, P, P]
Non-overlapping patches from an input image,
where:
P is the patch size,
N is the number of patches,
C is the number of channels in the image.
M is the number of images
P is the patch size,
N is the number of patches,
C is the number of channels in the image.
Returns
-------
image : Tensor, shape [*, C, height, width]
image : Tensor, shape [M, C, height, width]
Image reconstructed by stitching together input patches.
"""
leading_dims = patches.shape[:-4]
Expand Down
4 changes: 2 additions & 2 deletions sparsecoding/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,8 @@ class OMP(InferenceMethod):
"""
Infer coefficients for each image in data using elements dictionary.
Method description can be traced to:
"Orthogonal Matching Pursuit: Recursive Function Approximation with Application to Wavelet Decomposition"
(Y. Pati & R. Rezaiifar & P. Krishnaprasad, 1993)
"Orthogonal Matching Pursuit: Recursive Function Approximation with Application to Wavelet Decomposition"
(Y. Pati & R. Rezaiifar & P. Krishnaprasad, 1993)
"""

def __init__(self, sparsity, solver=None, return_all_coefficients=False):
Expand Down

0 comments on commit c5ed4b1

Please sign in to comment.