Skip to content

Commit

Permalink
[UPDATE] README + small changes in the .sh script to facilitate repro…
Browse files Browse the repository at this point in the history
…ducibility
  • Loading branch information
BenCretois committed Jan 9, 2024
1 parent 4913ff4 commit f24030e
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 44 deletions.
50 changes: 26 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Because of the duration of the preprocessing, we save the preprocessed files as
./preprocess_data.sh /BASE/FOLDER
```

The script will create a new folder `DCASEfewshot` containing three subfolders (`train`, `validate` and `evaluate`). Each of these folder contains the processed data in the form of `numpy arrays`.
The script will create a new folder `DCASEfewshot` containing three subfolders (`train`, `validate` and `evaluate`). Each of these folder **contains a subfolder with a hash as a name**. **The hash has been created based on the processing parameters**. The processed data in the form of `numpy arrays`.

:black_nib: You can change the parameters for preprocessing the data in the [CONFIG.yaml file](/CONFIG.yaml)

Expand All @@ -70,13 +70,15 @@ The training script should create a `log` folder in the base folder (`lightning_

:black_nib: Update the `status` parameter of the [CONFIG.yaml file](/CONFIG.yaml) to the dataset you want to use the model on. Change `status` to either **validate** or **evaluate**.

:black_nib: Also update the `model_path` in the [CONFIG.yaml file](/CONFIG.yaml) to the checkpoints (`ckpt`) that has been trained in the previous step (stored in `lightning_logs`)

To run the prediction use the script `test_model`.

```bash
./test_model.sh /BASE/FOLDER
```

`test_model.sh` creates a result file `eval_out.csv` in the `BASE/FOLDER` containing all the detections made by the model.
`test_model.sh` creates a result file `eval_out.csv` in the folder containing the processed `validation` data. **The full path is printed in the console**

