Skip to content

Commit

Permalink
FEATURE: finetuned models can be evaluated + black
Browse files Browse the repository at this point in the history
  • Loading branch information
femke-sintef committed Apr 10, 2024
1 parent 8440511 commit a09c0b0
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 51 deletions.
15 changes: 9 additions & 6 deletions dcase_fine_tune/FTBeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@ def __init__(
self._build_model()

self.train_acc = Accuracy(
task="multiclass",
num_classes=self.num_target_classes
task="multiclass", num_classes=self.num_target_classes
)
self.valid_acc = Accuracy(
task="multiclass",
num_classes=self.num_target_classes
task="multiclass", num_classes=self.num_target_classes
)
self.save_hyperparameters()

Expand All @@ -65,7 +63,9 @@ def _build_model(self):
# 2. Classifier
print(f"Classifier has {self.num_target_classes} output neurons")
self.beats.predictor_dropout = nn.Dropout(self.cfg.predictor_dropout)
self.beats.predictor = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class)
self.beats.predictor = nn.Linear(
self.cfg.encoder_embed_dim, self.cfg.predictor_class
)
# self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.num_target_classes)

def extract_features(self, x, padding_mask=None):
Expand Down Expand Up @@ -132,7 +132,10 @@ def configure_optimizers(self):
)
else:
optimizer = optim.AdamW(
self.beats.predictor.parameters(), lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
self.beats.predictor.parameters(),
lr=self.lr,
betas=(0.9, 0.98),
weight_decay=0.01,
)

return optimizer
14 changes: 8 additions & 6 deletions dcase_fine_tune/FTDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,20 @@ def __init__(
self.data_frame["category"]
)
# remove classes with too few samples
removals = self.data_frame['category'].value_counts().reset_index()
removals = removals[removals['category'] < min_sample_per_category]['index'].values
self.data_frame.drop(self.data_frame[self.data_frame["category"].isin(removals)].index, inplace=True)
removals = self.data_frame["category"].value_counts().reset_index()
removals = removals[removals["category"] < min_sample_per_category][
"index"
].values
self.data_frame.drop(
self.data_frame[self.data_frame["category"].isin(removals)].index,
inplace=True,
)
self.data_frame.reset_index(inplace=True)

# init dataset and divide into train&val
self.complete_dataset = TrainAudioDatasetDCASE(data_frame=self.data_frame)
self.divide_train_val()




