-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
162 lines (156 loc) · 4.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# ===========================================================================
# Project: Compression-aware Training of Neural Networks using Frank-Wolfe
# File: main.py
# Description: Starts up a run for the comparison between sparsification strategies
# ===========================================================================
import socket
import sys
import os
import shutil
import torch
import wandb
from runners.scratchRunner import scratchRunner
from runners.pretrainedRunner import pretrainedRunner
# Default wandb parameters
defaults = dict(
# System
run_id=None,
computer=socket.gethostname(),
fixed_init=None,
extensive_metrics=False,
# Setup
dataset=None,
model=None,
nepochs=None,
batch_size=None,
# Effiency
use_amp=True,
channels_last=False, # Disabled by default, since pruning will make problems. Enable this for Dense training
# Optimizer
optimizer=None,
learning_rate=None,
n_epochs_warmup=None, # number of epochs to warmup the lr, should be an int
momentum=None,
nesterov=None,
weight_decay=None,
weight_decay_schedule=None,
decouple_wd=None,
group_penalty=None,
group_penalty_type=None,
prox_threshold=None,
# Constraints
lmo=None,
lmo_mode=None,
lmo_ord=None,
lmo_value=None,
lmo_k=None,
lmo_rescale=None,
lmo_global=None,
lmo_delay=None,
lmo_nuc_method=None,
lmo_adjust_diameter=None,
# Sparsifying strategy
strategy=None,
goal_sparsity=None,
decomp_type=None,
# IMP
IMP_selector=None, # must be in ['global', 'uniform', 'uniform_plus', 'ERK', 'LAMP']
# Retraining
n_phases=None, # Should be 1, except when using IMP
n_epochs_per_phase=None,
retrain_schedule=None,
retrain_wd=None,
retrain_schedule_wd=None,
retrain_schedule_warmup=None,
retrain_schedule_init=None,
dynamic_retrain_length=None,
)
debug_mode = False
if '--debug' in sys.argv:
debug_mode = True
defaults.update(dict(
# System
run_id=1,
computer=socket.gethostname(),
fixed_init=True,
extensive_metrics=True,
# Setup
dataset='mnist',
model='SimpleCNN',
nepochs=1,
batch_size=1024,
# Efficiency
use_amp=True,
channels_last=False,
# Optimizer
optimizer='SFW',
# learning_rate='(MultiStepLR, 0.1, [3|7], 0.1)',
learning_rate='(Linear, 0.1)',
n_epochs_warmup=None, # number of epochs to warmup the lr, should be an int
momentum=0.9,
nesterov=False,
weight_decay=0.0001,
weight_decay_schedule=None,
decouple_wd=False,
group_penalty=None,
group_penalty_type='Filter_Conv',
prox_threshold=1,
# Constraints
lmo='SpectralKSupportNormBall',
lmo_mode='initialization',
lmo_ord=2,
lmo_value=10,
lmo_k=0.3,
lmo_rescale='fast_gradient',
lmo_global=False,
lmo_delay=None,
lmo_nuc_method='qrpartial',
lmo_adjust_diameter=True,
# Sparsifying strategy
strategy='struct_Dense',
goal_sparsity=0.5, # Functions as (1 - goal_energy) simultaneously
decomp_type='conv',
# IMP
IMP_selector='global', # must be in ['global', 'uniform', 'uniform_plus', 'ERK', 'LAMP']
# Retraining
n_phases=1, # Should be 1, except when using IMP
n_epochs_per_phase=0,
retrain_schedule='LLR',
retrain_wd=0.0005,
retrain_schedule_wd=None,
retrain_schedule_warmup=None,
retrain_schedule_init=None,
dynamic_retrain_length=None,
))
# Configure wandb logging
wandb.init(
config=defaults,
project='test-000', # automatically changed in sweep
entity=None, # automatically changed in sweep
)
config = wandb.config
ngpus = torch.cuda.device_count()
if ngpus > 0:
if ngpus > 1 and config.dataset == 'imagenet':
config.update(dict(device='cuda:' + ','.join(f"{i}" for i in range(ngpus))))
else:
config.update(dict(device='cuda:0'))
else:
config.update(dict(device='cpu'))
# At the moment, IMP is the only strategy that requires a pretrained model, all others start from scratch
if config.strategy in ['IMP', 'struct_IMP', 'struct_Decomp']:
# Use the pretrainedRunner
runner = pretrainedRunner(config=config, debug_mode=debug_mode)
else:
# Use the scratchRunner
runner = scratchRunner(config=config, debug_mode=debug_mode)
runner.run()
# Close wandb run
wandb_dir_path = wandb.run.dir
wandb.join()
# Delete the local files
if os.path.exists(wandb_dir_path):
shutil.rmtree(wandb_dir_path)
# Delete temporary directory
if os.path.exists(runner.tmp_dir):
shutil.rmtree(runner.tmp_dir)