From a09c0b0855d5dff6ab7dac3bc1cc780f529220bd Mon Sep 17 00:00:00 2001 From: Femke Gelderblom Date: Wed, 10 Apr 2024 10:00:37 +0200 Subject: [PATCH] FEATURE: finetuned models can be evaluated + black --- dcase_fine_tune/FTBeats.py | 15 +++++---- dcase_fine_tune/FTDataModule.py | 14 +++++---- dcase_fine_tune/FTevaluate.py | 7 ++++- dcase_fine_tune/FTtrain.py | 12 ++++--- evaluate/_utils_compute.py | 30 +++++++++++------- evaluate/evaluateDCASE.py | 31 ++++++++++++++----- evaluate/evaluation_metrics/evaluation_all.py | 9 ------ prototypicalbeats/prototraining.py | 18 ++++++++--- 8 files changed, 85 insertions(+), 51 deletions(-) diff --git a/dcase_fine_tune/FTBeats.py b/dcase_fine_tune/FTBeats.py index 82572c9..6d90507 100644 --- a/dcase_fine_tune/FTBeats.py +++ b/dcase_fine_tune/FTBeats.py @@ -46,12 +46,10 @@ def __init__( self._build_model() self.train_acc = Accuracy( - task="multiclass", - num_classes=self.num_target_classes + task="multiclass", num_classes=self.num_target_classes ) self.valid_acc = Accuracy( - task="multiclass", - num_classes=self.num_target_classes + task="multiclass", num_classes=self.num_target_classes ) self.save_hyperparameters() @@ -65,7 +63,9 @@ def _build_model(self): # 2. Classifier print(f"Classifier has {self.num_target_classes} output neurons") self.beats.predictor_dropout = nn.Dropout(self.cfg.predictor_dropout) - self.beats.predictor = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class) + self.beats.predictor = nn.Linear( + self.cfg.encoder_embed_dim, self.cfg.predictor_class + ) # self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.num_target_classes) def extract_features(self, x, padding_mask=None): @@ -132,7 +132,10 @@ def configure_optimizers(self): ) else: optimizer = optim.AdamW( - self.beats.predictor.parameters(), lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01 + self.beats.predictor.parameters(), + lr=self.lr, + betas=(0.9, 0.98), + weight_decay=0.01, ) return optimizer diff --git a/dcase_fine_tune/FTDataModule.py b/dcase_fine_tune/FTDataModule.py index b96375e..affb335 100644 --- a/dcase_fine_tune/FTDataModule.py +++ b/dcase_fine_tune/FTDataModule.py @@ -60,18 +60,20 @@ def __init__( self.data_frame["category"] ) # remove classes with too few samples - removals = self.data_frame['category'].value_counts().reset_index() - removals = removals[removals['category'] < min_sample_per_category]['index'].values - self.data_frame.drop(self.data_frame[self.data_frame["category"].isin(removals)].index, inplace=True) + removals = self.data_frame["category"].value_counts().reset_index() + removals = removals[removals["category"] < min_sample_per_category][ + "index" + ].values + self.data_frame.drop( + self.data_frame[self.data_frame["category"].isin(removals)].index, + inplace=True, + ) self.data_frame.reset_index(inplace=True) # init dataset and divide into train&val self.complete_dataset = TrainAudioDatasetDCASE(data_frame=self.data_frame) self.divide_train_val() - - - def divide_train_val(self): # Separate into training and validation set train_indices, validation_indices, _, _ = train_test_split( diff --git a/dcase_fine_tune/FTevaluate.py b/dcase_fine_tune/FTevaluate.py index 2948233..0718820 100644 --- a/dcase_fine_tune/FTevaluate.py +++ b/dcase_fine_tune/FTevaluate.py @@ -363,7 +363,12 @@ def main(cfg: DictConfig): filename = ( os.path.basename(support_spectrograms).split("data_")[1].split(".")[0] ) - (result, pred_labels, gt_labels, result_raw,) = train_predict( + ( + result, + pred_labels, + gt_labels, + result_raw, + ) = train_predict( param, meta_df, support_spectrograms, diff --git a/dcase_fine_tune/FTtrain.py b/dcase_fine_tune/FTtrain.py index f6b2bc5..43a3bbc 100644 --- a/dcase_fine_tune/FTtrain.py +++ b/dcase_fine_tune/FTtrain.py @@ -3,7 +3,7 @@ import hashlib import pandas as pd import json -from datetime import datetime +from datetime import datetime import pytorch_lightning as pl from dcase_fine_tune.FTBeats import BEATsTransferLearningModel @@ -33,14 +33,18 @@ def train_model( pl.callbacks.LearningRateMonitor(logging_interval="step"), pl.callbacks.EarlyStopping( monitor="train_loss", mode="min", patience=patience - ), + ), pl.callbacks.ModelCheckpoint( - os.path.join("lightning_logs", "finetuning","{date:%Y%m%d_%H%M%S}".format(date=datetime.now())), + os.path.join( + "lightning_logs", + "finetuning", + "{date:%Y%m%d_%H%M%S}".format(date=datetime.now()), + ), monitor="val_loss", mode="min", save_top_k=1, verbose=True, - ) + ), ], default_root_dir=root_dir, enable_checkpointing=True, diff --git a/evaluate/_utils_compute.py b/evaluate/_utils_compute.py index b55a867..239f6a7 100644 --- a/evaluate/_utils_compute.py +++ b/evaluate/_utils_compute.py @@ -25,7 +25,9 @@ def to_dataframe(features, labels): def get_proto_coordinates(model, model_type, support_data, support_labels, n_way): if model_type == "beats": - z_supports, _ = model.get_embeddings(support_data, padding_mask=None, skip_dropout=True) + z_supports, _ = model.get_embeddings( + support_data, padding_mask=None, skip_dropout=True + ) else: z_supports = model.get_embeddings(support_data, padding_mask=None) @@ -81,7 +83,7 @@ def merge_preds(df, tolerence, tensor_length, frame_shift, occurence_threshold): ).shift() ).cumsum() ids, occurence = np.unique(df["group"], return_counts=True) - ids_too_short_segments = ids[occurence= cfg["predict"]["self_detect_threshold_p_value"])[ + detected_pos_indices = np.where( + p_values_pos >= cfg["predict"]["self_detect_threshold_p_value"] + )[ 0 ] # We need to be sure that it is POS samples print(f"[INFO] SELF DETECTED {len(detected_pos_indices)} POS SAMPLES") @@ -233,18 +239,25 @@ def compute( support_samples_pos.to("cuda"), padding_mask=None, skip_dropout=True ) _, d_supports_to_POS_prototypes = calculate_distance( - model_type, z_pos_supports.to("cuda"), prototypes[pos_index].to("cuda") + model_type, + z_pos_supports.to("cuda"), + prototypes[pos_index].to("cuda"), ) else: z_pos_supports = model.get_embeddings( support_samples_pos.to("cuda"), padding_mask=None ) _, d_supports_to_POS_prototypes = calculate_distance( - model_type, z_pos_supports.to("cuda"), prototypes[pos_index].to("cuda") + model_type, + z_pos_supports.to("cuda"), + prototypes[pos_index].to("cuda"), ) d_supports_to_POS_prototypes = d_supports_to_POS_prototypes.squeeze() - cdf = obtain_cdf(d_supports_to_POS_prototypes.to("cpu").detach().numpy(), cfg["distribution"]) + cdf = obtain_cdf( + d_supports_to_POS_prototypes.to("cpu").detach().numpy(), + cfg["distribution"], + ) ############################################################## # CHANGE THE NUMBER OF SUPPORTS AND QUERY TO A CERTAIN RATIO # @@ -266,11 +279,13 @@ def compute( # Calculate n_shot and n_query based on the desired ratio # The total should not exceed the total_samples_per_class for each category - total_slots_per_class = total_samples_per_class / (1 + support_to_query_ratio) + total_slots_per_class = total_samples_per_class / ( + 1 + support_to_query_ratio + ) n_query = int( total_slots_per_class ) # Round down to ensure we do not exceed the available samples - n_shot = total_samples_per_class-n_query + n_shot = total_samples_per_class - n_query print(f"[INFO] Retraining with n_support={n_shot} and n_query={n_query}") diff --git a/evaluate/evaluation_metrics/evaluation_all.py b/evaluate/evaluation_metrics/evaluation_all.py index a676872..6c26e23 100644 --- a/evaluate/evaluation_metrics/evaluation_all.py +++ b/evaluate/evaluation_metrics/evaluation_all.py @@ -19,7 +19,6 @@ def remove_shots_from_ref(ref_df, number_shots=5): - ref_pos_indexes = select_events_with_value(ref_df, value=POS_VALUE) ref_n_shot_index = ref_pos_indexes[number_shots - 1] # remove all events (pos and UNK) that happen before this 5th event @@ -31,14 +30,12 @@ def remove_shots_from_ref(ref_df, number_shots=5): def select_events_with_value(data_frame, value=POS_VALUE): - indexes_list = data_frame.index[data_frame["Q"] == value].tolist() return indexes_list def build_matrix_from_selected_rows(data_frame, selected_indexes_list): - matrix_data = np.ones((2, len(selected_indexes_list))) * -1 for n, idx in enumerate(selected_indexes_list): matrix_data[0, n] = data_frame.loc[idx].Starttime # start time for event n @@ -98,7 +95,6 @@ def compute_tp_fp_fn(pred_events_df, ref_events_df): def compute_scores_per_class(counts_per_class): - scores_per_class = {} for cl in counts_per_class.keys(): tp = counts_per_class[cl]["TP"] @@ -143,7 +139,6 @@ def compute_scores_from_counts(counts): def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metadata=[]): - print("\nEvaluation for:", team_name, dataset) # read Gt file structure: get subsets and paths for ref csvs make an inverted dictionary with audiofilenames as keys and folder as value gt_file_structure = {} @@ -173,7 +168,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada counts_per_audiofile = {} for audiofilename in list(pred_events_by_audiofile.keys()): - # for each audiofile, load correcponding GT File (audiofilename.csv) ref_events_this_audiofile_all = pd.read_csv( os.path.join( @@ -322,7 +316,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada fp = 0 total_n_pos_events_this_set = 0 for audiofile in list_audiofiles_in_set: - scores_per_audiofile[audiofile] = compute_scores_from_counts( counts_per_audiofile[audiofile] ) @@ -361,7 +354,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada if __name__ == "__main__": - all_files = glob.glob( "/data/DCASEfewshot/validate/d8f698b184e75c3ef4e830f9da4f148071fb4c56/results/baseline/version_0/**/eval_out.csv", recursive=True, @@ -370,7 +362,6 @@ def evaluate(pred_file_path, ref_file_path, team_name, dataset, savepath, metada l_fscores = [] for file in all_files: - fscore = evaluate( pred_file_path=file, ref_file_path="/data/DCASE/Development_Set_annotations/Validation_Set", diff --git a/prototypicalbeats/prototraining.py b/prototypicalbeats/prototraining.py index 5f63b6d..cb77fff 100644 --- a/prototypicalbeats/prototraining.py +++ b/prototypicalbeats/prototraining.py @@ -51,11 +51,18 @@ def __init__( if model_path != "None": self.checkpoint = torch.load(model_path) - if self.state == "validate": + if self.state == "validate" or self.state == "finetuned": + if self.state == "validate": + model_name_to_replace = "model." + elif self.state == "finetuned": + model_name_to_replace = "beats." + self.adjusted_state_dict = OrderedDict() for k, v in self.checkpoint["state_dict"].items(): # Check if the key starts with 'module.' and remove it only then - name = k[6:] if k.startswith("model.") else k + if "predictor" in k: + continue + name = k[6:] if k.startswith(model_name_to_replace) else k self.adjusted_state_dict[name] = v self._build_model() @@ -91,12 +98,11 @@ def _build_model(self): } ) self.model = BEATs(self.cfg) - if self.state == "train": print("LOADING AUDIOSET PRE-TRAINED MODEL") self.model.load_state_dict(self.checkpoint["model"]) - if self.state == "validate": + if self.state == "validate" or self.state == "finetuned": print("LOADING THE FINE-TUNED MODEL") self.model.load_state_dict(self.adjusted_state_dict, strict=True) @@ -166,7 +172,9 @@ def get_prototypes(self, z_support, support_labels, n_way): def get_embeddings(self, input, padding_mask, skip_dropout=False): """Return the embeddings and the padding mask""" self.model.training = self.training - return self.model.extract_features(input, padding_mask, skip_dropout=skip_dropout) + return self.model.extract_features( + input, padding_mask, skip_dropout=skip_dropout + ) def forward( self,