Skip to content

Commit

Permalink
adds dataloader and optimiser
Browse files Browse the repository at this point in the history
  • Loading branch information
lphoogenboom committed Oct 9, 2024
1 parent 7a1cb97 commit 9331e0c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 0 deletions.
Binary file modified data/datasplits/datasplit.npz
Binary file not shown.
50 changes: 50 additions & 0 deletions dataloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path
import numpy as np
import torch as tt
import sys

class Dataset(tt.utils.data.Dataset):

def __init__(self,images,labels):
self.images = images # (Batch, Height, Width)
self.images = self.images[:, None] # (Batch, 1, Height, Width)
self.images = tt.from_numpy(self.images) # Convert to tensor

self.labels = tt.from_numpy(labels) # (Batch)
self.labels = tt.nn.functional.one_hot(self.labels.long(), num_classes=10) # (Batch, 10)

def __len__(self):
return len(self.images)

def __getitem__(self, index): # get image-label pair
return dict(image=self.images[index], label=self.labels[index])

if __name__ == "__main__":

here = Path(__file__).resolve().parent
path_splits = here/"data"/"datasplits"
path_processed = here/"data"/"processed"

# split contains test data + train data + train data-->(5 * (train_##,val_##) overlapping)
split = np.load(path_splits/"datasplit.npz")

# load processed images and labels
images = np.load(path_processed/"images.npy")
labels = np.load(path_processed/"labels.npy")

# use indices from split to create data for splits
split_images_data = images[split["train_00"]]
split_labels_data = labels[split["train_00"]]

# create Dataset for specific split
dataset_train = Dataset(split_images_data,split_labels_data)

# pass to dataset loader
loader_train = tt.utils.data.DataLoader(dataset_train, batch_size=5, shuffle=True)

for data in loader_train:
image = data["image"]
label = data["label"]
print(f"Image shape: {image.shape}, Label shape: {label.shape}")
print(f"Image min: {image.min()}, Image max: {image.max()}")
sys.exit() # Exit after first iteration
5 changes: 5 additions & 0 deletions datasplitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,26 @@
images = np.load(path_images)
labels = np.load(path_labels)

# Create index lists for splitting data
split = dict()
idx_train, idx_test, _, _ = train_test_split(
range(len(labels)),
labels,
test_size=0.2,
stratify=labels,
)

# Indices in images & labels of training and test data
split["train"] = idx_train
split["test"] = idx_test

# Use indexes with training data to create 5-fold split of training data into 5 new train+val sets (new sets overlap in data)
skf_train_val = StratifiedKFold(n_splits=5, shuffle=True)
for i, (idx_train, idx_va) in enumerate(skf_train_val.split(split["train"], labels[split["train"]])):
split[f"train_{i:02d}"] = idx_train
split[f"val_{i:02d}"] = idx_va

# save test data, train data and 5 splits of the train data (train_## + val_##)
np.savez(path_split/"datasplit.npz", **split)

"""Print datasplit information"""
Expand Down
9 changes: 9 additions & 0 deletions optimiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch.optim


class Solver(torch.optim):
def __init__(lr):
self.optimizer = super().AdamW(
model.parameters(),
lr=configuration["lr"],
)
Binary file modified out/images/histogram_labels_train_val_00.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 9331e0c

Please sign in to comment.