Skip to content

Commit

Permalink
FEAT: exposed max_segment_length as config param + black
Browse files Browse the repository at this point in the history
  • Loading branch information
femke-sintef committed Jan 11, 2024
1 parent ff311e7 commit 47e5d51
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 39 deletions.
1 change: 1 addition & 0 deletions CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ data:
overlap: 0.5 # used in preprocessing
n_subsample: 1
num_mel_bins: 128 # used in preprocessing
max_segment_length: 1.0 # used in preprocessing
status: train # used in preprocessing, train or validate or evaluate


Expand Down
40 changes: 25 additions & 15 deletions data_utils/DCASEfewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
PLOT = False
PLOT_TOO_SHORT_SAMPLES = False
PLOT_SUPPORT = False
MAX_SEGMENT_LENGTH = 1.0


def normalize_mono(samples):
Expand Down Expand Up @@ -95,6 +94,7 @@ def prepare_training_val_data(
target_fs=16000,
overlap=0.5,
num_mel_bins=128,
max_segment_length=1.0,
):
"""Prepare the Training_Set
Expand All @@ -120,7 +120,7 @@ def preprocess_df(df):
continue
# obtain a segment with large margins around event
extra_time = 3
segment_length_here = min(min_segment_lengths[label], MAX_SEGMENT_LENGTH)
segment_length_here = min(min_segment_lengths[label], max_segment_length)
frame_shift = np.round(segment_length_here / tensor_length * 1000)
frame_shift = 1 if frame_shift < 1 else frame_shift
start_waveform = int((df["Starttime"][ind] - extra_time) * fs)
Expand Down Expand Up @@ -152,7 +152,7 @@ def preprocess_df(df):
sample_frequency=target_fs,
frame_length=frame_length,
frame_shift=frame_shift,
num_mel_bins=num_mel_bins
num_mel_bins=num_mel_bins,
)
data = fbank.data[0].T
# select the relevant segment (without the large margins)
Expand Down Expand Up @@ -262,7 +262,8 @@ def preprocess_df(df):
"tensor_length": tensor_length,
"set_type": set_type,
"overlap": overlap,
"num_mel_bins": num_mel_bins
"num_mel_bins": num_mel_bins,
"max_segment_length": max_segment_length,
}
if resample:
my_hash_dict["target_fs"] = target_fs
Expand All @@ -282,7 +283,7 @@ def preprocess_df(df):
os.makedirs(os.path.join(target_path, "plots"))

# Save my_hash_dict as a metadata file
with open(os.path.join(target_path, 'metadata.json'), 'w') as f:
with open(os.path.join(target_path, "metadata.json"), "w") as f:
json.dump(my_hash_dict, f)

print("=== Processing data ===")
Expand Down Expand Up @@ -313,9 +314,9 @@ def preprocess_df(df):
audio_path = file.replace("csv", "wav")
print("Processing file name {}".format(audio_path))
y, fs = librosa.load(audio_path, sr=None, mono=True)
if not resample or my_hash_dict["target_fs"]> fs:
if not resample or my_hash_dict["target_fs"] > fs:
target_fs = fs
else:
else:
target_fs = my_hash_dict["target_fs"]
df = df[(df == "POS").any(axis=1)]
df = df.reset_index()
Expand Down Expand Up @@ -372,7 +373,7 @@ def preprocess_df(df):
# CREATE QUERY SETS
# obtain file specific frame_shift and save to meta.csv

segment_length_here = min(min_segment_lengths["POS"], MAX_SEGMENT_LENGTH)
segment_length_here = min(min_segment_lengths["POS"], max_segment_length)
frame_shift = np.round(segment_length_here / tensor_length * 1000)
frame_shift = 1 if frame_shift < 1 else frame_shift
print(file_name)
Expand Down Expand Up @@ -566,7 +567,6 @@ def preprocess_df(df):
type=str,
)


# get input
cli_args = parser.parse_args()

Expand All @@ -578,19 +578,28 @@ def preprocess_df(df):
cfg = yaml.load(f, Loader=FullLoader)

# Check values in config file
if not (cfg["data"]["status"]=="train" or cfg["data"]["status"]=="validate" or cfg["data"]["status"]=="test"):
raise Exception("ERROR: "+ str(cli_args.config) + ": Accepted values for 'status' are 'train', 'validate', or 'test'. Received '" + str(cfg["data"]["status"]) + "'.")

if not (
cfg["data"]["status"] == "train"
or cfg["data"]["status"] == "validate"
or cfg["data"]["status"] == "test"
):
raise Exception(
"ERROR: "
+ str(cli_args.config)
+ ": Accepted values for 'status' are 'train', 'validate', or 'test'. Received '"
+ str(cfg["data"]["status"])
+ "'."
)

# Select 'set_type' depending on chosen status
if cfg["data"]["status"]=="train":
if cfg["data"]["status"] == "train":
cfg["data"]["set_type"] = "Training_Set"

elif cfg["data"]["status"]=="validate":
elif cfg["data"]["status"] == "validate":
cfg["data"]["set_type"] = "Validation_Set"

else:
cfg["data"]["set_type"] = "Evaluation_Set"


prepare_training_val_data(
cfg["data"]["status"],
Expand All @@ -604,4 +613,5 @@ def preprocess_df(df):
cfg["data"]["target_fs"],
cfg["data"]["overlap"],
cfg["data"]["num_mel_bins"],
cfg["data"]["max_segment_length"],
)
5 changes: 3 additions & 2 deletions datamodules/DCASEDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def __init__(
n_subsample: int = 1,
overlap: float = 0.5,
num_mel_bins: int = 128,

**kwargs,
max_segment_length: float = 1.0**kwargs,
):
super().__init__(**kwargs)
self.n_task_train = n_task_train
Expand All @@ -123,6 +122,7 @@ def __init__(
self.n_subsample = n_subsample
self.overlap = overlap
self.num_mel_bins = num_mel_bins
self.max_segment_length = max_segment_length
self.setup()

def setup(self, stage=None):
Expand All @@ -136,6 +136,7 @@ def setup(self, stage=None):
"set_type": self.set_type,
"overlap": self.overlap,
"num_mel_bins": self.num_mel_bins,
"max_segment_length": self.max_segment_length,
}
if self.resample:
my_hash_dict["target_fs"] = self.target_fs
Expand Down
74 changes: 52 additions & 22 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def train_model(
try:
pretrained_model = ProtoBEATsModel.load_from_checkpoint(pretrained_model)
except KeyError:
print("Failed to load the pretrained model. Please check the checkpoint file.")
print(
"Failed to load the pretrained model. Please check the checkpoint file."
)
return None


# train the model
trainer.fit(model, datamodule=datamodule_class)

Expand Down Expand Up @@ -296,7 +297,11 @@ def main(
# Train the model with the support data
print("[INFO] TRAINING THE MODEL FOR {}".format(filename))

model = training(cfg["model"]["model_path"], custom_dcasedatamodule, max_epoch=cfg["trainer"]["max_epochs"])
model = training(
cfg["model"]["model_path"],
custom_dcasedatamodule,
max_epoch=cfg["trainer"]["max_epochs"],
)

# Get the prototypes coordinates
a = custom_dcasedatamodule.test_dataloader()
Expand Down Expand Up @@ -364,7 +369,9 @@ def main(
# Train the model with the support data
print("[INFO] TRAINING THE MODEL FOR {}".format(filename))

model = training(cfg["model"]["model_path"], custom_dcasedatamodule, max_epoch=1)
model = training(
cfg["model"]["model_path"], custom_dcasedatamodule, max_epoch=1
)

# Get the prototypes coordinates
a = custom_dcasedatamodule.test_dataloader()
Expand Down Expand Up @@ -424,7 +431,9 @@ def main(
)

result_POS_merged = merge_preds(
df=result_POS, tolerence=cfg["tolerance"], tensor_length=cfg["data"]["tensor_length"]
df=result_POS,
tolerence=cfg["tolerance"],
tensor_length=cfg["data"]["tensor_length"],
)

# Add the filename
Expand Down Expand Up @@ -478,19 +487,43 @@ def write_wav(
# Expand the dimensions
gt_labels = np.repeat(
np.squeeze(gt_labels, axis=1).T,
int(cfg["data"]["tensor_length"] * cfg["data"]["overlap"] * target_fs * frame_shift / 1000),
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)
pred_labels = np.repeat(
pred_labels.T,
int(cfg["data"]["tensor_length"] * cfg["data"]["overlap"] * target_fs * frame_shift / 1000),
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)
distances_to_pos = np.repeat(
distances_to_pos.T,
int(cfg["data"]["tensor_length"] * cfg["data"]["overlap"] * target_fs * frame_shift / 1000),
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)
z_scores_pos = np.repeat(
z_scores_pos.T,
int(cfg["data"]["tensor_length"] * cfg["data"]["overlap"] * target_fs * frame_shift / 1000),
int(
cfg["data"]["tensor_length"]
* cfg["data"]["overlap"]
* target_fs
* frame_shift
/ 1000
),
)

# pad with zeros
Expand Down Expand Up @@ -575,10 +608,10 @@ def write_wav(
version_name = os.path.basename(version_path)

# Select 'set_type' depending on chosen status
if cfg["data"]["status"]=="train":
if cfg["data"]["status"] == "train":
cfg["data"]["set_type"] = "Training_Set"

elif cfg["data"]["status"]=="validate":
elif cfg["data"]["status"] == "validate":
cfg["data"]["set_type"] = "Validation_Set"

else:
Expand All @@ -593,7 +626,8 @@ def write_wav(
"tensor_length": cfg["data"]["tensor_length"],
"set_type": cfg["data"]["set_type"],
"overlap": cfg["data"]["overlap"],
"num_mel_bins": cfg["data"]["num_mel_bins"]
"num_mel_bins": cfg["data"]["num_mel_bins"],
"max_segment_length": cfg["data"]["max_segment_length"],
}
if cfg["data"]["resample"]:
my_hash_dict["target_fs"] = cfg["data"]["target_fs"]
Expand All @@ -617,11 +651,11 @@ def write_wav(
"support_labels_*.npy",
)
query_data_path = os.path.join(
"/data/DCASEfewshot",
cfg["data"]["status"],
hash_dir_name,
"audio",
"query_data_*.npz"
"/data/DCASEfewshot",
cfg["data"]["status"],
hash_dir_name,
"audio",
"query_data_*.npz",
)
query_labels_path = os.path.join(
"/data/DCASEfewshot",
Expand All @@ -631,11 +665,7 @@ def write_wav(
"query_labels_*.npy",
)
meta_df_path = os.path.join(
"/data/DCASEfewshot",
cfg["data"]["status"],
hash_dir_name,
"audio",
"meta.csv"
"/data/DCASEfewshot", cfg["data"]["status"], hash_dir_name, "audio", "meta.csv"
)

# set target path
Expand Down

0 comments on commit 47e5d51

Please sign in to comment.