Skip to content

Commit

Permalink
restructure repo st transforms, datasets, dictionaries are all in bas…
Browse files Browse the repository at this point in the history
…e level, not in data dir
  • Loading branch information
belsten committed Nov 22, 2024
1 parent 828dc9d commit 32848c3
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 63 deletions.
63 changes: 0 additions & 63 deletions sparsecoding/data/datasets/field.py

This file was deleted.

56 changes: 56 additions & 0 deletions sparsecoding/data/datasets/bars.py → sparsecoding/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,59 @@ def __len__(self):

def __getitem__(self, idx: int):
return self.data[idx]


class FieldDataset(Dataset):
"""Dataset used in Olshausen & Field (1996).
Paper:
https://courses.cs.washington.edu/courses/cse528/11sp/Olshausen-nature-paper.pdf
Emergence of simple-cell receptive field properties
by learning a sparse code for natural images.
Parameters
----------
root : str
Location to download the dataset to.
patch_size : int
Side length of patches for sparse dictionary learning.
stride : int, optional
Stride for sampling patches. If not specified, set to `patch_size`
(non-overlapping patches).
"""

B = 10
C = 1
H = 512
W = 512

def __init__(
self,
root: str,
patch_size: int = 8,
stride: int = None,
):
self.P = patch_size
if stride is None:
stride = patch_size

root = os.path.expanduser(root)
os.system(f"mkdir -p {root}")
if not os.path.exists(f"{root}/field.mat"):
os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat")
os.system(f"mv IMAGES.mat {root}/field.mat")

self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B]
assert self.images.shape == (self.H, self.W, self.B)

self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W]
self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W]

self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P]
self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P]

def __len__(self):
return self.patches.shape[0]

def __getitem__(self, idx):
return self.patches[idx]
26 changes: 26 additions & 0 deletions sparsecoding/dictionaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import torch
import numpy as np
import pickle as pkl

MODULE_PATH = os.path.dirname(__file__)
DICTIONARY_PATH = os.path.join(MODULE_PATH, "data/dictionaries")


def load_dictionary_from_pickle(path):
dictionary_file = open(path, 'rb')
numpy_dictionary = pkl.load(dictionary_file)
dictionary_file.close()
dictionary = torch.tensor(numpy_dictionary.astype(np.float32))
return dictionary


def load_bars_dictionary():
path = os.path.join(DICTIONARY_PATH, "bars", "bars-16_by_16.p")
return load_dictionary_from_pickle(path)


def load_olshausen_dictionary():
path = os.path.join(DICTIONARY_PATH, "olshausen", "olshausen-1.5x_overcomplete.p")
return load_dictionary_from_pickle(path)

File renamed without changes.
File renamed without changes.

0 comments on commit 32848c3

Please sign in to comment.