-
Notifications
You must be signed in to change notification settings - Fork 8
/
dmlc_ssh.py
executable file
·119 lines (100 loc) · 4.22 KB
/
dmlc_ssh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
"""
DMLC submission script by ssh
One need to make sure all slaves machines are ssh-able.
"""
import argparse
import sys
import os
import subprocess
import tracker
import logging
from threading import Thread
class SSHLauncher(object):
def __init__(self, args, unknown):
self.args = args
self.cmd = (' '.join(args.command) + ' ' + ' '.join(unknown))
assert args.hostfile is not None
with open(args.hostfile) as f:
hosts = f.readlines()
assert len(hosts) > 0
self.hosts=[]
for h in hosts:
if len(h.strip()) > 0:
self.hosts.append(h.strip())
def sync_dir(self, local_dir, slave_node, slave_dir):
"""
sync the working directory from root node into slave node
"""
remote = slave_node + ':' + slave_dir
logging.info('rsync %s -> %s', local_dir, remote)
prog = 'rsync -az --rsh="ssh -o StrictHostKeyChecking=no" %s %s' % (
local_dir, remote)
subprocess.check_call([prog], shell = True)
def get_env(self, pass_envs):
envs = []
# get system envs
keys = ['LD_LIBRARY_PATH', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY']
for k in keys:
v = os.getenv(k)
if v is not None:
envs.append('export ' + k + '=' + v + ';')
# get ass_envs
for k, v in pass_envs.items():
envs.append('export ' + str(k) + '=' + str(v) + ';')
return (' '.join(envs))
def submit(self):
def ssh_submit(nworker, nserver, pass_envs):
"""
customized submit script
"""
# thread func to run the job
def run(prog):
subprocess.check_call(prog, shell = True)
# sync programs if necessary
local_dir = os.getcwd()+'/'
working_dir = local_dir
if self.args.sync_dir is not None:
working_dir = self.args.sync_dir
for h in self.hosts:
self.sync_dir(local_dir, h, working_dir)
# launch jobs
for i in range(nworker + nserver):
pass_envs['DMLC_ROLE'] = 'server' if i < nserver else 'worker'
if self.args.interface != "":
pass_envs['DMLC_INTERFACE'] = self.args.interface
node = self.hosts[i % len(self.hosts)]
prog = self.get_env(pass_envs) + ' cd ' + working_dir + '; ' + self.args.activation + self.cmd
prog = 'ssh -o StrictHostKeyChecking=no ' + node + ' \'' + prog + '\''
thread = Thread(target = run, args=(prog,))
thread.setDaemon(True)
thread.start()
return ssh_submit
def run(self):
tracker.config_logger(self.args)
tracker.submit(self.args.num_workers,
self.args.num_servers,
fun_submit = self.submit(),
pscmd = self.cmd)
def main():
parser = argparse.ArgumentParser(description='DMLC script to submit dmlc job using ssh')
parser.add_argument('-n', '--num-workers', default = 0, type=int,
help = 'number of worker nodes to be launched')
parser.add_argument('-s', '--num-servers', default = 0, type=int,
help = 'number of server nodes to be launched')
parser.add_argument('-i', '--interface', default = "", type=str,
help = 'the desired network interface')
parser.add_argument('-a', '--activation', default = "", type=str,
help = 'custom activation for local environments (e.g., to set specific conda environment)')
parser.add_argument('-H', '--hostfile', type=str,
help = 'the hostfile of all slave nodes')
parser.add_argument('command', nargs='+',
help = 'command for dmlc program')
parser.add_argument('--sync-dir', type=str,
help = 'if specificed, it will sync the current \
directory into slave machines\'s SYNC_DIR')
args, unknown = parser.parse_known_args()
launcher = SSHLauncher(args, unknown)
launcher.run()
if __name__ == '__main__':
main()