Skip to content

Commit

Permalink
Add scripts and configs for hyperparameter tuning (#675)
Browse files Browse the repository at this point in the history
* Merge py file changes from benchmark-algs

* Clean parallel script

* Undo the changes from #653 to the dagger benchmark config files.

This change just made some error messages go away indicating the missing imitation.algorithms.dagger.ExponentialBetaSchedule but it did not fix the root cause.

* Improve readability and interpretability of benchmarking tests.

* Add pxponential beta scheduler for dagger

* Ignore coverage for unknown algorithms.

* Cleanup and extend tests for beta schedules in dagger.

* Add optuna to dependencies

* Fix test case

* Clean up the scripts

* Remove reporter(done) since mean_return is reported by the runs

* Add beta_schedule parameter to dagger script

* Update config policy kwargs

* Changes from review

* Fix errors with some configs

* Updates based on review

* Change metric everywhere

* Separate tuning code from parallel.py

* Fix docstring

* Removing resume option as it is getting tricky to correctly implement

* Minor fixes

* Updates from review

* fix lint error

* Add documentation for using the tuning script

* Fix lint error

* Updates from the review

* Fix file name test errors

* Add tune_run_kwargs in parallel script

* Fix test errors

* Fix test

* Fix lint

* Updates from review

* Simplify few lines of code

* Updates from review

* Fix test

* Revert "Fix test"

This reverts commit 8b55134.

* Fix test

* Convert Dict to Mapping in input argument

* Ignore coverage in script configurations.

* Pin huggingface_sb3 version.

* Update to the newest seals environment versions.

* Push gymnasium dependency to 0.29 to ensure mujoco envs work.

* Incorporate review comments

* Fix test errors

* Move benchmarking/ to scripts/ and add named configs for tuned hyperparams

* Bump cache version & remove unnecessary files

* Include tuned hyperparam json files in package data

* Update storage hash

* Update search space of bc

* update benchmark and hyper parameter tuning readme

* Update README.md

* Incorporate reviewer's comments in benchmarking readme

* Update gymnasium version and render mode in eval policy

* Fix error

* Update commands.py hex strings

---------

Co-authored-by: Maximilian Ernestus <maximilian@ernestus.de>
Co-authored-by: ZiyueWang25 <wfuymu@gmail.com>
  • Loading branch information
3 people authored Oct 10, 2023
1 parent f099c33 commit 20366b0
Show file tree
Hide file tree
Showing 43 changed files with 1,023 additions and 264 deletions.
33 changes: 28 additions & 5 deletions benchmarking/README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,42 @@
# Benchmarking imitation

This directory contains sacred configuration files for benchmarking imitation's algorithms. For v0.3.2, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://www.rocamonde.com/publication/gleave-imitation-2022/).
The `src/imitation/scripts/config/tuned_hps` directory provides the tuned hyperparameter configs for benchmarking imitation. For v0.4.0, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://arxiv.org/abs/2211.11972).

Configuration files can be loaded either from the CLI or from the Python API. The examples below assume that your current working directory is the root of the `imitation` repository. This is not necessarily the case and you should adjust your paths accordingly.
Configuration files can be loaded either from the CLI or from the Python API.

## CLI

```bash
python -m imitation.scripts.<train_script> <algo> with benchmarking/<config_name>.json
python -m imitation.scripts.<train_script> <algo> with <algo>_<env>
```
`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`.
`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`. The `env` can be either of `seals_ant`, `seals_half_cheetah`, `seals_hopper`, `seals_swimmer`, or `seals_walker`. The hyperparameters for other environments are not tuned yet. You may be able to get reasonable performance by using hyperparameters tuned for a similar environment; alternatively, you can tune the hyperparameters using the `tuning` script.

## Python

```python
...
ex.add_config('benchmarking/<config_name>.json')
from imitation.scripts.<train_script> import <train_ex>
<train_ex>.run(command_name="<algo>", named_configs=["<algo>_<env>"])
```

# Tuning Hyperparameters

The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
the search space defined in the `scripts/config/tuning.py`.

The tuning script proceeds in two phases:
1. Tune the hyperparameters using the search space provided.
2. Re-evaluate the best hyperparameter config found in the first phase based on the maximum mean return on a separate set of seeds. Report the mean and standard deviation of these trials.

To use it with the default search space:
```bash
python -m imitation.scripts.tuning with <algo> 'parallel_run_config.base_named_configs=["<env>"]'
```

In this command:
- `<algo>` provides the default search space and settings for the specific algorithm, which is defined in the `scripts/config/tuning.py`
- `<env>` sets the environment to tune the algorithm in. They are defined in the algo-specifc `scripts/config/train_[adversarial|imitation|preference_comparisons|rl].py` files. For the already tuned environments, use the `<algo>_<env>` named configs here.

See the documentation of `scripts/tuning.py` and `scripts/parallel.py` for many other arguments that can be
provided through the command line to change the tuning behavior.
2 changes: 1 addition & 1 deletion benchmarking/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def clean_config_file(file: pathlib.Path, write_path: pathlib.Path, /) -> None:

remove_empty_dicts(config)
# files are of the format
# /path/to/file/example_<algo>_<env>_best_hp_eval/<other_info>/sacred/1/config.json
# /path/to/file/<algo>_<env>_best_hp_eval/<other_info>/sacred/1/config.json
# we want to write to /<write_path>/<algo>_<env>.json
with open(write_path / f"{file.parents[3].name}.json", "w") as f:
json.dump(config, f, indent=4)
Expand Down
31 changes: 17 additions & 14 deletions experiments/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@
For example, we can run:
TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps
python commands.py \
--name=run0 \
--cfg_pattern=../benchmarking/*ai*_seals_walker_*.json \
--cfg_pattern=$TUNED_HPS_DIR/*ai*_seals_walker_*.json \
--output_dir=output
And get the following commands printed out:
python -m imitation.scripts.train_adversarial airl \
--capture=sys --name=run0 \
--file_storage=output/sacred/$USER-cmd-run0-airl-0-a3531726 \
with ../benchmarking/example_airl_seals_walker_best_hp_eval.json \
with ../src/imitation/scripts/config/tuned_hps/airl_seals_walker_best_hp_eval.json \
seed=0 logging.log_root=output
python -m imitation.scripts.train_adversarial gail \
--capture=sys --name=run0 \
--file_storage=output/sacred/$USER-cmd-run0-gail-0-a1ec171b \
with ../benchmarking/example_gail_seals_walker_best_hp_eval.json \
with $TUNED_HPS_DIR/gail_seals_walker_best_hp_eval.json \
seed=0 logging.log_root=output
We can execute commands in parallel by piping them to GNU parallel:
Expand All @@ -40,9 +41,10 @@
For example, we can run:
TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps
python commands.py \
--name=run0 \
--cfg_pattern=../benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \
--cfg_pattern=$TUNED_HPS_DIR/bc_seals_half_cheetah_best_hp_eval.json \
--output_dir=/data/output \
--remote
Expand All @@ -51,8 +53,9 @@
ctl job run --name $USER-cmd-run0-bc-0-72cb1df3 \
--command "python -m imitation.scripts.train_imitation bc \
--capture=sys --name=run0 \
--file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 \
with /data/imitation/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \
--file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 with \
/data/imitation/src/imitation/scripts/config/tuned_hps/
bc_seals_half_cheetah_best_hp_eval.json \
seed=0 logging.log_root=/data/output" \
--container hacobe/devbox:imitation \
--login --force-pull --never-restart --gpu 0 --shared-host-dir-mount /data
Expand Down Expand Up @@ -85,7 +88,7 @@ def _get_algo_name(cfg_file: str) -> str:
"""Get the algorithm name from the given config filename."""
algo_names = set()
for key in _ALGO_NAME_TO_SCRIPT_NAME:
if cfg_file.find("_" + key + "_") != -1:
if cfg_file.find(key + "_") != -1:
algo_names.add(key)

if len(algo_names) == 0:
Expand Down Expand Up @@ -121,7 +124,7 @@ def main(args: argparse.Namespace) -> None:
else:
cfg_path = os.path.join(args.remote_cfg_dir, cfg_file)

cfg_id = _get_cfg_id(cfg_path)
cfg_id = _get_cfg_id(cfg_file)

for seed in args.seeds:
cmd_id = _CMD_ID_TEMPLATE.format(
Expand Down Expand Up @@ -177,19 +180,19 @@ def parse() -> argparse.Namespace:
parser.add_argument(
"--cfg_pattern",
type=str,
default="example_bc_seals_half_cheetah_best_hp_eval.json",
default="bc_seals_half_cheetah_best_hp_eval.json",
help="""Generate a command for every file that matches this glob pattern. \
Each matching file should be a config file that has its algorithm name \
(bc, dagger, airl or gail) bookended by underscores in the filename. \
If the --remote flag is enabled, then generate a command for every file in the \
--remote_cfg_dir directory that has the same filename as a file that matches this \
glob pattern. E.g., suppose the current, local working directory is 'foo' and \
the subdirectory 'foo/bar' contains the config files 'example_bc_best.json' and \
'example_dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \
will return ['bar/example_bc_best.json', 'bar/example_dagger_best.json']. \
the subdirectory 'foo/bar' contains the config files 'bc_best.json' and \
'dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \
will return ['bar/bc_best.json', 'bar/dagger_best.json']. \
If the --remote flag is enabled, 'bar' will be replaced with `remote_cfg_dir` and \
commands will be created for the following configs: \
[`remote_cfg_dir`/example_bc_best.json, `remote_cfg_dir`/example_dagger_best.json] \
[`remote_cfg_dir`/bc_best.json, `remote_cfg_dir`/dagger_best.json] \
Why not just supply the pattern '`remote_cfg_dir`/*.json' directly? \
Because the `remote_cfg_dir` directory may not exist on the local machine.""",
)
Expand Down Expand Up @@ -220,7 +223,7 @@ def parse() -> argparse.Namespace:
parser.add_argument(
"--remote_cfg_dir",
type=str,
default="/data/imitation/benchmarking",
default="/data/imitation/src/imitation/scripts/config/tuned_hps",
help="""Path to a directory storing config files \
accessible from each container. """,
)
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ source = imitation
include=
src/*
tests/*
omit =
src/imitation/scripts/config/*

[coverage:report]
exclude_lines =
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
python_requires=">=3.8.0",
packages=find_packages("src"),
package_dir={"": "src"},
package_data={"imitation": ["py.typed", "envs/examples/airl_envs/assets/*.xml"]},
package_data={"imitation": ["py.typed", "scripts/config/tuned_hps/*.json"]},
# Note: while we are strict with our test and doc requirement versions, we try to
# impose as little restrictions on the install requirements as possible. Try to
# encode only known incompatibilities here. This prevents nasty dependency issues
Expand All @@ -200,6 +200,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"sacred>=0.8.4",
"tensorboard>=1.14",
"huggingface_sb3~=3.0",
"optuna>=3.0.1",
"datasets>=2.8.0",
],
tests_require=TESTS_REQUIRE,
Expand Down
37 changes: 23 additions & 14 deletions src/imitation/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,38 +262,47 @@ def analyze_imitation(
csv_output_path: If provided, then save a CSV output file to this path.
tex_output_path: If provided, then save a LaTeX-format table to this path.
print_table: If True, then print the dataframe to stdout.
table_verbosity: Increasing levels of verbosity, from 0 to 2, increase the
number of columns in the table.
table_verbosity: Increasing levels of verbosity, from 0 to 3, increase the
number of columns in the table. Level 3 prints all of the columns available.
Returns:
The DataFrame generated from the Sacred logs.
"""
table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity)
# Get column names for which we have get value using make_entry_fn
# These are same across Level 2 & 3. In Level 3, we additionally add remaining
# config columns.
table_entry_fns_subset = _get_table_entry_fns_subset(min(table_verbosity, 2))

rows = []
output_table = pd.DataFrame()
for sd in _gather_sacred_dicts():
row = {}
if table_verbosity == 3:
# gets all config columns
row = pd.json_normalize(sd.config)
else:
# create an empty dataframe with a single row
row = pd.DataFrame(index=[0])

for col_name, make_entry_fn in table_entry_fns_subset.items():
row[col_name] = make_entry_fn(sd)
rows.append(row)

df = pd.DataFrame(rows)
if len(df) > 0:
df.sort_values(by=["algo", "env_name"], inplace=True)
output_table = pd.concat([output_table, row])

if len(output_table) > 0:
output_table.sort_values(by=["algo", "env_name"], inplace=True)

display_options = dict(index=False)
display_options: Mapping[str, Any] = dict(index=False)
if csv_output_path is not None:
df.to_csv(csv_output_path, **display_options)
output_table.to_csv(csv_output_path, **display_options)
print(f"Wrote CSV file to {csv_output_path}")
if tex_output_path is not None:
s: str = df.to_latex(**display_options)
s: str = output_table.to_latex(**display_options)
with open(tex_output_path, "w") as f:
f.write(s)
print(f"Wrote TeX file to {tex_output_path}")

if print_table:
print(df.to_string(**display_options))
return df
print(output_table.to_string(**display_options))
return output_table


def _make_return_summary(stats: dict, prefix="") -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/config/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def config():
tex_output_path = None # Write LaTex output to this path
print_table = True # Set to True to print analysis to stdout
split_str = "," # str used to split source_dir_str into multiple source dirs
table_verbosity = 1 # Choose from 0, 1, or 2
table_verbosity = 1 # Choose from 0, 1, 2 or 3
source_dirs = None


Expand Down
79 changes: 13 additions & 66 deletions src/imitation/scripts/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
`@parallel_ex.named_config` to define a new parallel experiment.
Adding custom named configs is necessary because the CLI interface can't add
search spaces to the config like `"seed": tune.grid_search([0, 1, 2, 3])`.
search spaces to the config like `"seed": tune.choice([0, 1, 2, 3])`.
For tuning hyperparameters of an algorithm on a given environment,
check out the imitation/scripts/tuning.py script.
"""

import numpy as np
Expand All @@ -31,19 +34,10 @@ def config():
"config_updates": {},
} # `config` argument to `ray.tune.run(trainable, config)`

local_dir = None # `local_dir` arg for `ray.tune.run`
upload_dir = None # `upload_dir` arg for `ray.tune.run`
n_seeds = 3 # Number of seeds to search over by default


@parallel_ex.config
def seeds(n_seeds):
search_space = {"config_updates": {"seed": tune.grid_search(list(range(n_seeds)))}}


@parallel_ex.named_config
def s3():
upload_dir = "s3://shwang-chai/private"
num_samples = 1 # Number of samples per grid search configuration
repeat = 1 # Number of times to repeat a sampled configuration
experiment_checkpoint_path = "" # Path to checkpoint of experiment
tune_run_kwargs = {} # Additional kwargs to pass to `tune.run`


# Debug named configs
Expand All @@ -58,12 +52,12 @@ def generate_test_data():
"""
sacred_ex_name = "train_rl"
run_name = "TEST"
n_seeds = 1
repeat = 1
search_space = {
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(
"learning_rate": tune.choice(
[3e-4 * x for x in (1 / 3, 1 / 2)],
),
},
Expand All @@ -86,63 +80,16 @@ def generate_test_data():
def example_cartpole_rl():
sacred_ex_name = "train_rl"
run_name = "example-cartpole"
n_seeds = 2
repeat = 2
search_space = {
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.grid_search([16, 32, 64]),
"learning_rate": tune.choice(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.choice([16, 32, 64]),
},
},
},
}
base_named_configs = ["cartpole"]
resources_per_trial = dict(cpu=4)


EASY_ENVS = ["cartpole", "pendulum", "mountain_car"]


@parallel_ex.named_config
def example_rl_easy():
sacred_ex_name = "train_rl"
run_name = "example-rl-easy"
n_seeds = 2
search_space = {
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.grid_search([16, 32, 64]),
},
},
},
}
resources_per_trial = dict(cpu=4)


@parallel_ex.named_config
def example_gail_easy():
sacred_ex_name = "train_adversarial"
run_name = "example-gail-easy"
n_seeds = 1
search_space = {
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
"config_updates": {
"init_trainer_kwargs": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(
np.logspace(3e-6, 1e-1, num=3),
),
"nminibatches": tune.grid_search([16, 32, 64]),
},
},
},
},
}
search_space = {
"command_name": "gail",
}
Loading

0 comments on commit 20366b0

Please sign in to comment.