-
Notifications
You must be signed in to change notification settings - Fork 8
/
ckpt_manager.py
170 lines (140 loc) · 6.66 KB
/
ckpt_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
from shutil import *
import torch
import numpy as np
class CKPT_Manager:
def __init__(self, root_dir, model_name, max_files_to_keep = 10, is_descending = False):
self.root_dir = root_dir
self.root_dir_ckpt = os.path.join(root_dir, 'ckpt')
self.root_dir_state = os.path.join(root_dir, 'state')
self.model_name = model_name
self.max_files = max_files_to_keep
self.ckpt_list = os.path.join(self.root_dir, 'checkpoints.txt')
self.is_descending = is_descending
def load_ckpt(self, network, by_score = True, name = None, abs_name = None, epoch = None):
# get ckpt path
if name is None and abs_name is None and epoch is None:
try:
with open(self.ckpt_list, 'r') as file:
lines = file.read().splitlines()
file.close()
except:
print('ckpt_list does not exists')
return
if by_score:
file_name = lines[0].split(' ')[0]
else:
file_name = lines[-1].split(' ')[0]
file_path = os.path.join(self.root_dir_ckpt, file_name)
else:
if name is not None:
file_name = name
file_path = os.path.join(self.root_dir_ckpt, file_name)
if abs_name is not None:
file_name = os.path.basename(abs_name)
file_path = abs_name
if epoch is not None:
file_name = '{}_{:05d}.pytorch'.format(self.model_name, epoch)
file_path = os.path.join(self.root_dir_ckpt, file_name)
device_id = torch.cuda.current_device()
return network.load_state_dict(torch.load(file_path, map_location="cuda:{}".format(device_id)), strict=False), os.path.basename(file_name)
def resume(self, network, resume_name, rank = -1):
# todo
# according to the resume_name,
# 2. load ckpt
resume_name = self.model_name + '_' + '{:05d}'.format(int(resume_name)) + '.pytorch'
result, _ = self.load_ckpt(network, name = resume_name)
print('Rank[{}]: '.format(rank), result)
# 3. load state
device_id = torch.cuda.current_device()
file_name = os.path.join(self.root_dir_state, resume_name)
# resume_state = torch.load(file_name, map_location=lambda storage, loc: storage.cuda(device_id))
resume_state = torch.load(file_name, map_location="cuda:{}".format(device_id))
# 4. files in chekpoints?
# if the resume state is the last file, good to go
# if the resume state is in the middle of states/checkpoints, remove the states/checkpoints after the resume state
if rank <= 0:
with open(self.ckpt_list, 'r') as file:
lines = file.read().splitlines()
file.close()
lines_to_add = []
line_recent = None
epoch_to_resume = int(resume_name.split('.')[0].split('_')[-1])
for line in lines[:-1]:
file_name = line.split(' ')[0]
epoch = int(file_name.split('.')[0].split('_')[-1])
if epoch > epoch_to_resume:
os.remove(os.path.join(self.root_dir_ckpt, file_name))
os.remove(os.path.join(self.root_dir_state, file_name))
elif epoch == epoch_to_resume:
line_recent = line
lines_to_add.append(line)
else:
lines_to_add.append(line)
if line_recent == None:
line_recent = lines[-1]
lines_to_add.append(line_recent)
with open(self.ckpt_list, 'w') as file:
for line in lines_to_add:
file.write(line + os.linesep)
file.close()
self._update_files()
return resume_state
def save(self, network, state, epoch, score):
if type(epoch) == str:
file_name = self.model_name + '_' + epoch + '.pytorch'
else:
file_name = self.model_name + '_' + '{:05d}'.format(epoch) + '.pytorch'
save_path = os.path.join(self.root_dir_ckpt, file_name)
torch.save(network.state_dict(), save_path)
save_path = os.path.join(self.root_dir_state, file_name)
torch.save(state, save_path)
# remove the most recently added line
if os.path.exists(self.ckpt_list):
with open(self.ckpt_list, 'r') as file:
lines = file.read().splitlines()
line_to_remove = lines[-1]
if line_to_remove not in lines[:-1]:
os.remove(os.path.join(self.root_dir_ckpt, line_to_remove.split(' ')[0]))
os.remove(os.path.join(self.root_dir_state, line_to_remove.split(' ')[0]))
del(lines[-1])
file.close()
with open(self.ckpt_list, 'w') as file:
for line in lines:
file.write(line + os.linesep)
file.close()
with open(self.ckpt_list, 'a') as file:
#line_to_add = file_name + ' ' + str(score) + os.linesep
line_to_add = file_name
for s in score:
line_to_add = line_to_add + ' ' + str(s)
line_to_add = line_to_add + os.linesep
file.write(line_to_add) # for the new ckpt
file.write(line_to_add) # for the most recent ckpt
file.close()
self._update_files()
def _update_files(self):
# read file
with open(self.ckpt_list, 'r') as file:
lines = file.read().splitlines()
file.close()
# sort by score
line_recent = lines[-1]
lines_prev = self._sort(lines[:-1])
# delete ckpt
while len(lines_prev) > self.max_files:
line_to_remove = lines_prev[-1]
if line_to_remove != line_recent:
os.remove(os.path.join(self.root_dir_ckpt, line_to_remove.split(' ')[0]))
os.remove(os.path.join(self.root_dir_state, line_to_remove.split(' ')[0]))
del(lines_prev[-1])
# update ckpt list
with open(self.ckpt_list, 'w') as file:
for line in lines_prev:
file.write(line + os.linesep)
file.write(line_recent + os.linesep)
file.close()
def _sort(self, lines):
scores = [float(score.split(' ')[1]) for score in lines]
lines = [line for _, line in sorted(zip(scores, lines), reverse=self.is_descending)]
return lines