Skip to content

Commit

Permalink
flaking
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonSantiago committed Nov 23, 2024
1 parent 2c58d49 commit 2502a0d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion sparsecoding/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__all__ = ['quilt', 'patchify', 'sample_random_patches', 'whiten',
'compute_whitening_stats', 'compute_image_whitening_stats',
'WhiteningTransform', 'whiten_images']
'WhiteningTransform', 'whiten_images']
60 changes: 30 additions & 30 deletions sparsecoding/transforms/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@


def check_images(images: torch.Tensor, algorithm: str = 'zca'):
"""Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method
"""Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method
"""

if len(images.shape) != 4:
raise ValueError('Images must be in shape [N, C, H, W]')

if images.shape[1] != 1 and algorithm == 'frequency':
raise ValueError(f'When using frequency based decorrelation, images must \
be grayscale, received {images.shape[1]} channels')

# Running cov based methods on large images can eat memory
if algorithm in ['zca', 'pca', 'cholesky'] and (images.shape[2] > 64 or images.shape[3] > 64):
print(f'WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}. \
It is not recommended to use this for images smaller than 64x64')

# Running cov based methods on large images can eat memory
if algorithm == 'frequency' and (images.shape[2] <= 64 or images.shape[3] <= 64):
print(f'WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}. \
Expand Down Expand Up @@ -60,7 +60,7 @@ def whiten_images(images: torch.Tensor,
else:
raise ValueError(f"Unknown whitening algorithm: {algorithm}, \
must be one of ['frequency', 'pca', 'zca', 'cholesky]")


def compute_image_whitening_stats(images: torch.Tensor,
n_components=None) -> Dict:
Expand All @@ -86,37 +86,37 @@ def compute_image_whitening_stats(images: torch.Tensor,
def create_frequency_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor:
"""
Create a frequency domain filter for image whitening.
Parameters
----------
image_size: Size of the square image
f0_factor: Factor for determining the cutoff frequency (default 0.4)
Returns
----------
torch.Tensor: Frequency domain filter
"""
fx = torch.linspace(-image_size/2, image_size/2-1, image_size)
fy = torch.linspace(-image_size/2, image_size/2-1, image_size)
fx, fy = torch.meshgrid(fx, fy, indexing='xy')

rho = torch.sqrt(fx**2 + fy**2)
f_0 = f0_factor * image_size
filt = rho * torch.exp(-(rho/f_0)**4)

return fft.fftshift(filt)


@lru_cache(maxsize=32)
def get_cached_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor:
"""
Get a cached frequency filter for the given image size.
Parameters
----------
image_size: Size of the square image
f0_factor: Factor for determining the cutoff frequency
Returns
----------
torch.Tensor: Cached frequency domain filter
Expand All @@ -127,20 +127,20 @@ def get_cached_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor:
def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.) -> torch.Tensor:
"""
Normalize the variance of a tensor to a target value.
Parameters
----------
tensor: Input tensor
target_variance: Desired variance after normalization
Returns
----------
torch.Tensor: Normalized tensor
"""

centered = tensor - tensor.mean()
current_variance = torch.var(centered)

if current_variance > 0:
scale_factor = torch.sqrt(torch.tensor(target_variance) / current_variance)
return centered * scale_factor
Expand All @@ -154,25 +154,25 @@ def whiten_channel(
) -> torch.Tensor:
"""
Apply frequency domain whitening to a single channel.
Parameters
----------
channel: Single channel image tensor
filt: Frequency domain filter
target_variance: Target variance for normalization
Returns
----------
torch.Tensor: Whitened channel
"""

if torch.var(channel) < 1e-8:
return channel

# Convert to frequency domain and apply filter
If = fft.fft2(channel)
If_whitened = If * filt.to(channel.device)

# Convert back to spatial domain and normalize
whitened = torch.real(fft.ifft2(If_whitened))

Expand All @@ -191,31 +191,31 @@ def frequency_whitening(
Apply frequency domain decorrelation to batched images.
Method used in original sparsenet in Olshausen and Field in Nature
and http://www.rctn.org/bruno/sparsenet/
Parameters
----------
images: Input images of shape (N, C, H, W)
target_variance: Target variance for normalization
f0_factor: Factor for determining filter cutoff frequency
Returns
----------
torch.Tensor: Whitened images
"""
_, _, H, W = images.shape
if H != W:
raise ValueError("Images must be square")

# Get cached filter
filt = get_cached_filter(H, f0_factor)

# Process each image in the batch
whitened_batch = []
for img in images:
whitened_batch.append(
whiten_channel(img[0], filt, target_variance)
)

return torch.stack(whitened_batch).unsqueeze(1)


Expand All @@ -233,7 +233,7 @@ def __init__(
):
"""
Initialize whitening transform.
Parameters
----------
algorithm: One of ['frequency', 'pca', 'zca', 'cholesky]
Expand All @@ -245,15 +245,15 @@ def __init__(
self.stats = stats
self.compute_stats = compute_stats
self.kwargs = kwargs

def __call__(self, images: torch.Tensor) -> torch.Tensor:
"""
Apply whitening transform to images.
Parameters
----------
images: Input images of shape [N, C, H, W] or [C, H, W]
Returns
----------
Whitened images of same shape as input
Expand All @@ -273,11 +273,11 @@ def __call__(self, images: torch.Tensor) -> torch.Tensor:
self.stats,
**self.kwargs
)

# Remove batch dimension if input was single image
if single_image:
whitened = whitened.squeeze(0)

return whitened

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion sparsecoding/transforms/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def whiten(X: torch.Tensor,
Parameters
----------
X: Input data of shape [N, D] where N are unique data elements of dimensionality D
X: Input data of shape [N, D] where N are unique data elements of dimensionality D
algorithm: Whitening transform we want to apply, one of ['zca', 'pca', or 'cholesky']
stats: Dict containing precomputed whitening statistics (mean, eigenvectors, eigenvalues)
n_components: number of components to retain if computing stats
Expand Down

0 comments on commit 2502a0d

Please sign in to comment.