forked from Quandela/HybridAIQuantum-Challenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
95 lines (81 loc) · 3.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
import re
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
################
## DATA UTILS ##
################
# load the correct train, val dataset for the challenge, from the csv files
class MNIST_partial(Dataset):
def __init__(self, data = '.\data', transform=None, split = 'train'):
"""
Args:
data: path to dataset folder which contains train.csv and val.csv
transform (callable, optional): Optional transform to be applied
on a sample (e.g., data augmentation or normalization)
split: 'train' or 'val' to determine which set to download
"""
self.data_dir = data
self.transform = transform
self.data = []
if split == 'train':
filename = os.path.join(self.data_dir,'train.csv')
elif split == 'val':
filename = os.path.join(self.data_dir,'val.csv')
else:
raise AttributeError("split!='train' and split!='val': split must be train or val")
self.df = pd.read_csv(filename)
def __len__(self):
l = len(self.df['image'])
return l
def __getitem__(self, idx):
img = self.df['image'].iloc[idx]
label = self.df['label'].iloc[idx]
# string to list
img_list = re.split(r',', img)
# remove '[' and ']'
img_list[0] = img_list[0][1:]
img_list[-1] = img_list[-1][:-1]
# convert to float
img_float = [float(el) for el in img_list]
# convert to image
img_square = torch.unflatten(torch.tensor(img_float),0,(1,28,28))
if self.transform is not None:
img_square = self.transform(img_square)
return img_square, label
# to uncomment if you want to use the whole MNIST dataset and download it
# # if you need to download MNIST
# dataset = MNIST(root = '/home/jupyter-pemeriau/scaleway_demo/mnist-data/', download = True)
# print(f"Total length of dataset = {len(dataset)}")
####################
## TRAINING UTILS ##
####################
# plot the training curves (accuracy and loss) and save them in 'training_curves.png'
def plot_training_metrics(train_acc,val_acc,train_loss,val_loss):
fig, axes = plt.subplots(1,2,figsize = (15,5))
X = [i for i in range(len(train_acc))]
names = [str(i+1) for i in range(len(train_acc))]
axes[0].plot(X,train_acc,label = 'training')
axes[0].plot(X,val_acc,label = 'validation')
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("ACC")
axes[0].set_title("Training and validation accuracies")
axes[0].grid(visible = True)
axes[0].legend()
axes[1].plot(X,train_loss,label = 'training')
axes[1].plot(X,val_loss,label = 'validation')
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("Loss")
axes[1].set_title("Training and validation losses")
axes[1].grid(visible = True)
axes[1].legend()
axes[0].set_xticks(ticks=X,labels = names)
axes[1].set_xticks(ticks=X,labels = names)
fig.savefig("training_curves.png")
# compute the accuracy of the model
def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim = 1)
return(torch.tensor(torch.sum(preds == labels).item()/ len(preds)))