def divide_train_val(self):
# Separate into training and validation set
train_indices, validation_indices, _, _ = train_test_split(
Expand Down
7 changes: 6 additions & 1 deletion dcase_fine_tune/FTevaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,12 @@ def main(cfg: DictConfig):
filename = (
os.path.basename(support_spectrograms).split("data_")[1].split(".")[0]
)
(result, pred_labels, gt_labels, result_raw,) = train_predict(
(
result,
pred_labels,
gt_labels,
result_raw,
) = train_predict(
param,
meta_df,
support_spectrograms,
Expand Down
12 changes: 8 additions & 4 deletions dcase_fine_tune/FTtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import pandas as pd
import json
from datetime import datetime
from datetime import datetime
import pytorch_lightning as pl

from dcase_fine_tune.FTBeats import BEATsTransferLearningModel
Expand Down Expand Up @@ -33,14 +33,18 @@ def train_model(
pl.callbacks.LearningRateMonitor(logging_interval="step"),
pl.callbacks.EarlyStopping(
monitor="train_loss", mode="min", patience=patience
),
),
pl.callbacks.ModelCheckpoint(
os.path.join("lightning_logs", "finetuning","{date:%Y%m%d_%H%M%S}".format(date=datetime.now())),
os.path.join(
"lightning_logs",
"finetuning",
"{date:%Y%m%d_%H%M%S}".format(date=datetime.now()),
),
monitor="val_loss",
mode="min",
save_top_k=1,
verbose=True,
)
),
],
default_root_dir=root_dir,
enable_checkpointing=True,
Expand Down
30 changes: 18 additions & 12 deletions evaluate/_utils_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def to_dataframe(features, labels):

def get_proto_coordinates(model, model_type, support_data, support_labels, n_way):
if model_type == "beats":
z_supports, _ = model.get_embeddings(support_data, padding_mask=None, skip_dropout=True)
z_supports, _ = model.get_embeddings(
support_data, padding_mask=None, skip_dropout=True
)
else:
z_supports = model.get_embeddings(support_data, padding_mask=None)

Expand Down Expand Up @@ -81,7 +83,7 @@ def merge_preds(df, tolerence, tensor_length, frame_shift, occurence_threshold):
).shift()
).cumsum()
ids, occurence = np.unique(df["group"], return_counts=True)
ids_too_short_segments = ids[occurence<occurence_threshold]
ids_too_short_segments = ids[occurence < occurence_threshold]
df.drop(df[df.group.isin(ids_too_short_segments)].index, inplace=True)
result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"})
return result
Expand Down Expand Up @@ -172,7 +174,7 @@ def predict_labels_query(
frame_shift,
overlap,
pos_index,
repetitions=1
repetitions=1,
):
"""
- l_segment to know the length of the segment
Expand All @@ -195,11 +197,13 @@ def predict_labels_query(
feature, label = data
feature = feature.to("cuda")

l_dists = torch.empty(size=(repetitions,2))
l_classification_scores = torch.empty(size=(repetitions,2))
l_dists = torch.empty(size=(repetitions, 2))
l_classification_scores = torch.empty(size=(repetitions, 2))
for rep_i in range(repetitions):
if model_type == "beats":
q_embedding, _ = model.get_embeddings(feature, padding_mask=None, skip_dropout=True)
q_embedding, _ = model.get_embeddings(
feature, padding_mask=None, skip_dropout=True
)
else:
q_embedding = model.get_embeddings(feature, padding_mask=None)

Expand All @@ -211,11 +215,11 @@ def predict_labels_query(
if model_type != "beats":
dists = dists.squeeze()
classification_scores = classification_scores.squeeze()

l_dists[rep_i] = dists
l_classification_scores[rep_i] = classification_scores
dists = torch.mean(l_dists, dim=0)
classification_scores=torch.mean(l_classification_scores, dim=0)
classification_scores = torch.mean(l_classification_scores, dim=0)

# Calculate beginTime and endTime for each segment
# We multiply by 1000 to get the time in seconds
Expand All @@ -226,8 +230,6 @@ def predict_labels_query(
begin = i * tensor_length * frame_shift * overlap / 1000
end = begin + tensor_length * frame_shift / 1000



# Get the labels (either POS or NEG):
predicted_labels = torch.max(classification_scores, 0)[
pos_index
Expand Down Expand Up @@ -263,11 +265,15 @@ def filter_outliers_by_p_values(Y, p_values, target_class=1, upper_threshold=0.0

return Y


def obtain_cdf(d_supports_to_POS_prototypes, distribution_name):
if distribution_name == "ecdf":
cdf = ECDF(d_supports_to_POS_prototypes)
elif distribution_name == "norm":
cdf = norm(loc=np.mean(d_supports_to_POS_prototypes), scale=np.std(d_supports_to_POS_prototypes)).cdf
cdf = norm(
loc=np.mean(d_supports_to_POS_prototypes),
scale=np.std(d_supports_to_POS_prototypes),
).cdf
else:
raise
return cdf
return cdf
31 changes: 23 additions & 8 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def compute(
)

if cfg["model"]["model_type"] == "beats":
z_pos_supports, _ = model.get_embeddings(support_samples_pos, padding_mask=None, skip_dropout=True)
z_pos_supports, _ = model.get_embeddings(
support_samples_pos, padding_mask=None, skip_dropout=True
)
_, d_supports_to_POS_prototypes = calculate_distance(
model_type, z_pos_supports, prototypes[pos_index]
)
Expand Down Expand Up @@ -137,7 +139,9 @@ def compute(
# z_neg_supports, _ = model.get_embeddings(support_samples_neg, padding_mask=None)

if cfg["model"]["model_type"] == "beats":
z_neg_supports, _ = model.get_embeddings(support_samples_neg, padding_mask=None, skip_dropout=True)
z_neg_supports, _ = model.get_embeddings(
support_samples_neg, padding_mask=None, skip_dropout=True
)
else:
z_neg_supports = model.get_embeddings(support_samples_neg, padding_mask=None)

Expand Down Expand Up @@ -181,7 +185,9 @@ def compute(
#########################################################

# Detect POS samples
detected_pos_indices = np.where(p_values_pos >= cfg["predict"]["self_detect_threshold_p_value"])[
detected_pos_indices = np.where(
p_values_pos >= cfg["predict"]["self_detect_threshold_p_value"]
)[
0
] # We need to be sure that it is POS samples
print(f"[INFO] SELF DETECTED {len(detected_pos_indices)} POS SAMPLES")
Expand Down Expand Up @@ -233,18 +239,25 @@ def compute(
support_samples_pos.to("cuda"), padding_mask=None, skip_dropout=True
)
_, d_supports_to_POS_prototypes = calculate_distance(
model_type, z_pos_supports.to("cuda"), prototypes[pos_index].to("cuda")
model_type,
z_pos_supports.to("cuda"),
prototypes[pos_index].to("cuda"),
)
else:
z_pos_supports = model.get_embeddings(
support_samples_pos.to("cuda"), padding_mask=None
)
_, d_supports_to_POS_prototypes = calculate_distance(
model_type, z_pos_supports.to("cuda"), prototypes[pos_index].to("cuda")
model_type,
z_pos_supports.to("cuda"),
prototypes[pos_index].to("cuda"),
)
d_supports_to_POS_prototypes = d_supports_to_POS_prototypes.squeeze()

cdf = obtain_cdf(d_supports_to_POS_prototypes.to("cpu").detach().numpy(), cfg["distribution"])
cdf = obtain_cdf(
d_supports_to_POS_prototypes.to("cpu").detach().numpy(),
cfg["distribution"],
)

##############################################################
# CHANGE THE NUMBER OF SUPPORTS AND QUERY TO A CERTAIN RATIO #
Expand All @@ -266,11 +279,13 @@ def compute(

# Calculate n_shot and n_query based on the desired ratio
# The total should not exceed the total_samples_per_class for each category
total_slots_per_class = total_samples_per_class / (1 + support_to_query_ratio)
total_slots_per_class = total_samples_per_class / (
1 + support_to_query_ratio
)
n_query = int(
total_slots_per_class
) # Round down to ensure we do not exceed the available samples
n_shot = total_samples_per_class-n_query
n_shot = total_samples_per_class - n_query

print(f"[INFO] Retraining with n_support={n_shot} and n_query={n_query}")

Expand Down
9 changes: 0 additions & 9 deletions evaluate/evaluation_metrics/evaluation_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


def remove_shots_from_ref(ref_df, number_shots=5):

ref_pos_indexes = select_events_with_value(ref_df, value=POS_VALUE)
ref_n_shot_index = ref_pos_indexes[number_shots - 1]
# remove all events (pos and UNK) that happen before this 5th event
Expand All @@ -31,14 +30,12 @@ def remove_shots_from_ref(ref_df, number_shots=5):


def select_events_with_value(data_frame, value=POS_VALUE):

indexes_list = data_frame.index[data_frame["Q"] == value].tolist()

return indexes_list


def build_matrix_from_selected_rows(data_frame, selected_indexes_list):

matrix_data = np.ones((2, len(selected_indexes_list))) * -1
for n, idx in enumerate(selected_indexes_list):
matrix_data[0, n] = data_frame.loc[idx].Starttime # start time for event n
Expand Down Expand Up @@ -98,7 +95,6 @@ def compute_tp_fp_fn(pred_events_df, ref_events_df):


def compute_scores_per_class(counts_per_class):

scores_per_class = {}
for cl in counts_per_class.keys():
tp = counts_per_class[cl]["TP"]
Expand Down Expand Up @@ -143,7 +139,6 @@ def compute_scores_from_counts(counts):


def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metadata=[]):

print("\nEvaluation for:", team_name, dataset)
# read Gt file structure: get subsets and paths for ref csvs make an inverted dictionary with audiofilenames as keys and folder as value
gt_file_structure = {}
Expand Down Expand Up @@ -173,7 +168,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada

counts_per_audiofile = {}
for audiofilename in list(pred_events_by_audiofile.keys()):

# for each audiofile, load correcponding GT File (audiofilename.csv)
ref_events_this_audiofile_all = pd.read_csv(
os.path.join(
Expand Down Expand Up @@ -322,7 +316,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada
fp = 0
total_n_pos_events_this_set = 0
for audiofile in list_audiofiles_in_set:

scores_per_audiofile[audiofile] = compute_scores_from_counts(
counts_per_audiofile[audiofile]
)
Expand Down Expand Up @@ -361,7 +354,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada


if __name__ == "__main__":

all_files = glob.glob(
"/data/DCASEfewshot/validate/d8f698b184e75c3ef4e830f9da4f148071fb4c56/results/baseline/version_0/**/eval_out.csv",
recursive=True,
Expand All @@ -370,7 +362,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada
l_fscores = []

for file in all_files:

fscore = evaluate(
pred_file_path=file,
ref_file_path="/data/DCASE/Development_Set_annotations/Validation_Set",
Expand Down
18 changes: 13 additions & 5 deletions prototypicalbeats/prototraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,18 @@ def __init__(

if model_path != "None":
self.checkpoint = torch.load(model_path)
if self.state == "validate":
if self.state == "validate" or self.state == "finetuned":
if self.state == "validate":
model_name_to_replace = "model."
elif self.state == "finetuned":
model_name_to_replace = "beats."

self.adjusted_state_dict = OrderedDict()
for k, v in self.checkpoint["state_dict"].items():
# Check if the key starts with 'module.' and remove it only then
name = k[6:] if k.startswith("model.") else k
if "predictor" in k:
continue
name = k[6:] if k.startswith(model_name_to_replace) else k
self.adjusted_state_dict[name] = v

self._build_model()
Expand Down Expand Up @@ -91,12 +98,11 @@ def _build_model(self):
}
)
self.model = BEATs(self.cfg)

if self.state == "train":
print("LOADING AUDIOSET PRE-TRAINED MODEL")
self.model.load_state_dict(self.checkpoint["model"])

if self.state == "validate":
if self.state == "validate" or self.state == "finetuned":
print("LOADING THE FINE-TUNED MODEL")
self.model.load_state_dict(self.adjusted_state_dict, strict=True)

Expand Down Expand Up @@ -166,7 +172,9 @@ def get_prototypes(self, z_support, support_labels, n_way):
def get_embeddings(self, input, padding_mask, skip_dropout=False):
"""Return the embeddings and the padding mask"""
self.model.training = self.training
return self.model.extract_features(input, padding_mask, skip_dropout=skip_dropout)
return self.model.extract_features(
input, padding_mask, skip_dropout=skip_dropout
)

def forward(
self,
Expand Down

0 comments on commit a09c0b0

Please sign in to comment.