Skip to content

Commit

Permalink
Add ability to return nan/null from script to mark failure
Browse files Browse the repository at this point in the history
When running a script if your function fails and you get a nan or null or some sort
you can now just pass that along to BOA instead of checking and marking your test as failed
BOA will mark it as failed for you.
  • Loading branch information
madeline-scyphers committed Mar 11, 2024
1 parent e87826f commit 54d5f4a
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 11 deletions.
21 changes: 10 additions & 11 deletions boa/async_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,19 @@ def run(config_path, scheduler_path, num_trials, experiment_dir=None):
if scheduler.experiment.fetch_data().df.empty:
trials = scheduler.experiment.trials
metrics = scheduler.experiment.metrics
for metric in metrics.keys():
scheduler.experiment.attach_data(
Data(
df=pd.DataFrame.from_records(
dict(
trial_index=list(trials.keys()),
arm_name=[f"{i}_0" for i in trials.keys()],
metric_name=metric,
mean=None,
sem=0.0,
)
scheduler.experiment.attach_data(
Data(
df=pd.DataFrame(
dict(
trial_index=[i for i in trials.keys() for m in metrics.keys()],
arm_name=[f"{i}_0" for i in trials.keys() for m in metrics.keys()],
metric_name=[m for i in trials.keys() for m in metrics.keys()],
mean=None,
sem=0.0,
)
)
)
)

scheduler.save_data(metrics_to_end=True, ax_kwargs=dict(always_include_field_columns=True))
return scheduler
Expand Down
14 changes: 14 additions & 0 deletions boa/metrics/modular_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ def fetch_trial_data(self, trial: Trial, **kwargs):
if self.wrapper
else {}
)
if isinstance(wrapper_kwargs, dict):
nan_checks = list(wrapper_kwargs.values())
elif isinstance(wrapper_kwargs, list):
nan_checks = wrapper_kwargs
else:
nan_checks = [wrapper_kwargs]
for elem in nan_checks:
if (
(isinstance(elem, str) and ("nan" == elem.lower() or "na" == elem.lower()))
or (isinstance(elem, float) and pd.isna(elem))
or (elem is None)
):
return Err(f"NaNs in Results for Trial {trial.index}, failing trial")

wrapper_kwargs = wrapper_kwargs if wrapper_kwargs is not None else {}
if wrapper_kwargs is not None and not isinstance(wrapper_kwargs, dict):
wrapper_kwargs = {"wrapper_args": wrapper_kwargs}
Expand Down
6 changes: 6 additions & 0 deletions tests/integration_tests/test_dunder_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,9 @@ def test_non_zero_exit_code_fails_trial():
with pytest.raises(FailureRateExceededError):
config_path = ROOT / "tests" / f"scripts/other_langs/r_failure_exit_code/config.yaml"
dunder_main.main(split_shell_command(f"--config-path {config_path} -td"), standalone_mode=False)


def test_return_nan_fails_trial():
with pytest.raises(FailureRateExceededError):
config_path = ROOT / "tests" / f"scripts/other_langs/r_failure_nan/config.yaml"
dunder_main.main(split_shell_command(f"--config-path {config_path} -td"), standalone_mode=False)
16 changes: 16 additions & 0 deletions tests/scripts/other_langs/r_failure_nan/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
objective:
metrics:
- name: metric
scheduler:
n_trials: 10

parameters:
x0:
'bounds': [ 0, 1 ]
'type': 'range'
'value_type': 'float'


script_options:
run_model: Rscript run_model.R
exp_name: "r_error_nan"
8 changes: 8 additions & 0 deletions tests/scripts/other_langs/r_failure_nan/run_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
library(jsonlite)
args <- commandArgs(trailingOnly=TRUE)
trial_dir <- args[length(args)]
out_data <- list(
metric=NaN
)
json_data <- toJSON(out_data, pretty = TRUE)
write(json_data, file.path(trial_dir, "output.json"))

0 comments on commit 54d5f4a

Please sign in to comment.