-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
29 lines (26 loc) · 1.2 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
def write_weight_grad_stats(values, val_type, step, tb_writer):
has_nans = torch.isnan(values).any()
if has_nans:
linfo(f"************ WARNING: {val_type} contain a nan value at batch # {step}")
tb_writer.add_histogram(f"{val_type}", values, global_step=step)
values = values.abs()
tb_writer.add_scalars(main_tag=f"{val_type} Stats",
tag_scalar_dict={'abs_min': torch.min(values),
'abs_max': torch.max(values),
'abs_mean': torch.mean(values),
'abs_std': torch.std(values),
'has_nans': has_nans},
global_step=step)
def write_model_params(mod, step, tb_writer):
weights = []
grads = []
for p in mod.parameters():
weights.append(p.data.flatten())
if p.grad is not None:
grads.append(p.grad.flatten())
weights = torch.cat(weights)
write_weight_grad_stats(weights, 'Weights', step, tb_writer)
if len(grads) > 0:
grads = torch.cat(grads)
write_weight_grad_stats(grads, 'Gradients', step, tb_writer)