Skip to content

Commit

Permalink
Merge branch 'main' of github.com:NINAnor/rare_species_detections
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Jan 29, 2024
2 parents ccb4636 + c9c79a6 commit 52d4eb9
Show file tree
Hide file tree
Showing 12 changed files with 769 additions and 93 deletions.
27 changes: 16 additions & 11 deletions CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,38 @@
data:
n_task_train: 100
n_task_val: 100
target_fs: 16000
resample: true
denoise: true
normalize: true
frame_length: 25.0
tensor_length: 128
target_fs: 22050 # used in preprocessing
resample: true # used in preprocessing
denoise: true # used in preprocessing
normalize: true # used in preprocessing
frame_length: 25.0 # used in preprocessing
tensor_length: 128 # used in preprocessing
n_shot: 5
n_query: 10
overlap: 0.5
n_subsample: 1 # ask Femke what this stands for
status: train # train or validate or evaluate
overlap: 0.5 # used in preprocessing
n_subsample: 1
num_mel_bins: 128 # used in preprocessing
max_segment_length: 1.0 # used in preprocessing
status: train # used in preprocessing, train or validate or evaluate


#################################
# PARAMETERS FOR MODEL TRAINING #
#################################
# Be sure the parameters match the ones in data processing

trainer:
max_epochs: 1
max_epochs: 25
default_root_dir: /data
accelerator: gpu
gpus: 1

model:
distance: euclidean # other option is mahalanobis
lr: 1.0e-05
model_path: /data/BEATs/BEATs_iter3_plus_AS2M.pt # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt
model_type: beats # beats, pann or baseline
state: None # train or validate or None if using beats or baseline // since we remove from layers from the original PANN we can't load the ckpts normally (see _build_model in prototraining.py)
model_path: /data/lightning_logs/version_96/checkpoints/epoch=85-step=8600.ckpt #/data/lightning_logs/pann/lightning_logs/version_1/checkpoints/epoch=99-step=10000.ckpt #/data/lightning_logs/pann/lightning_logs/version_1/checkpoints/epoch=99-step=10000.ckpt #/data/model/PANN/Cnn14_mAP=0.431.pth # /data/model/BEATs/BEATs_iter3_plus_AS2M.pt # # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt
specaugment_params: null
# specaugment_params:
# application_ratio: 1.0
Expand Down
34 changes: 34 additions & 0 deletions Models/baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from collections import OrderedDict

def conv_block(in_channels,out_channels):

return nn.Sequential(
nn.Conv2d(in_channels,out_channels,3,padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)

class ProtoNet(nn.Module):
def __init__(self):
super(ProtoNet,self).__init__()
self.encoder = nn.Sequential(
conv_block(1,64),
conv_block(64,64),
conv_block(64,64),
conv_block(64,64)
)
def forward(self,x):
(num_samples,seq_len,mel_bins) = x.shape
x = x.view(-1,1,seq_len,mel_bins)
print(x.shape)
x = self.encoder(x)
x = nn.MaxPool2d(2)(x)

return x.view(x.size(0),-1)

def extract_features(self, x, padding_mask=None):
return self.forward(x)
Loading

0 comments on commit 52d4eb9

Please sign in to comment.