-
Notifications
You must be signed in to change notification settings - Fork 0
/
filterbank.py
199 lines (170 loc) · 7.7 KB
/
filterbank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Author : Jan Schlüter
# Copyright (c) 2017 Jan Schlüter
# https://github.com/f0k/ismir2015
"""
MIT License
Copyright (c) 2017 Jan Schlüter
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from torch import nn
import torch
import numpy as np
def create_mel_filterbank(sample_rate, frame_len, num_bands, min_freq, max_freq,
norm=True, crop=False):
"""
Creates a mel filterbank of `num_bands` triangular filters, with the first
filter starting at `min_freq` and the last one stopping at `max_freq`.
Returns the filterbank as a matrix suitable for a dot product against
magnitude spectra created from samples at a sample rate of `sample_rate`
with a window length of `frame_len` samples. If `norm`, will normalize
each filter by its area. If `crop`, will exclude rows that exceed the
maximum frequency and are therefore zero.
"""
# mel-spaced peak frequencies
min_mel = 1127 * np.log1p(min_freq / 7000.0)
max_mel = 1127 * np.log1p(max_freq / 7000.0)
peaks_mel = torch.linspace(min_mel, max_mel, num_bands + 2)
peaks_hz = 7000 * (torch.expm1(peaks_mel / 1127))
peaks_bin = peaks_hz * frame_len / sample_rate
# create filterbank
input_bins = (frame_len // 2) + 1
if crop:
input_bins = min(input_bins,
int(np.ceil(max_freq * frame_len /
float(sample_rate))))
x = torch.arange(input_bins, dtype=peaks_bin.dtype)[:, np.newaxis]
l, c, r = peaks_bin[0:-2], peaks_bin[1:-1], peaks_bin[2:]
# triangles are the minimum of two linear functions f(x) = a*x + b
# left side of triangles: f(l) = 0, f(c) = 1 -> a=1/(c-l), b=-a*l
tri_left = (x - l) / (c - l)
# right side of triangles: f(c) = 1, f(r) = 0 -> a=1/(c-r), b=-a*r
tri_right = (x - r) / (c - r)
# combine by taking the minimum of the left and right sides
tri = torch.min(tri_left, tri_right)
# and clip to only keep positive values
filterbank = torch.clamp(tri, min=0)
# normalize by area
if norm:
filterbank /= filterbank.sum(0)
return filterbank
class MelFilter(nn.Module):
def __init__(self, sample_rate, winsize, num_bands, min_freq, max_freq):
super(MelFilter, self).__init__()
melbank = create_mel_filterbank(sample_rate, winsize, num_bands,
min_freq, max_freq, crop=True)
self.register_buffer('bank', melbank)
def forward(self, x):
x = x.transpose(-1, -2) # put fft bands last
x = x[..., :self.bank.shape[0]] # remove unneeded fft bands
x = x.matmul(self.bank) # turn fft bands into mel bands
x = x.transpose(-1, -2) # put time last
return x
def state_dict(self, destination=None, prefix='', keep_vars=False):
result = super(MelFilter, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# remove all buffers; we use them as cached constants
for k in self._buffers:
del result[prefix + k]
return result
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# ignore stored buffers for backwards compatibility
for k in self._buffers:
state_dict.pop(prefix + k, None)
# temporarily hide the buffers; we do not want to restore them
buffers = self._buffers
self._buffers = {}
result = super(MelFilter, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
self._buffers = buffers
return result
class STFT(nn.Module):
def __init__(self, winsize, hopsize, complex=False):
super(STFT, self).__init__()
self.winsize = winsize
self.hopsize = hopsize
self.register_buffer('window',
torch.hann_window(winsize, periodic=False))
self.complex = complex
def state_dict(self, destination=None, prefix='', keep_vars=False):
result = super(STFT, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# remove all buffers; we use them as cached constants
for k in self._buffers:
del result[prefix + k]
return result
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# ignore stored buffers for backwards compatibility
for k in self._buffers:
state_dict.pop(prefix + k, None)
# temporarily hide the buffers; we do not want to restore them
buffers = self._buffers
self._buffers = {}
result = super(STFT, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
self._buffers = buffers
return result
def forward(self, x):
x = x.unsqueeze(1)
# we want each channel to be treated separately, so we mash
# up the channels and batch size and split them up afterwards
batchsize, channels = x.shape[:2]
x = x.reshape((-1,) + x.shape[2:])
# we apply the STFT
x = torch.stft(x, self.winsize, self.hopsize, window=self.window,
center=False, return_complex=False)
# we compute magnitudes, if requested
if not self.complex:
x = x.norm(p=2, dim=-1)
# restore original batchsize and channels in case we mashed them
x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
return x
class MedFilt(nn.Module):
"""
Withdraw the median of each frequency band
"""
def __init__(self):
super(MedFilt, self).__init__()
def forward(self, x):
return x - torch.quantile(x, 0.2, dim=-1, keepdim=True)[0]
class TemporalBatchNorm(nn.Module):
"""
Batch normalization of a (batch, channels, bands, time) tensor over all but
the previous to last dimension (the frequency bands).
"""
def __init__(self, num_bands):
super(TemporalBatchNorm, self).__init__()
self.bn = nn.BatchNorm1d(num_bands)
def forward(self, x):
shape = x.shape
# squash channels into the batch dimension
x = x.reshape((-1,) + x.shape[-2:])
# pass through 1D batch normalization
x = self.bn(x)
# restore squashed dimensions
return x.reshape(shape)
class Log1p(nn.Module):
"""
Applies log(1 + 10**a * x), with scale fixed or trainable.
"""
def __init__(self, a=0, trainable=False):
super(Log1p, self).__init__()
if trainable:
a = nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype()))
self.a = a
self.trainable = trainable
def forward(self, x):
if self.trainable or self.a != 0:
x = torch.log1p(10 ** self.a * x)
return x
def extra_repr(self):
return 'trainable={}'.format(repr(self.trainable))