Skip to content

Commit c4bfd70

Browse files
committed
testing on rotated MNIST
0 parents  commit c4bfd70

36 files changed

+2517
-0
lines changed

.flake8

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[flake8]
2+
ignore = F401, F403
3+
max-line-length = 120
4+
exclude =
5+
.git,
6+
__pycache__,

.gitignore

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
env/
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
49+
# Translations
50+
*.mo
51+
*.pot
52+
53+
# Django stuff:
54+
*.log
55+
local_settings.py
56+
57+
# Flask stuff:
58+
instance/
59+
.webassets-cache
60+
61+
# Scrapy stuff:
62+
.scrapy
63+
64+
# Sphinx documentation
65+
docs/_build/
66+
67+
# PyBuilder
68+
target/
69+
70+
# Jupyter Notebook
71+
.ipynb_checkpoints
72+
73+
# pyenv
74+
.python-version
75+
76+
# celery beat schedule file
77+
celerybeat-schedule
78+
79+
# SageMath parsed files
80+
*.sage.py
81+
82+
# dotenv
83+
.env
84+
85+
# virtualenv
86+
.venv
87+
venv/
88+
ENV/
89+
90+
# Spyder project settings
91+
.spyderproject
92+
.spyproject
93+
94+
# Rope project settings
95+
.ropeproject
96+
97+
# mkdocs documentation
98+
/site
99+
100+
# mypy
101+
.mypy_cache/
102+
103+
# input data, saved log, checkpoints
104+
data/
105+
input/
106+
saved/
107+
datasets/
108+
109+
# editor, os cache directory
110+
.vscode/
111+
.idea/
112+
__MACOSX/

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# equivariance

base/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base_data_loader import *
2+
from .base_model import *
3+
from .base_trainer import *

base/base_data_loader.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import numpy as np
2+
from torch.utils.data import DataLoader
3+
from torch.utils.data.dataloader import default_collate
4+
from torch.utils.data.sampler import SubsetRandomSampler
5+
6+
7+
class BaseDataLoader(DataLoader):
8+
"""
9+
Base class for all data loaders
10+
"""
11+
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12+
self.validation_split = validation_split
13+
self.shuffle = shuffle
14+
15+
self.batch_idx = 0
16+
self.n_samples = len(dataset)
17+
18+
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19+
20+
self.init_kwargs = {
21+
'dataset': dataset,
22+
'batch_size': batch_size,
23+
'shuffle': self.shuffle,
24+
'collate_fn': collate_fn,
25+
'num_workers': num_workers
26+
}
27+
super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)
28+
29+
def _split_sampler(self, split):
30+
if split == 0.0:
31+
return None, None
32+
33+
idx_full = np.arange(self.n_samples)
34+
35+
np.random.seed(0)
36+
np.random.shuffle(idx_full)
37+
38+
len_valid = int(self.n_samples * split)
39+
40+
valid_idx = idx_full[0:len_valid]
41+
train_idx = np.delete(idx_full, np.arange(0, len_valid))
42+
43+
train_sampler = SubsetRandomSampler(train_idx)
44+
valid_sampler = SubsetRandomSampler(valid_idx)
45+
46+
# turn off shuffle option which is mutually exclusive with sampler
47+
self.shuffle = False
48+
self.n_samples = len(train_idx)
49+
50+
return train_sampler, valid_sampler
51+
52+
def split_validation(self):
53+
if self.valid_sampler is None:
54+
return None
55+
else:
56+
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
57+

base/base_model.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import logging
2+
import torch.nn as nn
3+
import numpy as np
4+
5+
6+
class BaseModel(nn.Module):
7+
"""
8+
Base class for all models
9+
"""
10+
def __init__(self):
11+
super(BaseModel, self).__init__()
12+
self.logger = logging.getLogger(self.__class__.__name__)
13+
14+
def forward(self, *input):
15+
"""
16+
Forward pass logic
17+
18+
:return: Model output
19+
"""
20+
raise NotImplementedError
21+
22+
def summary(self):
23+
"""
24+
Model summary
25+
"""
26+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
27+
params = sum([np.prod(p.size()) for p in model_parameters])
28+
self.logger.info('Trainable parameters: {}'.format(params))
29+
self.logger.info(self)
30+
31+
def __str__(self):
32+
"""
33+
Model prints with number of trainable parameters
34+
"""
35+
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
36+
params = sum([np.prod(p.size()) for p in model_parameters])
37+
return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params)
38+
# print(super(BaseModel, self))

0 commit comments

Comments
 (0)