forked from optuna/optuna-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfastaiv1_simple.py
104 lines (79 loc) · 3.29 KB
/
fastaiv1_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Optuna example that optimizes convolutional neural network and data augmentation using FastAI V1.
In this example, we optimize the hyperparameters of a convolutional neural network and
data augmentation for hand-written digit recognition in terms of validation accuracy.
The network is implemented by fastai and
evaluated on MNIST dataset. Throughout the training of neural networks, a pruner observes
intermediate results and stops unpromising trials.
Note that this example will take longer than the other examples
as this uses the entire MNIST dataset.
You can run this example as follows, pruning can be turned on and off with the `--pruning`
argument.
$ python fastaiv1_simple.py [--pruning]
"""
import argparse
from functools import partial
import urllib
import optuna
from optuna.integration import FastAIV1PruningCallback
from fastai import vision
# Register a global custom opener to avoid HTTP Error 403: Forbidden when downloading MNIST.
opener = urllib.request.build_opener()
opener.addheaders = [("User-agent", "Mozilla/5.0")]
urllib.request.install_opener(opener)
BATCHSIZE = 128
EPOCHS = 10
path = vision.untar_data(vision.URLs.MNIST_SAMPLE)
def objective(trial):
# Data Augmentation
apply_tfms = trial.suggest_categorical("apply_tfms", [True, False])
if apply_tfms:
# MNIST is a hand-written digit dataset. Thus horizontal and vertical flipping are
# disabled. However, the two flipping will be important when the dataset is CIFAR or
# ImageNet.
tfms = vision.get_transforms(
do_flip=False,
flip_vert=False,
max_rotate=trial.suggest_int("max_rotate", -45, 45),
max_zoom=trial.suggest_float("max_zoom", 1, 2),
p_affine=trial.suggest_float("p_affine", 0.1, 1.0, step=0.1),
)
data = vision.ImageDataBunch.from_folder(
path, bs=BATCHSIZE, ds_tfms=tfms if apply_tfms else None
)
n_layers = trial.suggest_int("n_layers", 2, 5)
n_channels = [3]
for i in range(n_layers):
out_channels = trial.suggest_int("n_channels_{}".format(i), 3, 32)
n_channels.append(out_channels)
n_channels.append(2)
model = vision.simple_cnn(n_channels)
learn = vision.Learner(
data,
model,
silent=True,
metrics=[vision.accuracy],
callback_fns=[partial(FastAIV1PruningCallback, trial=trial, monitor="accuracy")],
)
learn.fit(EPOCHS)
return learn.recorder.metrics[-1][0].item()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FastAI V1 example.")
parser.add_argument(
"--pruning",
"-p",
action="store_true",
help="Activate the pruning feature. `MedianPruner` stops unpromising "
"trials at the early stages of training.",
)
args = parser.parse_args()
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))