-
Notifications
You must be signed in to change notification settings - Fork 46
/
sed.py
195 lines (158 loc) · 7.89 KB
/
sed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from __future__ import print_function
import os
import numpy as np
import time
import sys
import matplotlib.pyplot as plot
from keras.layers import Bidirectional, TimeDistributed, Conv2D, MaxPooling2D, Input, GRU, Dense, Activation, Dropout, Reshape, Permute
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from sklearn.metrics import confusion_matrix
import metrics
import utils
from IPython import embed
import keras.backend as K
K.set_image_data_format('channels_first')
plot.switch_backend('agg')
sys.setrecursionlimit(10000)
def load_data(_feat_folder, _mono, _fold=None):
feat_file_fold = os.path.join(_feat_folder, 'mbe_{}_fold{}.npz'.format('mon' if _mono else 'bin', _fold))
dmp = np.load(feat_file_fold)
_X_train, _Y_train, _X_test, _Y_test = dmp['arr_0'], dmp['arr_1'], dmp['arr_2'], dmp['arr_3']
return _X_train, _Y_train, _X_test, _Y_test
def get_model(data_in, data_out, _cnn_nb_filt, _cnn_pool_size, _rnn_nb, _fc_nb):
spec_start = Input(shape=(data_in.shape[-3], data_in.shape[-2], data_in.shape[-1]))
spec_x = spec_start
for _i, _cnt in enumerate(_cnn_pool_size):
spec_x = Conv2D(filters=_cnn_nb_filt, kernel_size=(3, 3), padding='same')(spec_x)
spec_x = BatchNormalization(axis=1)(spec_x)
spec_x = Activation('relu')(spec_x)
spec_x = MaxPooling2D(pool_size=(1, _cnn_pool_size[_i]))(spec_x)
spec_x = Dropout(dropout_rate)(spec_x)
spec_x = Permute((2, 1, 3))(spec_x)
spec_x = Reshape((data_in.shape[-2], -1))(spec_x)
for _r in _rnn_nb:
spec_x = Bidirectional(
GRU(_r, activation='tanh', dropout=dropout_rate, recurrent_dropout=dropout_rate, return_sequences=True),
merge_mode='mul')(spec_x)
for _f in _fc_nb:
spec_x = TimeDistributed(Dense(_f))(spec_x)
spec_x = Dropout(dropout_rate)(spec_x)
spec_x = TimeDistributed(Dense(data_out.shape[-1]))(spec_x)
out = Activation('sigmoid', name='strong_out')(spec_x)
_model = Model(inputs=spec_start, outputs=out)
_model.compile(optimizer='Adam', loss='binary_crossentropy')
_model.summary()
return _model
def plot_functions(_nb_epoch, _tr_loss, _val_loss, _f1, _er, extension=''):
plot.figure()
plot.subplot(211)
plot.plot(range(_nb_epoch), _tr_loss, label='train loss')
plot.plot(range(_nb_epoch), _val_loss, label='val loss')
plot.legend()
plot.grid(True)
plot.subplot(212)
plot.plot(range(_nb_epoch), _f1, label='f')
plot.plot(range(_nb_epoch), _er, label='er')
plot.legend()
plot.grid(True)
plot.savefig(__models_dir + __fig_name + extension)
plot.close()
print('figure name : {}'.format(__fig_name))
def preprocess_data(_X, _Y, _X_test, _Y_test, _seq_len, _nb_ch):
# split into sequences
_X = utils.split_in_seqs(_X, _seq_len)
_Y = utils.split_in_seqs(_Y, _seq_len)
_X_test = utils.split_in_seqs(_X_test, _seq_len)
_Y_test = utils.split_in_seqs(_Y_test, _seq_len)
_X = utils.split_multi_channels(_X, _nb_ch)
_X_test = utils.split_multi_channels(_X_test, _nb_ch)
return _X, _Y, _X_test, _Y_test
#######################################################################################
# MAIN SCRIPT STARTS HERE
#######################################################################################
is_mono = True # True: mono-channel input, False: binaural input
feat_folder = '/scratch/asignal/sharath/DCASE2017/TUT-sound-events-2017-development/feat/'
__fig_name = '{}_{}'.format('mon' if is_mono else 'bin', time.strftime("%Y_%m_%d_%H_%M_%S"))
nb_ch = 1 if is_mono else 2
batch_size = 128 # Decrease this if you want to run on smaller GPU's
seq_len = 256 # Frame sequence length. Input to the CRNN.
nb_epoch = 500 # Training epochs
patience = int(0.25 * nb_epoch) # Patience for early stopping
# Number of frames in 1 second, required to calculate F and ER for 1 sec segments.
# Make sure the nfft and sr are the same as in feature.py
sr = 44100
nfft = 2048
frames_1_sec = int(sr/(nfft/2.0))
print('\n\nUNIQUE ID: {}'.format(__fig_name))
print('TRAINING PARAMETERS: nb_ch: {}, seq_len: {}, batch_size: {}, nb_epoch: {}, frames_1_sec: {}'.format(
nb_ch, seq_len, batch_size, nb_epoch, frames_1_sec))
# Folder for saving model and training curves
__models_dir = 'models/'
utils.create_folder(__models_dir)
# CRNN model definition
cnn_nb_filt = 128 # CNN filter size
cnn_pool_size = [5, 2, 2] # Maxpooling across frequency. Length of cnn_pool_size = number of CNN layers
rnn_nb = [32, 32] # Number of RNN nodes. Length of rnn_nb = number of RNN layers
fc_nb = [32] # Number of FC nodes. Length of fc_nb = number of FC layers
dropout_rate = 0.5 # Dropout after each layer
print('MODEL PARAMETERS:\n cnn_nb_filt: {}, cnn_pool_size: {}, rnn_nb: {}, fc_nb: {}, dropout_rate: {}'.format(
cnn_nb_filt, cnn_pool_size, rnn_nb, fc_nb, dropout_rate))
avg_er = list()
avg_f1 = list()
for fold in [1, 2, 3, 4]:
print('\n\n----------------------------------------------')
print('FOLD: {}'.format(fold))
print('----------------------------------------------\n')
# Load feature and labels, pre-process it
X, Y, X_test, Y_test = load_data(feat_folder, is_mono, fold)
X, Y, X_test, Y_test = preprocess_data(X, Y, X_test, Y_test, seq_len, nb_ch)
# Load model
model = get_model(X, Y, cnn_nb_filt, cnn_pool_size, rnn_nb, fc_nb)
# Training
best_epoch, pat_cnt, best_er, f1_for_best_er, best_conf_mat = 0, 0, 99999, None, None
tr_loss, val_loss, f1_overall_1sec_list, er_overall_1sec_list = [0] * nb_epoch, [0] * nb_epoch, [0] * nb_epoch, [0] * nb_epoch
posterior_thresh = 0.5
for i in range(nb_epoch):
print('Epoch : {} '.format(i), end='')
hist = model.fit(
X, Y,
batch_size=batch_size,
validation_data=[X_test, Y_test],
epochs=1,
verbose=2
)
val_loss[i] = hist.history.get('val_loss')[-1]
tr_loss[i] = hist.history.get('loss')[-1]
# Calculate the predictions on test data, in order to calculate ER and F scores
pred = model.predict(X_test)
pred_thresh = pred > posterior_thresh
score_list = metrics.compute_scores(pred_thresh, Y_test, frames_in_1_sec=frames_1_sec)
f1_overall_1sec_list[i] = score_list['f1_overall_1sec']
er_overall_1sec_list[i] = score_list['er_overall_1sec']
pat_cnt = pat_cnt + 1
# Calculate confusion matrix
test_pred_cnt = np.sum(pred_thresh, 2)
Y_test_cnt = np.sum(Y_test, 2)
conf_mat = confusion_matrix(Y_test_cnt.reshape(-1), test_pred_cnt.reshape(-1))
conf_mat = conf_mat / (utils.eps + np.sum(conf_mat, 1)[:, None].astype('float'))
if er_overall_1sec_list[i] < best_er:
best_conf_mat = conf_mat
best_er = er_overall_1sec_list[i]
f1_for_best_er = f1_overall_1sec_list[i]
model.save(os.path.join(__models_dir, '{}_fold_{}_model.h5'.format(__fig_name, fold)))
best_epoch = i
pat_cnt = 0
print('tr Er : {}, val Er : {}, F1_overall : {}, ER_overall : {} Best ER : {}, best_epoch: {}'.format(
tr_loss[i], val_loss[i], f1_overall_1sec_list[i], er_overall_1sec_list[i], best_er, best_epoch))
plot_functions(nb_epoch, tr_loss, val_loss, f1_overall_1sec_list, er_overall_1sec_list, '_fold_{}'.format(fold))
if pat_cnt > patience:
break
avg_er.append(best_er)
avg_f1.append(f1_for_best_er)
print('saved model for the best_epoch: {} with best_f1: {} f1_for_best_er: {}'.format(
best_epoch, best_er, f1_for_best_er))
print('best_conf_mat: {}'.format(best_conf_mat))
print('best_conf_mat_diag: {}'.format(np.diag(best_conf_mat)))
print('\n\nMETRICS FOR ALL FOUR FOLDS: avg_er: {}, avg_f1: {}'.format(avg_er, avg_f1))
print('MODEL AVERAGE OVER FOUR FOLDS: avg_er: {}, avg_f1: {}'.format(np.mean(avg_er), np.mean(avg_f1)))