forked from facebookresearch/ijepa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_distributed.py
104 lines (84 loc) · 2.8 KB
/
main_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import logging
import os
import pprint
import sys
import yaml
import submitit
from src.train import main as app_main
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()
parser = argparse.ArgumentParser()
parser.add_argument(
'--folder', type=str,
help='location to save submitit logs')
parser.add_argument(
'--batch-launch', action='store_true',
help='whether fname points to a file to batch-lauch several config files')
parser.add_argument(
'--fname', type=str,
help='yaml file containing config file names to launch',
default='configs.yaml')
parser.add_argument(
'--partition', type=str,
help='cluster partition to submit jobs on')
parser.add_argument(
'--nodes', type=int, default=1,
help='num. nodes to request for job')
parser.add_argument(
'--tasks-per-node', type=int, default=1,
help='num. procs to per node')
parser.add_argument(
'--time', type=int, default=4300,
help='time in minutes to run job')
class Trainer:
def __init__(self, fname='configs.yaml', load_model=None):
self.fname = fname
self.load_model = load_model
def __call__(self):
fname = self.fname
load_model = self.load_model
logger.info(f'called-params {fname}')
# -- load script params
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info('loaded params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)
resume_preempt = False if load_model is None else load_model
app_main(args=params, resume_preempt=resume_preempt)
def checkpoint(self):
fb_trainer = Trainer(self.fname, True)
return submitit.helpers.DelayedSubmission(fb_trainer,)
def launch():
executor = submitit.AutoExecutor(
folder=os.path.join(args.folder, 'job_%j'),
slurm_max_num_timeout=20)
executor.update_parameters(
slurm_partition=args.partition,
slurm_mem_per_gpu='55G',
timeout_min=args.time,
nodes=args.nodes,
tasks_per_node=args.tasks_per_node,
cpus_per_task=10,
gpus_per_node=args.tasks_per_node)
config_fnames = [args.fname]
jobs, trainers = [], []
with executor.batch():
for cf in config_fnames:
fb_trainer = Trainer(cf)
job = executor.submit(fb_trainer,)
trainers.append(fb_trainer)
jobs.append(job)
for job in jobs:
print(job.job_id)
if __name__ == '__main__':
args = parser.parse_args()
launch()