Skip to content

Commit

Permalink
Merge pull request #9 from NINAnor/filippo_dev
Browse files Browse the repository at this point in the history
[FIX] Handling command line inputs, status, and set_type
  • Loading branch information
femke-sintef authored Oct 5, 2023
2 parents 5f7ef23 + 668ec74 commit 4798fc4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
5 changes: 2 additions & 3 deletions CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ data:
n_query: 10
overlap: 0.5
n_subsample: 1 # ask Femke what this stands for
status: validate # train or validate or evaluate
set_type: Validation_Set # Training_Set or Validation_Set or Evaluation_Set
status: train # train or validate or evaluate

#################################
# PARAMETERS FOR MODEL TRAINING #
Expand All @@ -37,7 +36,7 @@ trainer:
model:
distance: euclidean # other option is mahalanobis
lr: 1.0e-05
model_path: /data/lightning_logs/version_0/checkpoints/epoch=14-step=1500.ckpt # /data/BEATs/BEATs_iter3_plus_AS2M.pt # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt
model_path: /data/BEATs/BEATs_iter3_plus_AS2M.pt # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt


##################################################################
Expand Down
28 changes: 16 additions & 12 deletions data_utils/DCASEfewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,19 +560,8 @@ def preprocess_df(df):
)


# check input
# get input
cli_args = parser.parse_args()
assert (
cli_args.status == "validate"
or cli_args.status == "train"
or cli_args.status == "test"
)
if cli_args.status == "validate":
assert cli_args.set_type == "Validation_Set"
elif cli_args.status == "train":
assert cli_args.set_type == "Training_Set"
elif cli_args.status == "test":
assert cli_args.set_type == "Evaluation_Set"

# Open the config file
import yaml
Expand All @@ -581,6 +570,21 @@ def preprocess_df(df):
with open(cli_args.config) as f:
cfg = yaml.load(f, Loader=FullLoader)

# Check values in config file
if not (cfg["data"]["status"]=="train" or cfg["data"]["status"]=="validate" or cfg["data"]["status"]=="test"):
raise Exception("ERROR: "+ str(cli_args.config) + ": Accepted values for 'status' are 'train', 'validate', or 'test'. Received '" + str(cfg["data"]["status"]) + "'.")

# Select 'set_type' depending on chosen status
if cfg["data"]["status"]=="train":
cfg["data"]["set_type"] = "Training_Set"

elif cfg["data"]["status"]=="validate":
cfg["data"]["set_type"] = "Validation_Set"

else:
cfg["data"]["set_type"] = "Evaluation_Set"


prepare_training_val_data(
cfg["data"]["status"],
cfg["data"]["set_type"],
Expand Down

0 comments on commit 4798fc4

Please sign in to comment.