-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
34 lines (30 loc) · 1.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os, sys, time
import argparse
import torch
import wandb
import logging
import torch.distributed as dist
from utils import logging_utils
logging_utils.config_logger()
from utils.YParams import YParams
from utils.trainer import Trainer
if __name__ == '__main__':
# parsers for any cmd line args
parser = argparse.ArgumentParser()
parser.add_argument("--yaml_config", default='./config/operators.yaml', type=str)
parser.add_argument("--config", default='default', type=str)
parser.add_argument("--root_dir", default='./', type=str, help='root dir to store results')
parser.add_argument("--run_num", default='0', type=str, help='sub run config')
parser.add_argument("--sweep_id", default=None, type=str, help='wandb sweep config from configs/sweep_config.yaml')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
trainer = Trainer(params, args)
if args.sweep_id and trainer.world_rank==0:
# logging.disable(logging.CRITICAL)
# start the wandb sweep agent
wandb.agent(args.sweep_id, function=trainer.launch, count=1, entity=trainer.params.entity, project=trainer.params.project)
else:
trainer.launch()
if dist.is_initialized():
dist.barrier()
logging.info('DONE')