From 32848c30ed42ed6cd1e740fbafa36ba82a3e44da Mon Sep 17 00:00:00 2001 From: belsten Date: Fri, 22 Nov 2024 15:00:30 -0800 Subject: [PATCH] restructure repo st transforms, datasets, dictionaries are all in base level, not in data dir --- sparsecoding/data/datasets/field.py | 63 ------------------- .../{data/datasets/bars.py => datasets.py} | 56 +++++++++++++++++ sparsecoding/dictionaries.py | 26 ++++++++ sparsecoding/{data => }/transforms/patch.py | 0 sparsecoding/{data => }/transforms/whiten.py | 0 5 files changed, 82 insertions(+), 63 deletions(-) delete mode 100644 sparsecoding/data/datasets/field.py rename sparsecoding/{data/datasets/bars.py => datasets.py} (51%) create mode 100644 sparsecoding/dictionaries.py rename sparsecoding/{data => }/transforms/patch.py (100%) rename sparsecoding/{data => }/transforms/whiten.py (100%) diff --git a/sparsecoding/data/datasets/field.py b/sparsecoding/data/datasets/field.py deleted file mode 100644 index 4a0aae8..0000000 --- a/sparsecoding/data/datasets/field.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -from scipy.io import loadmat -import torch -from torch.utils.data import Dataset - -from sparsecoding.data.transforms.patch import patchify - - -class FieldDataset(Dataset): - """Dataset used in Olshausen & Field (1996). - - Paper: - https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf - Emergence of simple-cell receptive field properties - by learning a sparse code for natural images. - - Parameters - ---------- - root : str - Location to download the dataset to. - patch_size : int - Side length of patches for sparse dictionary learning. - stride : int, optional - Stride for sampling patches. If not specified, set to `patch_size` - (non-overlapping patches). - """ - - B = 10 - C = 1 - H = 512 - W = 512 - - def __init__( - self, - root: str, - patch_size: int = 8, - stride: int = None, - ): - self.P = patch_size - if stride is None: - stride = patch_size - - root = os.path.expanduser(root) - os.system(f"mkdir -p {root}") - if not os.path.exists(f"{root}/field.mat"): - os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") - os.system(f"mv IMAGES.mat {root}/field.mat") - - self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] - assert self.images.shape == (self.H, self.W, self.B) - - self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] - self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] - - self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] - self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] - - def __len__(self): - return self.patches.shape[0] - - def __getitem__(self, idx): - return self.patches[idx] diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/datasets.py similarity index 51% rename from sparsecoding/data/datasets/bars.py rename to sparsecoding/datasets.py index 242e1d5..f25ed51 100644 --- a/sparsecoding/data/datasets/bars.py +++ b/sparsecoding/datasets.py @@ -61,3 +61,59 @@ def __len__(self): def __getitem__(self, idx: int): return self.data[idx] + + +class FieldDataset(Dataset): + """Dataset used in Olshausen & Field (1996). + + Paper: + https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf + Emergence of simple-cell receptive field properties + by learning a sparse code for natural images. + + Parameters + ---------- + root : str + Location to download the dataset to. + patch_size : int + Side length of patches for sparse dictionary learning. + stride : int, optional + Stride for sampling patches. If not specified, set to `patch_size` + (non-overlapping patches). + """ + + B = 10 + C = 1 + H = 512 + W = 512 + + def __init__( + self, + root: str, + patch_size: int = 8, + stride: int = None, + ): + self.P = patch_size + if stride is None: + stride = patch_size + + root = os.path.expanduser(root) + os.system(f"mkdir -p {root}") + if not os.path.exists(f"{root}/field.mat"): + os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") + os.system(f"mv IMAGES.mat {root}/field.mat") + + self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] + assert self.images.shape == (self.H, self.W, self.B) + + self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] + self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] + + self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] + self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] + + def __len__(self): + return self.patches.shape[0] + + def __getitem__(self, idx): + return self.patches[idx] diff --git a/sparsecoding/dictionaries.py b/sparsecoding/dictionaries.py new file mode 100644 index 0000000..2547e3d --- /dev/null +++ b/sparsecoding/dictionaries.py @@ -0,0 +1,26 @@ +import os +import torch +import numpy as np +import pickle as pkl + +MODULE_PATH = os.path.dirname(__file__) +DICTIONARY_PATH = os.path.join(MODULE_PATH, "data/dictionaries") + + +def load_dictionary_from_pickle(path): + dictionary_file = open(path, 'rb') + numpy_dictionary = pkl.load(dictionary_file) + dictionary_file.close() + dictionary = torch.tensor(numpy_dictionary.astype(np.float32)) + return dictionary + + +def load_bars_dictionary(): + path = os.path.join(DICTIONARY_PATH, "bars", "bars-16_by_16.p") + return load_dictionary_from_pickle(path) + + +def load_olshausen_dictionary(): + path = os.path.join(DICTIONARY_PATH, "olshausen", "olshausen-1.5x_overcomplete.p") + return load_dictionary_from_pickle(path) + \ No newline at end of file diff --git a/sparsecoding/data/transforms/patch.py b/sparsecoding/transforms/patch.py similarity index 100% rename from sparsecoding/data/transforms/patch.py rename to sparsecoding/transforms/patch.py diff --git a/sparsecoding/data/transforms/whiten.py b/sparsecoding/transforms/whiten.py similarity index 100% rename from sparsecoding/data/transforms/whiten.py rename to sparsecoding/transforms/whiten.py