From 2502a0d764c6b6b217c21d340e1398d1187893c5 Mon Sep 17 00:00:00 2001 From: HarrisonSantiago Date: Sat, 23 Nov 2024 11:54:21 -0500 Subject: [PATCH] flaking --- sparsecoding/transforms/__init__.py | 2 +- sparsecoding/transforms/images.py | 60 ++++++++++++++--------------- sparsecoding/transforms/whiten.py | 2 +- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/sparsecoding/transforms/__init__.py b/sparsecoding/transforms/__init__.py index 16072ba..8b77c4d 100644 --- a/sparsecoding/transforms/__init__.py +++ b/sparsecoding/transforms/__init__.py @@ -4,4 +4,4 @@ __all__ = ['quilt', 'patchify', 'sample_random_patches', 'whiten', 'compute_whitening_stats', 'compute_image_whitening_stats', - 'WhiteningTransform', 'whiten_images'] \ No newline at end of file + 'WhiteningTransform', 'whiten_images'] diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index dc83e2f..3b63e47 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -7,21 +7,21 @@ def check_images(images: torch.Tensor, algorithm: str = 'zca'): - """Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method + """Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method """ if len(images.shape) != 4: raise ValueError('Images must be in shape [N, C, H, W]') - + if images.shape[1] != 1 and algorithm == 'frequency': raise ValueError(f'When using frequency based decorrelation, images must \ be grayscale, received {images.shape[1]} channels') - + # Running cov based methods on large images can eat memory if algorithm in ['zca', 'pca', 'cholesky'] and (images.shape[2] > 64 or images.shape[3] > 64): print(f'WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}. \ It is not recommended to use this for images smaller than 64x64') - + # Running cov based methods on large images can eat memory if algorithm == 'frequency' and (images.shape[2] <= 64 or images.shape[3] <= 64): print(f'WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}. \ @@ -60,7 +60,7 @@ def whiten_images(images: torch.Tensor, else: raise ValueError(f"Unknown whitening algorithm: {algorithm}, \ must be one of ['frequency', 'pca', 'zca', 'cholesky]") - + def compute_image_whitening_stats(images: torch.Tensor, n_components=None) -> Dict: @@ -86,12 +86,12 @@ def compute_image_whitening_stats(images: torch.Tensor, def create_frequency_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor: """ Create a frequency domain filter for image whitening. - + Parameters ---------- image_size: Size of the square image f0_factor: Factor for determining the cutoff frequency (default 0.4) - + Returns ---------- torch.Tensor: Frequency domain filter @@ -99,11 +99,11 @@ def create_frequency_filter(image_size: int, f0_factor: float = 0.4) -> torch.Te fx = torch.linspace(-image_size/2, image_size/2-1, image_size) fy = torch.linspace(-image_size/2, image_size/2-1, image_size) fx, fy = torch.meshgrid(fx, fy, indexing='xy') - + rho = torch.sqrt(fx**2 + fy**2) f_0 = f0_factor * image_size filt = rho * torch.exp(-(rho/f_0)**4) - + return fft.fftshift(filt) @@ -111,12 +111,12 @@ def create_frequency_filter(image_size: int, f0_factor: float = 0.4) -> torch.Te def get_cached_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor: """ Get a cached frequency filter for the given image size. - + Parameters ---------- image_size: Size of the square image f0_factor: Factor for determining the cutoff frequency - + Returns ---------- torch.Tensor: Cached frequency domain filter @@ -127,20 +127,20 @@ def get_cached_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor: def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.) -> torch.Tensor: """ Normalize the variance of a tensor to a target value. - + Parameters ---------- tensor: Input tensor target_variance: Desired variance after normalization - + Returns ---------- torch.Tensor: Normalized tensor """ - + centered = tensor - tensor.mean() current_variance = torch.var(centered) - + if current_variance > 0: scale_factor = torch.sqrt(torch.tensor(target_variance) / current_variance) return centered * scale_factor @@ -154,13 +154,13 @@ def whiten_channel( ) -> torch.Tensor: """ Apply frequency domain whitening to a single channel. - + Parameters ---------- channel: Single channel image tensor filt: Frequency domain filter target_variance: Target variance for normalization - + Returns ---------- torch.Tensor: Whitened channel @@ -168,11 +168,11 @@ def whiten_channel( if torch.var(channel) < 1e-8: return channel - + # Convert to frequency domain and apply filter If = fft.fft2(channel) If_whitened = If * filt.to(channel.device) - + # Convert back to spatial domain and normalize whitened = torch.real(fft.ifft2(If_whitened)) @@ -191,13 +191,13 @@ def frequency_whitening( Apply frequency domain decorrelation to batched images. Method used in original sparsenet in Olshausen and Field in Nature and http://www.rctn.org/bruno/sparsenet/ - + Parameters ---------- images: Input images of shape (N, C, H, W) target_variance: Target variance for normalization f0_factor: Factor for determining filter cutoff frequency - + Returns ---------- torch.Tensor: Whitened images @@ -205,17 +205,17 @@ def frequency_whitening( _, _, H, W = images.shape if H != W: raise ValueError("Images must be square") - + # Get cached filter filt = get_cached_filter(H, f0_factor) - + # Process each image in the batch whitened_batch = [] for img in images: whitened_batch.append( whiten_channel(img[0], filt, target_variance) ) - + return torch.stack(whitened_batch).unsqueeze(1) @@ -233,7 +233,7 @@ def __init__( ): """ Initialize whitening transform. - + Parameters ---------- algorithm: One of ['frequency', 'pca', 'zca', 'cholesky] @@ -245,15 +245,15 @@ def __init__( self.stats = stats self.compute_stats = compute_stats self.kwargs = kwargs - + def __call__(self, images: torch.Tensor) -> torch.Tensor: """ Apply whitening transform to images. - + Parameters ---------- images: Input images of shape [N, C, H, W] or [C, H, W] - + Returns ---------- Whitened images of same shape as input @@ -273,11 +273,11 @@ def __call__(self, images: torch.Tensor) -> torch.Tensor: self.stats, **self.kwargs ) - + # Remove batch dimension if input was single image if single_image: whitened = whitened.squeeze(0) - + return whitened def __repr__(self): diff --git a/sparsecoding/transforms/whiten.py b/sparsecoding/transforms/whiten.py index 1dce173..8dd4429 100644 --- a/sparsecoding/transforms/whiten.py +++ b/sparsecoding/transforms/whiten.py @@ -76,7 +76,7 @@ def whiten(X: torch.Tensor, Parameters ---------- - X: Input data of shape [N, D] where N are unique data elements of dimensionality D + X: Input data of shape [N, D] where N are unique data elements of dimensionality D algorithm: Whitening transform we want to apply, one of ['zca', 'pca', or 'cholesky'] stats: Dict containing precomputed whitening statistics (mean, eigenvectors, eigenvalues) n_components: number of components to retain if computing stats