Skip to content

Commit

Permalink
Merge pull request #10 from NINAnor/ablation_study
Browse files Browse the repository at this point in the history
Ablation study
  • Loading branch information
BenCretois authored Jan 11, 2024
2 parents 16d9fd0 + aee2784 commit 98ada32
Show file tree
Hide file tree
Showing 8 changed files with 605 additions and 25 deletions.
4 changes: 2 additions & 2 deletions CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ data:
status: train # used in preprocessing, train or validate or evaluate



#################################
# PARAMETERS FOR MODEL TRAINING #
#################################
Expand All @@ -40,7 +39,8 @@ trainer:
model:
distance: euclidean # other option is mahalanobis
lr: 1.0e-05
model_path: /data/BEATs/BEATs_iter3_plus_AS2M.pt # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt
model_type: pann # beats, pann or baseline
model_path: /data/model/PANN/Cnn14_mAP=0.431.pth # /data/model/BEATs/BEATs_iter3_plus_AS2M.pt # # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt
specaugment_params: null
# specaugment_params:
# application_ratio: 1.0
Expand Down
34 changes: 34 additions & 0 deletions Models/baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from collections import OrderedDict

def conv_block(in_channels,out_channels):

return nn.Sequential(
nn.Conv2d(in_channels,out_channels,3,padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)

class ProtoNet(nn.Module):
def __init__(self):
super(ProtoNet,self).__init__()
self.encoder = nn.Sequential(
conv_block(1,64),
conv_block(64,64),
conv_block(64,64),
conv_block(64,64)
)
def forward(self,x):
(num_samples,seq_len,mel_bins) = x.shape
x = x.view(-1,1,seq_len,mel_bins)
print(x.shape)
x = self.encoder(x)
x = nn.MaxPool2d(2)(x)

return x.view(x.size(0),-1)

def extract_features(self, x, padding_mask=None):
return self.forward(x)
Loading

0 comments on commit 98ada32

Please sign in to comment.