Note that there are other advanced options. For instance, if `--wav_save` is specified, the script will also return a `.wav` file for all files containing additional channels: the ground truth labels, the predicted labels, the distance to the POS prototype and finally the p-values. The `.wav` file can be opened in [Audacity](https://www.audacityteam.org/) to be inspected more closely.

Expand All @@ -85,33 +87,33 @@ Note that there are other advanced options. For instance, if `--wav_save` is spe
Once the `eval_out.csv` has been created, it is possible to get the results for our approach. Note that the metrics can only be computed for the `Validation_Set` as it contains all ground truth labels as opposed to the `Evaluation_Set` for which only the 5 first samples of the POS class are labelled.

```bash
./compute_metrics.sh /BASE/FOLDER
./compute_metrics.sh /BASE/FOLDER /PATH/TO/eval_out.csv
```

Here are the results we obtain using our pipeline described in our [Technical Report](https://dcase.community/documents/challenge2023/technical_reports/DCASE2023_Gelderblom_SINTEF_t5.pdf)

```
Evaluation for: BEATs VAL
BUK1_20181011_001004.wav {'TP': 15, 'FP': 35, 'FN': 16, 'total_n_pos_events': 31}
BUK1_20181013_023504.wav {'TP': 2, 'FP': 258, 'FN': 22, 'total_n_pos_events': 24}
BUK4_20161011_000804.wav {'TP': 1, 'FP': 30, 'FN': 46, 'total_n_pos_events': 47}
BUK4_20171022_004304a.wav {'TP': 7, 'FP': 17, 'FN': 10, 'total_n_pos_events': 17}
BUK5_20161101_002104a.wav {'TP': 31, 'FP': 7, 'FN': 57, 'total_n_pos_events': 88}
BUK5_20180921_015906a.wav {'TP': 4, 'FP': 24, 'FN': 19, 'total_n_pos_events': 23}
ME1.wav {'TP': 9, 'FP': 18, 'FN': 2, 'total_n_pos_events': 11}
ME2.wav {'TP': 41, 'FP': 27, 'FN': 0, 'total_n_pos_events': 41}
R4_cleaned recording_13-10-17.wav {'TP': 19, 'FP': 14, 'FN': 0, 'total_n_pos_events': 19}
R4_cleaned recording_16-10-17.wav {'TP': 30, 'FP': 8, 'FN': 0, 'total_n_pos_events': 30}
R4_cleaned recording_17-10-17.wav {'TP': 36, 'FP': 9, 'FN': 0, 'total_n_pos_events': 36}
R4_cleaned recording_TEL_19-10-17.wav {'TP': 52, 'FP': 12, 'FN': 2, 'total_n_pos_events': 54}
R4_cleaned recording_TEL_20-10-17.wav {'TP': 64, 'FP': 8, 'FN': 0, 'total_n_pos_events': 64}
R4_cleaned recording_TEL_23-10-17.wav {'TP': 84, 'FP': 8, 'FN': 0, 'total_n_pos_events': 84}
R4_cleaned recording_TEL_24-10-17.wav {'TP': 99, 'FP': 14, 'FN': 0, 'total_n_pos_events': 99}
R4_cleaned recording_TEL_25-10-17.wav {'TP': 99, 'FP': 9, 'FN': 0, 'total_n_pos_events': 99}
file_423_487.wav {'TP': 57, 'FP': 13, 'FN': 0, 'total_n_pos_events': 57}
file_97_113.wav {'TP': 11, 'FP': 27, 'FN': 109, 'total_n_pos_events': 120}
Overall_scores: {'precision': 0.2911279078300433, 'recall': 0.4938446186969832, 'fmeasure (percentage)': 36.631}
Evaluation for: TeamBEATs VAL
BUK1_20181011_001004.wav {'TP': 13, 'FP': 22, 'FN': 18, 'total_n_pos_events': 31}
BUK1_20181013_023504.wav {'TP': 3, 'FP': 206, 'FN': 21, 'total_n_pos_events': 24}
BUK4_20161011_000804.wav {'TP': 1, 'FP': 22, 'FN': 46, 'total_n_pos_events': 47}
BUK4_20171022_004304a.wav {'TP': 6, 'FP': 15, 'FN': 11, 'total_n_pos_events': 17}
BUK5_20161101_002104a.wav {'TP': 39, 'FP': 7, 'FN': 49, 'total_n_pos_events': 88}
BUK5_20180921_015906a.wav {'TP': 4, 'FP': 9, 'FN': 19, 'total_n_pos_events': 23}
ME1.wav {'TP': 10, 'FP': 21, 'FN': 1, 'total_n_pos_events': 11}
ME2.wav {'TP': 41, 'FP': 35, 'FN': 0, 'total_n_pos_events': 41}
R4_cleaned recording_13-10-17.wav {'TP': 19, 'FP': 23, 'FN': 0, 'total_n_pos_events': 19}
R4_cleaned recording_16-10-17.wav {'TP': 30, 'FP': 9, 'FN': 0, 'total_n_pos_events': 30}
R4_cleaned recording_17-10-17.wav {'TP': 36, 'FP': 6, 'FN': 0, 'total_n_pos_events': 36}
R4_cleaned recording_TEL_19-10-17.wav {'TP': 52, 'FP': 29, 'FN': 2, 'total_n_pos_events': 54}
R4_cleaned recording_TEL_20-10-17.wav {'TP': 64, 'FP': 10, 'FN': 0, 'total_n_pos_events': 64}
R4_cleaned recording_TEL_23-10-17.wav {'TP': 84, 'FP': 5, 'FN': 0, 'total_n_pos_events': 84}
R4_cleaned recording_TEL_24-10-17.wav {'TP': 99, 'FP': 13, 'FN': 0, 'total_n_pos_events': 99}
R4_cleaned recording_TEL_25-10-17.wav {'TP': 99, 'FP': 8, 'FN': 0, 'total_n_pos_events': 99}
file_423_487.wav {'TP': 57, 'FP': 7, 'FN': 0, 'total_n_pos_events': 57}
file_97_113.wav {'TP': 11, 'FP': 30, 'FN': 109, 'total_n_pos_events': 120}
Overall_scores: {'precision': 0.348444259075038, 'recall': 0.525770811091538, 'fmeasure (percentage)': 41.912}
```

## Taking the idea further:
Expand Down
8 changes: 5 additions & 3 deletions compute_metrics.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#!/bin/bash

BASE_FOLDER=$1
EVAL_CSV_PATH=$2
CONFIG_PATH="/app/CONFIG.yaml"

docker run -v $PWD:/app \
-v $BASE_FOLDER:/data \
-v $BASE_FOLDER:/data \
-v $EVAL_CSV_PATH:/eval_folder/eval_out.csv \
--gpus all \
beats \
poetry run python evaluation/evaluation_metrics/evaluation.py \
-pred_file /data/eval_out.csv \
poetry run python /app/evaluate/evaluation_metrics/evaluation.py \
-pred_file /eval_folder/eval_out.csv \
-ref_files_path /data/DCASE/Development_Set_annotations/Validation_Set \
-team_name TeamBEATs \
-dataset VAL \
Expand Down
25 changes: 14 additions & 11 deletions dcase_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mkdir -p $TARGET_FOLDER
############################
# Download the BEATs model #
############################
MODEL_FOLDER=$BASE_FOLDER/BEATs
MODEL_FOLDER=$BASE_FOLDER/model/BEATs
mkdir -p $MODEL_FOLDER
wget -O "$MODEL_FOLDER/BEATs_iter3_plus_AS2M.pt" "https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"

Expand All @@ -47,18 +47,21 @@ download_and_unzip "https://zenodo.org/record/6482837/files/Development_Set_anno
# Acoustic data
download_and_unzip "https://zenodo.org/record/6482837/files/Development_Set.zip?download=1" "$TARGET_FOLDER"

###############################
# Download the evaluation set #
###############################
mkdir -p "$TARGET_FOLDER/Development_Set/Evaluation_Set"
#####################################################################
# Download the evaluation set - OUTDATED AS THIS WAS FOR DCASE 2023 #
#####################################################################

download_and_unzip "https://zenodo.org/record/7879692/files/Annotations_only.zip?download=1" "$TARGET_FOLDER"
mv "$TARGET_FOLDER/Annotations_only" "$TARGET_FOLDER/Development_Set_annotations/Evaluation_Set"


#mkdir -p "$TARGET_FOLDER/Development_Set/Evaluation_Set"

#download_and_unzip "https://zenodo.org/record/7879692/files/Annotations_only.zip?download=1" "$TARGET_FOLDER"
#mv "$TARGET_FOLDER/Annotations_only" "$TARGET_FOLDER/Development_Set_annotations/Evaluation_Set"

# Acoustic data
for i in {1..3}
do
download_and_unzip "https://zenodo.org/record/7879692/files/eval_$i.zip?download=1" "$TARGET_FOLDER/Development_Set/Evaluation_Set"
done
#for i in {1..3}
#do
# download_and_unzip "https://zenodo.org/record/7879692/files/eval_$i.zip?download=1" "$TARGET_FOLDER/Development_Set/Evaluation_Set"
#done


17 changes: 16 additions & 1 deletion evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def train_model(

if pretrained_model:
# Load the pretrained model
pretrained_model = ProtoBEATsModel.load_from_checkpoint(pretrained_model)
try:
pretrained_model = ProtoBEATsModel.load_from_checkpoint(pretrained_model)
except KeyError:
print("Failed to load the pretrained model. Please check the checkpoint file.")
return None


# train the model
trainer.fit(model, datamodule=datamodule_class)
Expand Down Expand Up @@ -569,6 +574,16 @@ def write_wav(
training_config_path = os.path.join(version_path, "config.yaml")
version_name = os.path.basename(version_path)

# Select 'set_type' depending on chosen status
if cfg["data"]["status"]=="train":
cfg["data"]["set_type"] = "Training_Set"

elif cfg["data"]["status"]=="validate":
cfg["data"]["set_type"] = "Validation_Set"

else:
cfg["data"]["set_type"] = "Evaluation_Set"

# Get correct paths to dataset
my_hash_dict = {
"resample": cfg["data"]["resample"],
Expand Down
12 changes: 7 additions & 5 deletions train_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
BASE_FOLDER=$1
CONFIG_PATH="/app/CONFIG.yaml"

docker run -v $BASE_FOLDER:/data \
-v $PWD:/app \
--gpus all \
beats \
poetry run prototypicalbeats/trainer.py fit --config $CONFIG_PATH
# Check if BASE_FOLDER is not set or empty
if [ -z "$BASE_FOLDER" ]; then
echo "Error: BASE_FOLDER is not specified."
exit 1
fi

docker run -v $BASE_FOLDER:/data -v $PWD:/app --gpus all beats poetry run prototypicalbeats/trainer.py fit --config $CONFIG_PATH

0 comments on commit f24030e

Please sign in to comment.