-
Notifications
You must be signed in to change notification settings - Fork 62
/
conformer.py
463 lines (358 loc) · 14.9 KB
/
conformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
"""
EEG Conformer
Convolutional Transformer for EEG decoding
Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths
import argparse
import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
import numpy as np
import math
import glob
import random
import itertools
import datetime
import time
import datetime
import sys
import scipy.io
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchsummary import summary
import torch.autograd as autograd
from torchvision.models import vgg19
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn.init as init
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from sklearn.decomposition import PCA
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
# from common_spatial_pattern import csp
import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True
# writer = SummaryWriter('./TensorBoardX/')
# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
def __init__(self, emb_size=40):
# self.patch_size = patch_size
super().__init__()
self.shallownet = nn.Sequential(
nn.Conv2d(1, 40, (1, 25), (1, 1)),
nn.Conv2d(40, 40, (22, 1), (1, 1)),
nn.BatchNorm2d(40),
nn.ELU(),
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
nn.Dropout(0.5),
)
self.projection = nn.Sequential(
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly
Rearrange('b e (h) (w) -> b (h w) e'),
)
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.shallownet(x)
x = self.projection(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class GELU(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=10,
drop_p=0.5,
forward_expansion=4,
forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
# global average pooling
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
self.fc = nn.Sequential(
nn.Linear(2440, 256),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(256, 32),
nn.ELU(),
nn.Dropout(0.3),
nn.Linear(32, 4)
)
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return x, out
class Conformer(nn.Sequential):
def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs):
super().__init__(
PatchEmbedding(emb_size),
TransformerEncoder(depth, emb_size),
ClassificationHead(emb_size, n_classes)
)
class ExP():
def __init__(self, nsub):
super(ExP, self).__init__()
self.batch_size = 72
self.n_epochs = 2000
self.c_dim = 4
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.999
self.dimension = (190, 50)
self.nSub = nsub
self.start_epoch = 0
self.root = '/Data/strict_TE/'
self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")
self.Tensor = torch.cuda.FloatTensor
self.LongTensor = torch.cuda.LongTensor
self.criterion_l1 = torch.nn.L1Loss().cuda()
self.criterion_l2 = torch.nn.MSELoss().cuda()
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
self.model = Conformer().cuda()
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
self.model = self.model.cuda()
# summary(self.model, (1, 22, 1000))
# Segmentation and Reconstruction (S&R) data augmentation
def interaug(self, timg, label):
aug_data = []
aug_label = []
for cls4aug in range(4):
cls_idx = np.where(label == cls4aug + 1)
tmp_data = timg[cls_idx]
tmp_label = label[cls_idx]
tmp_aug_data = np.zeros((int(self.batch_size / 4), 1, 22, 1000))
for ri in range(int(self.batch_size / 4)):
for rj in range(8):
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
rj * 125:(rj + 1) * 125]
aug_data.append(tmp_aug_data)
aug_label.append(tmp_label[:int(self.batch_size / 4)])
aug_data = np.concatenate(aug_data)
aug_label = np.concatenate(aug_label)
aug_shuffle = np.random.permutation(len(aug_data))
aug_data = aug_data[aug_shuffle, :, :]
aug_label = aug_label[aug_shuffle]
aug_data = torch.from_numpy(aug_data).cuda()
aug_data = aug_data.float()
aug_label = torch.from_numpy(aug_label-1).cuda()
aug_label = aug_label.long()
return aug_data, aug_label
def get_source_data(self):
# ! please please recheck if you need validation set
# ! and the data segement compared methods used
# train data
self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
self.train_data = self.total_data['data']
self.train_label = self.total_data['label']
self.train_data = np.transpose(self.train_data, (2, 1, 0))
self.train_data = np.expand_dims(self.train_data, axis=1)
self.train_label = np.transpose(self.train_label)
self.allData = self.train_data
self.allLabel = self.train_label[0]
shuffle_num = np.random.permutation(len(self.allData))
self.allData = self.allData[shuffle_num, :, :, :]
self.allLabel = self.allLabel[shuffle_num]
# test data
self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
self.test_data = self.test_tmp['data']
self.test_label = self.test_tmp['label']
self.test_data = np.transpose(self.test_data, (2, 1, 0))
self.test_data = np.expand_dims(self.test_data, axis=1)
self.test_label = np.transpose(self.test_label)
self.testData = self.test_data
self.testLabel = self.test_label[0]
# standardize
target_mean = np.mean(self.allData)
target_std = np.std(self.allData)
self.allData = (self.allData - target_mean) / target_std
self.testData = (self.testData - target_mean) / target_std
# data shape: (trial, conv channel, electrode channel, time samples)
return self.allData, self.allLabel, self.testData, self.testLabel
def train(self):
img, label, test_data, test_label = self.get_source_data()
img = torch.from_numpy(img)
label = torch.from_numpy(label - 1)
dataset = torch.utils.data.TensorDataset(img, label)
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
test_data = torch.from_numpy(test_data)
test_label = torch.from_numpy(test_label - 1)
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
# Optimizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
test_data = Variable(test_data.type(self.Tensor))
test_label = Variable(test_label.type(self.LongTensor))
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
total_step = len(self.dataloader)
curr_lr = self.lr
for e in range(self.n_epochs):
# in_epoch = time.time()
self.model.train()
for i, (img, label) in enumerate(self.dataloader):
img = Variable(img.cuda().type(self.Tensor))
label = Variable(label.cuda().type(self.LongTensor))
# data augmentation
aug_data, aug_label = self.interaug(self.allData, self.allLabel)
img = torch.cat((img, aug_data))
label = torch.cat((label, aug_label))
tok, outputs = self.model(img)
loss = self.criterion_cls(outputs, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# out_epoch = time.time()
# test process
if (e + 1) % 1 == 0:
self.model.eval()
Tok, Cls = self.model(test_data)
loss_test = self.criterion_cls(Cls, test_label)
y_pred = torch.max(Cls, 1)[1]
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
print('Epoch:', e,
' Train loss: %.6f' % loss.detach().cpu().numpy(),
' Test loss: %.6f' % loss_test.detach().cpu().numpy(),
' Train accuracy %.6f' % train_acc,
' Test accuracy is %.6f' % acc)
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc:
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch.save(self.model.module.state_dict(), 'model.pth')
averAcc = averAcc / num
print('The average accuracy is:', averAcc)
print('The best accuracy is:', bestAcc)
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
return bestAcc, averAcc, Y_true, Y_pred
# writer.close()
def main():
best = 0
aver = 0
result_write = open("./results/sub_result.txt", "w")
for i in range(9):
starttime = datetime.datetime.now()
seed_n = np.random.randint(2021)
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
print('Subject %d' % (i+1))
exp = ExP(i + 1)
bestAcc, averAcc, Y_true, Y_pred = exp.train()
print('THE BEST ACCURACY IS ' + str(bestAcc))
result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
endtime = datetime.datetime.now()
print('subject %d duration: '%(i+1) + str(endtime - starttime))
best = best + bestAcc
aver = aver + averAcc
if i == 0:
yt = Y_true
yp = Y_pred
else:
yt = torch.cat((yt, Y_true))
yp = torch.cat((yp, Y_pred))
best = best / 9
aver = aver / 9
result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
result_write.close()
if __name__ == "__main__":
print(time.asctime(time.localtime(time.time())))
main()
print(time.asctime(time.localtime(time.time())))