Skip to content

Commit

Permalink
rename files to match new convention
Browse files Browse the repository at this point in the history
  • Loading branch information
dpaiton committed Dec 31, 2024
1 parent 087ec50 commit b45be9e
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 131 deletions.
48 changes: 48 additions & 0 deletions sparsecoding/transforms/images_patch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

from sparsecoding.transforms import patchify, quilt, sample_random_patches


def test_patchify_quilt_cycle():
X, Y, Z = 3, 4, 5
C = 3
P = 8
H = 6 * P
W = 8 * P

images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32)

patches = patchify(P, images)
assert patches.shape == (X, Y, Z, int(H / P) * int(W / P), C, P, P)

quilted_images = quilt(H, W, patches)
assert torch.allclose(
images,
quilted_images,
), "Quilted images should be equal to input images."

def test_sample_random_patches():
X, Y, Z = 3, 4, 5
C = 3
P = 8
H = 4 * P
W = 8 * P
N = 10

images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32)

random_patches = sample_random_patches(P, N, images)
assert random_patches.shape == (N, C, P, P)

# Check that patches are actually taken from one of the images.
all_patches = torch.nn.functional.unfold(
input=images.reshape(-1, C, H, W),
kernel_size=P,
) # [prod(*), C*P*P, L]
all_patches = torch.permute(all_patches, (0, 2, 1)) # [prod(*), L, C*P*P]
all_patches = torch.reshape(all_patches, (-1, C*P*P))
for n in range(N):
patch = random_patches[n].reshape(1, C*P*P)
delta = torch.abs(patch - all_patches) # [-1, C*P*P]
patchwise_delta = torch.sum(delta, dim=1) # [-1]
assert torch.min(patchwise_delta) == 0.
58 changes: 0 additions & 58 deletions sparsecoding/transforms/test_patch.py

This file was deleted.

73 changes: 0 additions & 73 deletions sparsecoding/transforms/test_whiten.py

This file was deleted.

61 changes: 61 additions & 0 deletions sparsecoding/transforms/whiten_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch

from sparsecoding.transforms import whiten


def test_zca():
N = 5000
D = 32*32

X = torch.rand((N, D), dtype=torch.float32)

X_whitened = whiten(X)

assert torch.allclose(
torch.mean(X_whitened, dim=0),
torch.zeros(D, dtype=torch.float32),
atol=1e-3,
), "Whitened data should have zero mean."
assert torch.allclose(
torch.cov(X_whitened.T),
torch.eye(D, dtype=torch.float32),
atol=1e-3,
), "Whitened data should have unit (identity) covariance."

def test_pca():
N = 5000
D = 32*32

X = torch.rand((N, D), dtype=torch.float32)

X_whitened = whiten(X, algorithm='pca')

assert torch.allclose(
torch.mean(X_whitened, dim=0),
torch.zeros(D, dtype=torch.float32),
atol=1e-3,
), "Whitened data should have zero mean."
assert torch.allclose(
torch.cov(X_whitened.T),
torch.eye(D, dtype=torch.float32),
atol=1e-3,
), "Whitened data should have unit (identity) covariance."

def test_cholesky():
N = 5000
D = 32*32

X = torch.rand((N, D), dtype=torch.float32)

X_whitened = whiten(X, algorithm='cholesky')

assert torch.allclose(
torch.mean(X_whitened, dim=0),
torch.zeros(D, dtype=torch.float32),
atol=1e-3,
), "Whitened data should have zero mean."
assert torch.allclose(
torch.cov(X_whitened.T),
torch.eye(D, dtype=torch.float32),
atol=1e-3,
), "Whitened data should have unit (identity) covariance."

0 comments on commit b45be9e

Please sign in to comment.