-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
73 lines (59 loc) · 2.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import warnings
warnings.simplefilter("ignore")
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from configs.default import get_cfg_defaults
from core.pipelines import get_pipeline
def setup_config():
parser = argparse.ArgumentParser(description="voice2pose main program")
parser.add_argument("--config_file", default="", metavar="FILE", help="path to config file")
parser.add_argument("--resume_from", type=str, default=None, help="the checkpoint to resume from")
parser.add_argument("--test_only", action="store_true", help="perform testing and evaluation only")
parser.add_argument("--demo_input", type=str, default=None, help="path to input for demo")
parser.add_argument("--checkpoint", type=str, default=None, help="the checkpoint to test with")
parser.add_argument("--tag", type=str, default='', help="tag for the experiment")
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
cfg = get_cfg_defaults()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return args, cfg
def run(args, cfg):
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
pipeline = get_pipeline(cfg.PIPELINE_TYPE)(cfg)
cfg_name = args.config_file.split('/')[-1].split('.')[0]
if args.demo_input:
exp_tag = cfg_name + '-DEMO-' + args.tag
pipeline.demo(cfg, exp_tag, args.checkpoint, args.demo_input)
elif args.test_only:
exp_tag = cfg_name + '-TEST-' + args.tag
pipeline.test(cfg, exp_tag, args.checkpoint)
else:
exp_tag = cfg_name + '-TRAIN-' + args.tag
pipeline.train(cfg, exp_tag, args.resume_from)
def run_distributed(rank, args, cfg):
os.environ['MASTER_ADDR'] = cfg.SYS.MASTER_ADDR
os.environ['MASTER_PORT'] = str(cfg.SYS.MASTER_PORT)
dist.init_process_group("nccl", rank=rank, world_size=cfg.SYS.WORLD_SIZE)
run(args, cfg)
def main():
args, cfg = setup_config()
if cfg.SYS.DISTRIBUTED:
mp.spawn(run_distributed,
args=(args, cfg),
nprocs=cfg.SYS.WORLD_SIZE,
join=True)
else:
run(args, cfg)
if __name__ == "__main__":
main()