-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
85 lines (64 loc) · 2.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
from scipy.io.wavfile import read
import torch
def get_mask_from_lengths(lengths, max_len=-1):
max_len = max(torch.max(lengths).item(), max_len)
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
mask = (ids < lengths.unsqueeze(1))
return mask
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""
From https://github.com/espnet/espnet/blob/e962a3c609ad535cd7fb9649f9f9e9e0a2a27291/espnet/nets/pytorch_backend/nets_utils.py#L64
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""
See https://github.com/espnet/espnet/blob/e962a3c609ad535cd7fb9649f9f9e9e0a2a27291/espnet/nets/pytorch_backend/nets_utils.py#L179
"""
return ~make_pad_mask(lengths, xs, length_dim)
def get_drop_frame_mask_from_lengths(lengths, drop_frame_rate, r_len_pad):
batch_size = lengths.size(0)
max_len = torch.max(lengths).item() + r_len_pad
mask = get_mask_from_lengths(lengths, max_len).float()
drop_mask = torch.empty([batch_size, max_len], device=lengths.device).uniform_(0., 1.) < drop_frame_rate
drop_mask = drop_mask.float() * mask
return drop_mask
def dropout_frame(mels, global_mean, mel_lengths, drop_frame_rate, r_len_pad):
drop_mask = get_drop_frame_mask_from_lengths(mel_lengths, drop_frame_rate, r_len_pad)
dropped_mels = (mels * (1.0 - drop_mask).unsqueeze(1) +
global_mean[None, :, None] * drop_mask.unsqueeze(1))
return dropped_mels
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def to_gpu(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return torch.autograd.Variable(x)