Skip to content

Commit

Permalink
[ADD] seed_everything for reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Mar 12, 2024
1 parent e82a107 commit 5da85d6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
9 changes: 6 additions & 3 deletions evaluate/_utils_compute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch

from tqdm import tqdm
Expand All @@ -9,6 +8,10 @@
from prototypicalbeats.prototraining import ProtoBEATsModel
from datamodules.TestDCASEDataModule import DCASEDataModule, AudioDatasetDCASE

import pytorch_lightning as pl
from pl.utilities.seed import seed_everything
seed_everything(42, workers=True)

def to_dataframe(features, labels):
# Load the saved array and map the features and labels into a single dataframe
input_features = np.load(features)
Expand Down Expand Up @@ -92,7 +95,7 @@ def reshape_support(support_samples, tensor_length=128, n_subsample=1):
def train_model(
model_type=None,
datamodule_class=DCASEDataModule,
max_epochs=1,
max_epochs=5,
enable_model_summary=False,
num_sanity_val_steps=0,
seed=42,
Expand All @@ -102,7 +105,7 @@ def train_model(
):
# create the lightning trainer object
trainer = pl.Trainer(
max_epochs=1,
max_epochs=max_epochs,
enable_model_summary=enable_model_summary,
num_sanity_val_steps=num_sanity_val_steps,
deterministic=True,
Expand Down
16 changes: 12 additions & 4 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from datamodules.TestDCASEDataModule import DCASEDataModule, AudioDatasetDCASE

import pytorch_lightning as pl
from pl.utilities.seed import seed_everything
seed_everything(42, workers=True)

from callbacks.callbacks import MilestonesFinetuning

Expand Down Expand Up @@ -151,6 +153,12 @@ def compute(

# Detect POS samples
detected_pos_indices = np.where(p_values_pos == 1)[0]
print(f"[INFO] SELF DETECTED {detected_pos_indices} POS SAMPLES")

# BECAUSE CUDA ERROR WHEN RESAMPLING TOO MANY SAMPLES
if len(detected_pos_indices) > 40:
detected_pos_indices = np.random.choice(detected_pos_indices, size=40, replace=False)

df_extension_pos = df_query.iloc[detected_pos_indices].copy()
df_extension_pos["category"] = "POS"

Expand All @@ -159,6 +167,7 @@ def compute(

# Randomly sample NEG samples to match the number of POS samples
num_pos_samples = len(detected_pos_indices)

if num_pos_samples > 0 and len(detected_neg_indices) > num_pos_samples:
sampled_neg_indices = np.random.choice(detected_neg_indices, size=num_pos_samples, replace=False)
else:
Expand Down Expand Up @@ -190,21 +199,19 @@ def compute(
# CHANGE THE NUMBER OF SUPPORTS AND QUERY TO A CERTAIN RATIO #
##############################################################

support_to_query_ratio = 2
support_to_query_ratio = 1.5

# Calculate the total number of samples for each category (POS and NEG)
num_pos_samples = len(df_support_extended[df_support_extended['category'] == 'POS'])
num_neg_samples = len(df_support_extended[df_support_extended['category'] == 'NEG'])
print(num_pos_samples)
print(num_neg_samples)

# Assuming an equal distribution of samples between support and query for simplicity
total_samples_per_class = min(num_pos_samples, num_neg_samples)
print(f"TOTAL SLOT PER CLASS {total_samples_per_class}")

# Calculate n_shot and n_query based on the desired ratio
# The total should not exceed the total_samples_per_class for each category
total_slots_per_class = total_samples_per_class / (1 + support_to_query_ratio)
print(f"TOTAL SLOT PER CLASS {total_samples_per_class}")
n_query = int(total_slots_per_class) # Round down to ensure we do not exceed the available samples
n_shot = int(total_slots_per_class * support_to_query_ratio)

Expand Down Expand Up @@ -484,6 +491,7 @@ def main(cfg: DictConfig):

results = results.append(result)
results_raw = results_raw.append(result_raw)

if cfg["predict"]["wav_save"]:
write_wav(
files,
Expand Down

0 comments on commit 5da85d6

Please sign in to comment.