-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathplot_darts.py
26 lines (20 loc) · 928 Bytes
/
plot_darts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import logging
from naslib.defaults.trainer import Trainer
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch
from naslib.search_spaces import NasBench301SearchSpace, SimpleCellSearchSpace
from naslib.utils import set_seed, setup_logger, get_config_from_args
config = get_config_from_args() # use --help so see the options
# config.search.batch_size = 128
config.search.epochs = 1
config.save_arch_weights = True
config.plot_arch_weights = True
config.save_arch_weights_path = f"{config.save}/save_arch"
set_seed(config.seed)
logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO) # default DEBUG is very verbose
search_space = NasBench301SearchSpace() #SimpleCellSearchSpace() # use SimpleCellSearchSpace() for less heavy search
optimizer = DARTSOptimizer(config)
optimizer.adapt_search_space(search_space)
trainer = Trainer(optimizer, config)
trainer.search()