forked from facebookresearch/suncet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_distributed.py
133 lines (114 loc) · 3.73 KB
/
main_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import submitit
import argparse
import logging
import pprint
import yaml
import sys
from src.paws_train import main as paws
from src.suncet_train import main as suncet
from src.fine_tune import main as fine_tune
from src.snn_fine_tune import main as snn_fine_tune
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()
parser = argparse.ArgumentParser()
parser.add_argument(
'--folder', type=str,
help='location to save submitit logs',
default='/checkpoint/submitit/')
parser.add_argument(
'--fname', type=str,
help='yaml file containing config file names to launch',
default='configs.yaml')
parser.add_argument(
'--batch-launch', action='store_true',
help='whether fname points to a file to batch-lauch several config files')
parser.add_argument(
'--device', type=str,
help='type of GPU to use')
parser.add_argument(
'--partition', type=str,
help='partition to submit jobs on')
parser.add_argument(
'--nodes', type=int, default=1,
help='num. nodes to request for job')
parser.add_argument(
'--tasks-per-node', type=int, default=1,
help='num. procs to per node')
parser.add_argument(
'--time', type=int,
help='time in minutes to run job',
default=5)
parser.add_argument(
'--sel', type=str,
help='which script to run',
choices=[
'paws_train',
'suncet_train',
'fine_tune',
'snn_fine_tune'
])
class Trainer:
def __init__(self, sel, fname='configs.yaml', load_model=None):
self.sel = sel
self.fname = fname
self.load_model = load_model
def __call__(self):
sel = self.sel
fname = self.fname
load_model = self.load_model
logger.info(f'called-params {sel} {fname} {load_model}')
# -- load script params
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
if load_model is not None:
params['meta']['load_checkpoint'] = load_model
logger.info('loaded params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)
logger.info('Running %s' % sel)
if sel == 'paws_train':
return paws(params)
elif sel == 'suncet_train':
return suncet(params)
elif sel == 'fine_tune':
return fine_tune(params)
elif sel == 'snn_fine_tune':
return snn_fine_tune(params)
def checkpoint(self):
fb_trainer = Trainer(self.sel, self.fname, True)
return submitit.helpers.DelayedSubmission(fb_trainer,)
def launch():
executor = submitit.AutoExecutor(folder=args.folder)
executor.update_parameters(
slurm_partition=args.partition,
slurm_constraint=args.device,
slurm_comment='running PAWS code',
slurm_mem='450G',
timeout_min=args.time,
nodes=args.nodes,
tasks_per_node=args.tasks_per_node,
cpus_per_task=10,
gpus_per_node=args.tasks_per_node)
config_fnames = [args.fname]
if args.batch_launch:
with open(args.fname, 'r') as y_file:
config_fnames = yaml.load(y_file, Loader=yaml.FullLoader)
jobs, trainers = [], []
with executor.batch():
for cf in config_fnames:
fb_trainer = Trainer(args.sel, cf)
job = executor.submit(fb_trainer,)
trainers.append(fb_trainer)
jobs.append(job)
for job in jobs:
print(job.job_id)
if __name__ == '__main__':
args = parser.parse_args()
launch()