diff --git a/benchmarking/README.md b/benchmarking/README.md index 3f5114545..80cbc0247 100644 --- a/benchmarking/README.md +++ b/benchmarking/README.md @@ -1,19 +1,42 @@ # Benchmarking imitation -This directory contains sacred configuration files for benchmarking imitation's algorithms. For v0.3.2, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://www.rocamonde.com/publication/gleave-imitation-2022/). +The `src/imitation/scripts/config/tuned_hps` directory provides the tuned hyperparameter configs for benchmarking imitation. For v0.4.0, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://arxiv.org/abs/2211.11972). -Configuration files can be loaded either from the CLI or from the Python API. The examples below assume that your current working directory is the root of the `imitation` repository. This is not necessarily the case and you should adjust your paths accordingly. +Configuration files can be loaded either from the CLI or from the Python API. ## CLI ```bash -python -m imitation.scripts. with benchmarking/.json +python -m imitation.scripts. with _ ``` -`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`. +`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`. The `env` can be either of `seals_ant`, `seals_half_cheetah`, `seals_hopper`, `seals_swimmer`, or `seals_walker`. The hyperparameters for other environments are not tuned yet. You may be able to get reasonable performance by using hyperparameters tuned for a similar environment; alternatively, you can tune the hyperparameters using the `tuning` script. ## Python ```python ... -ex.add_config('benchmarking/.json') +from imitation.scripts. import +.run(command_name="", named_configs=["_"]) ``` + +# Tuning Hyperparameters + +The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`. +The benchmarking hyperparameter configs were generated by tuning the hyperparameters using +the search space defined in the `scripts/config/tuning.py`. + +The tuning script proceeds in two phases: +1. Tune the hyperparameters using the search space provided. +2. Re-evaluate the best hyperparameter config found in the first phase based on the maximum mean return on a separate set of seeds. Report the mean and standard deviation of these trials. + +To use it with the default search space: +```bash +python -m imitation.scripts.tuning with 'parallel_run_config.base_named_configs=[""]' +``` + +In this command: +- `` provides the default search space and settings for the specific algorithm, which is defined in the `scripts/config/tuning.py` +- `` sets the environment to tune the algorithm in. They are defined in the algo-specifc `scripts/config/train_[adversarial|imitation|preference_comparisons|rl].py` files. For the already tuned environments, use the `_` named configs here. + +See the documentation of `scripts/tuning.py` and `scripts/parallel.py` for many other arguments that can be +provided through the command line to change the tuning behavior. diff --git a/benchmarking/util.py b/benchmarking/util.py index 408f0d812..88416344d 100644 --- a/benchmarking/util.py +++ b/benchmarking/util.py @@ -79,7 +79,7 @@ def clean_config_file(file: pathlib.Path, write_path: pathlib.Path, /) -> None: remove_empty_dicts(config) # files are of the format - # /path/to/file/example___best_hp_eval//sacred/1/config.json + # /path/to/file/__best_hp_eval//sacred/1/config.json # we want to write to //_.json with open(write_path / f"{file.parents[3].name}.json", "w") as f: json.dump(config, f, indent=4) diff --git a/experiments/commands.py b/experiments/commands.py index 2ac737e06..a05db867c 100644 --- a/experiments/commands.py +++ b/experiments/commands.py @@ -12,9 +12,10 @@ For example, we can run: +TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps python commands.py \ --name=run0 \ - --cfg_pattern=../benchmarking/*ai*_seals_walker_*.json \ + --cfg_pattern=$TUNED_HPS_DIR/*ai*_seals_walker_*.json \ --output_dir=output And get the following commands printed out: @@ -22,13 +23,13 @@ python -m imitation.scripts.train_adversarial airl \ --capture=sys --name=run0 \ --file_storage=output/sacred/$USER-cmd-run0-airl-0-a3531726 \ - with ../benchmarking/example_airl_seals_walker_best_hp_eval.json \ + with ../src/imitation/scripts/config/tuned_hps/airl_seals_walker_best_hp_eval.json \ seed=0 logging.log_root=output python -m imitation.scripts.train_adversarial gail \ --capture=sys --name=run0 \ --file_storage=output/sacred/$USER-cmd-run0-gail-0-a1ec171b \ - with ../benchmarking/example_gail_seals_walker_best_hp_eval.json \ + with $TUNED_HPS_DIR/gail_seals_walker_best_hp_eval.json \ seed=0 logging.log_root=output We can execute commands in parallel by piping them to GNU parallel: @@ -40,9 +41,10 @@ For example, we can run: +TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps python commands.py \ --name=run0 \ - --cfg_pattern=../benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \ + --cfg_pattern=$TUNED_HPS_DIR/bc_seals_half_cheetah_best_hp_eval.json \ --output_dir=/data/output \ --remote @@ -51,8 +53,9 @@ ctl job run --name $USER-cmd-run0-bc-0-72cb1df3 \ --command "python -m imitation.scripts.train_imitation bc \ --capture=sys --name=run0 \ - --file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 \ - with /data/imitation/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \ + --file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 with \ + /data/imitation/src/imitation/scripts/config/tuned_hps/ + bc_seals_half_cheetah_best_hp_eval.json \ seed=0 logging.log_root=/data/output" \ --container hacobe/devbox:imitation \ --login --force-pull --never-restart --gpu 0 --shared-host-dir-mount /data @@ -85,7 +88,7 @@ def _get_algo_name(cfg_file: str) -> str: """Get the algorithm name from the given config filename.""" algo_names = set() for key in _ALGO_NAME_TO_SCRIPT_NAME: - if cfg_file.find("_" + key + "_") != -1: + if cfg_file.find(key + "_") != -1: algo_names.add(key) if len(algo_names) == 0: @@ -121,7 +124,7 @@ def main(args: argparse.Namespace) -> None: else: cfg_path = os.path.join(args.remote_cfg_dir, cfg_file) - cfg_id = _get_cfg_id(cfg_path) + cfg_id = _get_cfg_id(cfg_file) for seed in args.seeds: cmd_id = _CMD_ID_TEMPLATE.format( @@ -177,19 +180,19 @@ def parse() -> argparse.Namespace: parser.add_argument( "--cfg_pattern", type=str, - default="example_bc_seals_half_cheetah_best_hp_eval.json", + default="bc_seals_half_cheetah_best_hp_eval.json", help="""Generate a command for every file that matches this glob pattern. \ Each matching file should be a config file that has its algorithm name \ (bc, dagger, airl or gail) bookended by underscores in the filename. \ If the --remote flag is enabled, then generate a command for every file in the \ --remote_cfg_dir directory that has the same filename as a file that matches this \ glob pattern. E.g., suppose the current, local working directory is 'foo' and \ -the subdirectory 'foo/bar' contains the config files 'example_bc_best.json' and \ -'example_dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \ -will return ['bar/example_bc_best.json', 'bar/example_dagger_best.json']. \ +the subdirectory 'foo/bar' contains the config files 'bc_best.json' and \ +'dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \ +will return ['bar/bc_best.json', 'bar/dagger_best.json']. \ If the --remote flag is enabled, 'bar' will be replaced with `remote_cfg_dir` and \ commands will be created for the following configs: \ -[`remote_cfg_dir`/example_bc_best.json, `remote_cfg_dir`/example_dagger_best.json] \ +[`remote_cfg_dir`/bc_best.json, `remote_cfg_dir`/dagger_best.json] \ Why not just supply the pattern '`remote_cfg_dir`/*.json' directly? \ Because the `remote_cfg_dir` directory may not exist on the local machine.""", ) @@ -220,7 +223,7 @@ def parse() -> argparse.Namespace: parser.add_argument( "--remote_cfg_dir", type=str, - default="/data/imitation/benchmarking", + default="/data/imitation/src/imitation/scripts/config/tuned_hps", help="""Path to a directory storing config files \ accessible from each container. """, ) diff --git a/setup.cfg b/setup.cfg index dc06cb335..0a294baab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,8 @@ source = imitation include= src/* tests/* +omit = + src/imitation/scripts/config/* [coverage:report] exclude_lines = diff --git a/setup.py b/setup.py index 4b3349493..825902e89 100644 --- a/setup.py +++ b/setup.py @@ -182,7 +182,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: python_requires=">=3.8.0", packages=find_packages("src"), package_dir={"": "src"}, - package_data={"imitation": ["py.typed", "envs/examples/airl_envs/assets/*.xml"]}, + package_data={"imitation": ["py.typed", "scripts/config/tuned_hps/*.json"]}, # Note: while we are strict with our test and doc requirement versions, we try to # impose as little restrictions on the install requirements as possible. Try to # encode only known incompatibilities here. This prevents nasty dependency issues @@ -200,6 +200,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "sacred>=0.8.4", "tensorboard>=1.14", "huggingface_sb3~=3.0", + "optuna>=3.0.1", "datasets>=2.8.0", ], tests_require=TESTS_REQUIRE, diff --git a/src/imitation/scripts/analyze.py b/src/imitation/scripts/analyze.py index df8ad6b79..b63538f6d 100644 --- a/src/imitation/scripts/analyze.py +++ b/src/imitation/scripts/analyze.py @@ -262,38 +262,47 @@ def analyze_imitation( csv_output_path: If provided, then save a CSV output file to this path. tex_output_path: If provided, then save a LaTeX-format table to this path. print_table: If True, then print the dataframe to stdout. - table_verbosity: Increasing levels of verbosity, from 0 to 2, increase the - number of columns in the table. + table_verbosity: Increasing levels of verbosity, from 0 to 3, increase the + number of columns in the table. Level 3 prints all of the columns available. Returns: The DataFrame generated from the Sacred logs. """ - table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity) + # Get column names for which we have get value using make_entry_fn + # These are same across Level 2 & 3. In Level 3, we additionally add remaining + # config columns. + table_entry_fns_subset = _get_table_entry_fns_subset(min(table_verbosity, 2)) - rows = [] + output_table = pd.DataFrame() for sd in _gather_sacred_dicts(): - row = {} + if table_verbosity == 3: + # gets all config columns + row = pd.json_normalize(sd.config) + else: + # create an empty dataframe with a single row + row = pd.DataFrame(index=[0]) + for col_name, make_entry_fn in table_entry_fns_subset.items(): row[col_name] = make_entry_fn(sd) - rows.append(row) - df = pd.DataFrame(rows) - if len(df) > 0: - df.sort_values(by=["algo", "env_name"], inplace=True) + output_table = pd.concat([output_table, row]) + + if len(output_table) > 0: + output_table.sort_values(by=["algo", "env_name"], inplace=True) - display_options = dict(index=False) + display_options: Mapping[str, Any] = dict(index=False) if csv_output_path is not None: - df.to_csv(csv_output_path, **display_options) + output_table.to_csv(csv_output_path, **display_options) print(f"Wrote CSV file to {csv_output_path}") if tex_output_path is not None: - s: str = df.to_latex(**display_options) + s: str = output_table.to_latex(**display_options) with open(tex_output_path, "w") as f: f.write(s) print(f"Wrote TeX file to {tex_output_path}") if print_table: - print(df.to_string(**display_options)) - return df + print(output_table.to_string(**display_options)) + return output_table def _make_return_summary(stats: dict, prefix="") -> str: diff --git a/src/imitation/scripts/config/analyze.py b/src/imitation/scripts/config/analyze.py index 5213a875d..01cc2d035 100644 --- a/src/imitation/scripts/config/analyze.py +++ b/src/imitation/scripts/config/analyze.py @@ -18,7 +18,7 @@ def config(): tex_output_path = None # Write LaTex output to this path print_table = True # Set to True to print analysis to stdout split_str = "," # str used to split source_dir_str into multiple source dirs - table_verbosity = 1 # Choose from 0, 1, or 2 + table_verbosity = 1 # Choose from 0, 1, 2 or 3 source_dirs = None diff --git a/src/imitation/scripts/config/parallel.py b/src/imitation/scripts/config/parallel.py index 8ea76f522..62ebbd9e3 100644 --- a/src/imitation/scripts/config/parallel.py +++ b/src/imitation/scripts/config/parallel.py @@ -5,7 +5,10 @@ `@parallel_ex.named_config` to define a new parallel experiment. Adding custom named configs is necessary because the CLI interface can't add -search spaces to the config like `"seed": tune.grid_search([0, 1, 2, 3])`. +search spaces to the config like `"seed": tune.choice([0, 1, 2, 3])`. + +For tuning hyperparameters of an algorithm on a given environment, +check out the imitation/scripts/tuning.py script. """ import numpy as np @@ -31,19 +34,10 @@ def config(): "config_updates": {}, } # `config` argument to `ray.tune.run(trainable, config)` - local_dir = None # `local_dir` arg for `ray.tune.run` - upload_dir = None # `upload_dir` arg for `ray.tune.run` - n_seeds = 3 # Number of seeds to search over by default - - -@parallel_ex.config -def seeds(n_seeds): - search_space = {"config_updates": {"seed": tune.grid_search(list(range(n_seeds)))}} - - -@parallel_ex.named_config -def s3(): - upload_dir = "s3://shwang-chai/private" + num_samples = 1 # Number of samples per grid search configuration + repeat = 1 # Number of times to repeat a sampled configuration + experiment_checkpoint_path = "" # Path to checkpoint of experiment + tune_run_kwargs = {} # Additional kwargs to pass to `tune.run` # Debug named configs @@ -58,12 +52,12 @@ def generate_test_data(): """ sacred_ex_name = "train_rl" run_name = "TEST" - n_seeds = 1 + repeat = 1 search_space = { "config_updates": { "rl": { "rl_kwargs": { - "learning_rate": tune.grid_search( + "learning_rate": tune.choice( [3e-4 * x for x in (1 / 3, 1 / 2)], ), }, @@ -86,63 +80,16 @@ def generate_test_data(): def example_cartpole_rl(): sacred_ex_name = "train_rl" run_name = "example-cartpole" - n_seeds = 2 + repeat = 2 search_space = { "config_updates": { "rl": { "rl_kwargs": { - "learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)), - "nminibatches": tune.grid_search([16, 32, 64]), + "learning_rate": tune.choice(np.logspace(3e-6, 1e-1, num=3)), + "nminibatches": tune.choice([16, 32, 64]), }, }, }, } base_named_configs = ["cartpole"] resources_per_trial = dict(cpu=4) - - -EASY_ENVS = ["cartpole", "pendulum", "mountain_car"] - - -@parallel_ex.named_config -def example_rl_easy(): - sacred_ex_name = "train_rl" - run_name = "example-rl-easy" - n_seeds = 2 - search_space = { - "named_configs": tune.grid_search([[env] for env in EASY_ENVS]), - "config_updates": { - "rl": { - "rl_kwargs": { - "learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)), - "nminibatches": tune.grid_search([16, 32, 64]), - }, - }, - }, - } - resources_per_trial = dict(cpu=4) - - -@parallel_ex.named_config -def example_gail_easy(): - sacred_ex_name = "train_adversarial" - run_name = "example-gail-easy" - n_seeds = 1 - search_space = { - "named_configs": tune.grid_search([[env] for env in EASY_ENVS]), - "config_updates": { - "init_trainer_kwargs": { - "rl": { - "rl_kwargs": { - "learning_rate": tune.grid_search( - np.logspace(3e-6, 1e-1, num=3), - ), - "nminibatches": tune.grid_search([16, 32, 64]), - }, - }, - }, - }, - } - search_space = { - "command_name": "gail", - } diff --git a/src/imitation/scripts/config/train_adversarial.py b/src/imitation/scripts/config/train_adversarial.py index 55e6effec..ff32a551b 100644 --- a/src/imitation/scripts/config/train_adversarial.py +++ b/src/imitation/scripts/config/train_adversarial.py @@ -1,5 +1,7 @@ """Configuration for imitation.scripts.train_adversarial.""" +import pathlib + import sacred from imitation.rewards import reward_nets @@ -7,6 +9,10 @@ from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, reward, rl +# Note: All the hyperparameter configs in the file are of the tuned +# hyperparameters of the RL algorithm of the respective environment. +# Taken from imitation/scripts/config/train_rl.py + train_adversarial_ex = sacred.Experiment( "train_adversarial", ingredients=[ @@ -96,13 +102,6 @@ def pendulum(): # Standard MuJoCo Gym environment named configs -@train_adversarial_ex.named_config -def seals_ant(): - locals().update(**MUJOCO_SHARED_LOCALS) - locals().update(**ANT_SHARED_LOCALS) - environment = dict(gym_id="seals/Ant-v0") - - CHEETAH_SHARED_LOCALS = dict( MUJOCO_SHARED_LOCALS, rl=dict(batch_size=16384, rl_kwargs=dict(batch_size=1024)), @@ -137,18 +136,6 @@ def half_cheetah(): environment = dict(gym_id="HalfCheetah-v2") -@train_adversarial_ex.named_config -def seals_half_cheetah(): - locals().update(**CHEETAH_SHARED_LOCALS) - environment = dict(gym_id="seals/HalfCheetah-v0") - - -@train_adversarial_ex.named_config -def seals_hopper(): - locals().update(**MUJOCO_SHARED_LOCALS) - environment = dict(gym_id="seals/Hopper-v0") - - @train_adversarial_ex.named_config def seals_humanoid(): locals().update(**MUJOCO_SHARED_LOCALS) @@ -162,19 +149,6 @@ def reacher(): algorithm_kwargs = {"allow_variable_horizon": True} -@train_adversarial_ex.named_config -def seals_swimmer(): - locals().update(**MUJOCO_SHARED_LOCALS) - environment = dict(gym_id="seals/Swimmer-v0") - total_timesteps = int(2e6) - - -@train_adversarial_ex.named_config -def seals_walker(): - locals().update(**MUJOCO_SHARED_LOCALS) - environment = dict(gym_id="seals/Walker2d-v0") - - # Debug configs @@ -189,3 +163,23 @@ def fast(): demo_batch_size=1, n_disc_updates_per_round=4, ) + + +hyperparam_dir = pathlib.Path(__file__).absolute().parent / "tuned_hps" +tuned_alg_envs = [ + "airl_seals_ant", + "airl_seals_half_cheetah", + "airl_seals_hopper", + "airl_seals_swimmer", + "airl_seals_walker", + "gail_seals_ant", + "gail_seals_half_cheetah", + "gail_seals_hopper", + "gail_seals_swimmer", + "gail_seals_walker", +] + +for tuned_alg_env in tuned_alg_envs: + config_file = hyperparam_dir / f"{tuned_alg_env}_best_hp_eval.json" + assert config_file.is_file(), f"{config_file} does not exist" + train_adversarial_ex.add_named_config(tuned_alg_env, str(config_file)) diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index 88bc4888c..f151e768e 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -1,5 +1,7 @@ """Configuration settings for train_dagger, training DAgger from synthetic demos.""" +import pathlib + import sacred from imitation.scripts.ingredients import bc @@ -67,11 +69,6 @@ def ant(): environment = dict(gym_id="Ant-v2") -@train_imitation_ex.named_config -def seals_ant(): - environment = dict(gym_id="seals/Ant-v0") - - @train_imitation_ex.named_config def half_cheetah(): environment = dict(gym_id="HalfCheetah-v2") @@ -79,13 +76,6 @@ def half_cheetah(): dagger = dict(total_timesteps=60000) -@train_imitation_ex.named_config -def seals_half_cheetah(): - environment = dict(gym_id="seals/HalfCheetah-v0") - bc = dict(l2_weight=0.0) - dagger = dict(total_timesteps=60000) - - @train_imitation_ex.named_config def humanoid(): environment = dict(gym_id="Humanoid-v2") @@ -101,3 +91,23 @@ def fast(): dagger = dict(total_timesteps=50) bc = dict(train_kwargs=dict(n_batches=50)) sqil = dict(total_timesteps=50) + + +hyperparam_dir = pathlib.Path(__file__).absolute().parent / "tuned_hps" +tuned_alg_envs = [ + "bc_seals_ant", + "bc_seals_half_cheetah", + "bc_seals_hopper", + "bc_seals_swimmer", + "bc_seals_walker", + "dagger_seals_ant", + "dagger_seals_half_cheetah", + "dagger_seals_hopper", + "dagger_seals_swimmer", + "dagger_seals_walker", +] + +for tuned_alg_env in tuned_alg_envs: + config_file = hyperparam_dir / f"{tuned_alg_env}_best_hp_eval.json" + assert config_file.is_file(), f"{config_file} does not exist" + train_imitation_ex.add_named_config(tuned_alg_env, str(config_file)) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index 28890bf33..4d8531732 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -1,12 +1,17 @@ """Configuration for imitation.scripts.train_preference_comparisons.""" import sacred +from torch import nn from imitation.algorithms import preference_comparisons from imitation.scripts.ingredients import environment from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, reward, rl +# Note: All the hyperparameter configs in the file are of the tuned +# hyperparameters of the RL algorithm of the respective environment. +# Taken from imitation/scripts/config/train_rl.py + train_preference_comparisons_ex = sacred.Experiment( "train_preference_comparisons", ingredients=[ @@ -72,9 +77,22 @@ def cartpole(): @train_preference_comparisons_ex.named_config def seals_ant(): - locals().update(**MUJOCO_SHARED_LOCALS) - locals().update(**ANT_SHARED_LOCALS) environment = dict(gym_id="seals/Ant-v0") + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=16, + clip_range=0.3, + ent_coef=3.1441389214159857e-06, + gae_lambda=0.8, + gamma=0.995, + learning_rate=0.00017959211641976886, + max_grad_norm=0.9, + n_epochs=10, + # policy_kwargs are same as the defaults + vf_coef=0.4351450387648799, + ), + ) @train_preference_comparisons_ex.named_config @@ -84,10 +102,105 @@ def half_cheetah(): rl = dict(batch_size=16384, rl_kwargs=dict(batch_size=1024)) +@train_preference_comparisons_ex.named_config +def seals_half_cheetah(): + environment = dict(gym_id="seals/HalfCheetah-v0") + rl = dict( + batch_size=512, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=3.794797423594763e-06, + gae_lambda=0.95, + gamma=0.95, + learning_rate=0.0003286871805949382, + max_grad_norm=0.8, + n_epochs=5, + vf_coef=0.11483689492120866, + ), + ) + num_iterations = 50 + total_timesteps = 20000000 + + @train_preference_comparisons_ex.named_config def seals_hopper(): - locals().update(**MUJOCO_SHARED_LOCALS) environment = dict(gym_id="seals/Hopper-v0") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=512, + clip_range=0.1, + ent_coef=0.0010159833764878474, + gae_lambda=0.98, + gamma=0.995, + learning_rate=0.0003904770450788824, + max_grad_norm=0.9, + n_epochs=20, + vf_coef=0.20315938606555833, + ), + ) + + +@train_preference_comparisons_ex.named_config +def seals_swimmer(): + environment = dict(gym_id="seals/Swimmer-v0") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=5.167107294612664e-08, + gae_lambda=0.95, + gamma=0.999, + learning_rate=0.000414936134792374, + max_grad_norm=2, + n_epochs=5, + # policy_kwargs are same as the defaults + vf_coef=0.6162112311062333, + ), + ) + + +@train_preference_comparisons_ex.named_config +def seals_walker(): + environment = dict(gym_id="seals/Walker2d-v0") + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + rl = dict( + batch_size=8192, + rl_kwargs=dict( + batch_size=128, + clip_range=0.4, + ent_coef=0.00013057334805552262, + gae_lambda=0.92, + gamma=0.98, + learning_rate=0.000138575372312869, + max_grad_norm=0.6, + n_epochs=20, + # policy_kwargs are same as the defaults + vf_coef=0.6167177795726859, + ), + ) @train_preference_comparisons_ex.named_config diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index d24d9492d..e4ab71da1 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -1,11 +1,18 @@ """Configuration settings for train_rl, training a policy with RL.""" + import sacred +from torch import nn from imitation.scripts.ingredients import environment from imitation.scripts.ingredients import logging as logging_ingredient from imitation.scripts.ingredients import policy_evaluation, rl +# Note: All the hyperparameter configs in the file are tuned +# for the PPO algorithm on the respective environment using the +# RL Baselines Zoo library: +# https://github.com/HumanCompatibleAI/rl-baselines3-zoo/ + train_rl_ex = sacred.Experiment( "train_rl", ingredients=[ @@ -70,8 +77,30 @@ def cartpole(): @train_rl_ex.named_config def seals_cartpole(): - environment = dict(gym_id="seals/CartPole-v0") - total_timesteps = int(1e6) + environment = dict(gym_id="seals/CartPole-v0", num_vec=8) + total_timesteps = int(1e5) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + normalize_reward = False + rl = dict( + batch_size=4096, + rl_kwargs=dict( + batch_size=256, + clip_range=0.4, + ent_coef=0.008508727919228772, + gae_lambda=0.9, + gamma=0.9999, + learning_rate=0.0012403278189645594, + max_grad_norm=0.8, + n_epochs=10, + vf_coef=0.489343896591493, + ), + ) @train_rl_ex.named_config @@ -80,9 +109,69 @@ def half_cheetah(): total_timesteps = int(5e6) # does OK after 1e6, but continues improving +@train_rl_ex.named_config +def seals_half_cheetah(): + environment = dict( + gym_id="seals/HalfCheetah-v0", + num_vec=1, + ) + + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.Tanh, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + # total_timesteps = int(5e6) # does OK after 1e6, but continues improving + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=512, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=3.794797423594763e-06, + gae_lambda=0.95, + gamma=0.95, + learning_rate=0.0003286871805949382, + max_grad_norm=0.8, + n_epochs=5, + vf_coef=0.11483689492120866, + ), + ) + + @train_rl_ex.named_config def seals_hopper(): - environment = dict(gym_id="seals/Hopper-v0") + environment = dict(gym_id="seals/Hopper-v0", num_vec=1) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=512, + clip_range=0.1, + ent_coef=0.0010159833764878474, + gae_lambda=0.98, + gamma=0.995, + learning_rate=0.0003904770450788824, + max_grad_norm=0.9, + n_epochs=20, + # policy_kwargs are same as the defaults + vf_coef=0.20315938606555833, + ), + ) @train_rl_ex.named_config @@ -122,17 +211,99 @@ def reacher(): @train_rl_ex.named_config def seals_ant(): - environment = dict(gym_id="seals/Ant-v0") + environment = dict( + gym_id="seals/Ant-v0", + num_vec=1, + ) + + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.Tanh, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=16, + clip_range=0.3, + ent_coef=3.1441389214159857e-06, + gae_lambda=0.8, + gamma=0.995, + learning_rate=0.00017959211641976886, + max_grad_norm=0.9, + n_epochs=10, + # policy_kwargs are same as the defaults + vf_coef=0.4351450387648799, + ), + ) @train_rl_ex.named_config def seals_swimmer(): - environment = dict(gym_id="seals/Swimmer-v0") + environment = dict(gym_id="seals/Swimmer-v0", num_vec=1) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=2048, + rl_kwargs=dict( + batch_size=64, + clip_range=0.1, + ent_coef=5.167107294612664e-08, + gae_lambda=0.95, + gamma=0.999, + learning_rate=0.000414936134792374, + max_grad_norm=2, + n_epochs=5, + # policy_kwargs are same as the defaults + vf_coef=0.6162112311062333, + ), + ) @train_rl_ex.named_config def seals_walker(): - environment = dict(gym_id="seals/Walker2d-v0") + environment = dict(gym_id="seals/Walker2d-v0", num_vec=1) + policy = dict( + policy_cls="MlpPolicy", + policy_kwargs=dict( + activation_fn=nn.ReLU, + net_arch=[dict(pi=[64, 64], vf=[64, 64])], + ), + ) + + total_timesteps = 1e6 + normalize_reward = False + + rl = dict( + batch_size=8192, + rl_kwargs=dict( + batch_size=128, + clip_range=0.4, + ent_coef=0.00013057334805552262, + gae_lambda=0.92, + gamma=0.98, + learning_rate=0.000138575372312869, + max_grad_norm=0.6, + n_epochs=20, + # policy_kwargs are same as the defaults + vf_coef=0.6167177795726859, + ), + ) # Debug configs diff --git a/benchmarking/example_airl_seals_ant_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/airl_seals_ant_best_hp_eval.json similarity index 98% rename from benchmarking/example_airl_seals_ant_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/airl_seals_ant_best_hp_eval.json index 17f969ff0..d4131433e 100644 --- a/benchmarking/example_airl_seals_ant_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/airl_seals_ant_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/airl_seals_half_cheetah_best_hp_eval.json similarity index 97% rename from benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/airl_seals_half_cheetah_best_hp_eval.json index 754ba6736..f69ba5cb5 100644 --- a/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/airl_seals_half_cheetah_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_airl_seals_hopper_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/airl_seals_hopper_best_hp_eval.json similarity index 98% rename from benchmarking/example_airl_seals_hopper_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/airl_seals_hopper_best_hp_eval.json index 91080d7ce..58c2475f5 100644 --- a/benchmarking/example_airl_seals_hopper_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/airl_seals_hopper_best_hp_eval.json @@ -75,6 +75,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_airl_seals_swimmer_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/airl_seals_swimmer_best_hp_eval.json similarity index 96% rename from benchmarking/example_airl_seals_swimmer_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/airl_seals_swimmer_best_hp_eval.json index fcca8e6b3..8529c58b5 100644 --- a/benchmarking/example_airl_seals_swimmer_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/airl_seals_swimmer_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Swimmer-v0", + "gym_id": "seals/Swimmer-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_airl_seals_walker_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/airl_seals_walker_best_hp_eval.json similarity index 96% rename from benchmarking/example_airl_seals_walker_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/airl_seals_walker_best_hp_eval.json index c63070751..edd99806d 100644 --- a/benchmarking/example_airl_seals_walker_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/airl_seals_walker_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Walker2d-v0", + "gym_id": "seals/Walker2d-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/benchmarking/example_bc_seals_ant_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/bc_seals_ant_best_hp_eval.json similarity index 97% rename from benchmarking/example_bc_seals_ant_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/bc_seals_ant_best_hp_eval.json index 108a93ce7..e9baa8fc1 100644 --- a/benchmarking/example_bc_seals_ant_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/bc_seals_ant_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/bc_seals_half_cheetah_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/bc_seals_half_cheetah_best_hp_eval.json index ecaff2eb0..041f159b0 100644 --- a/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/bc_seals_half_cheetah_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_bc_seals_hopper_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/bc_seals_hopper_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_hopper_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/bc_seals_hopper_best_hp_eval.json index e8c821841..9a7872d37 100644 --- a/benchmarking/example_bc_seals_hopper_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/bc_seals_hopper_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_bc_seals_swimmer_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/bc_seals_swimmer_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_swimmer_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/bc_seals_swimmer_best_hp_eval.json index 30884c9c4..8a8f2456a 100644 --- a/benchmarking/example_bc_seals_swimmer_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/bc_seals_swimmer_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_bc_seals_walker_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/bc_seals_walker_best_hp_eval.json similarity index 96% rename from benchmarking/example_bc_seals_walker_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/bc_seals_walker_best_hp_eval.json index 0ca30120e..f33e6c5a2 100644 --- a/benchmarking/example_bc_seals_walker_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/bc_seals_walker_best_hp_eval.json @@ -43,6 +43,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/benchmarking/example_dagger_seals_ant_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/dagger_seals_ant_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_ant_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/dagger_seals_ant_best_hp_eval.json index de75b80f1..e02828667 100644 --- a/benchmarking/example_dagger_seals_ant_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/dagger_seals_ant_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/dagger_seals_half_cheetah_best_hp_eval.json similarity index 96% rename from benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/dagger_seals_half_cheetah_best_hp_eval.json index 7f42bfdf9..d1c9e5923 100644 --- a/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/dagger_seals_half_cheetah_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_dagger_seals_hopper_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/dagger_seals_hopper_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_hopper_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/dagger_seals_hopper_best_hp_eval.json index 1cf29a1a4..b91f66298 100644 --- a/benchmarking/example_dagger_seals_hopper_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/dagger_seals_hopper_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/dagger_seals_swimmer_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_swimmer_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/dagger_seals_swimmer_best_hp_eval.json index c112db680..545761cbc 100644 --- a/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/dagger_seals_swimmer_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_dagger_seals_walker_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/dagger_seals_walker_best_hp_eval.json similarity index 97% rename from benchmarking/example_dagger_seals_walker_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/dagger_seals_walker_best_hp_eval.json index e59bef464..7b694c8d2 100644 --- a/benchmarking/example_dagger_seals_walker_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/dagger_seals_walker_best_hp_eval.json @@ -47,6 +47,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/benchmarking/fast_dagger_seals_cartpole.json b/src/imitation/scripts/config/tuned_hps/fast_dagger_seals_cartpole.json similarity index 100% rename from benchmarking/fast_dagger_seals_cartpole.json rename to src/imitation/scripts/config/tuned_hps/fast_dagger_seals_cartpole.json diff --git a/benchmarking/example_gail_seals_ant_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/gail_seals_ant_best_hp_eval.json similarity index 98% rename from benchmarking/example_gail_seals_ant_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/gail_seals_ant_best_hp_eval.json index 81399b00c..3d43b34ba 100644 --- a/benchmarking/example_gail_seals_ant_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/gail_seals_ant_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Ant-v0" + "gym_id": "seals/Ant-v1" } } diff --git a/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/gail_seals_half_cheetah_best_hp_eval.json similarity index 97% rename from benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/gail_seals_half_cheetah_best_hp_eval.json index 1d2f26648..914f3712a 100644 --- a/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/gail_seals_half_cheetah_best_hp_eval.json @@ -62,6 +62,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/HalfCheetah-v0" + "gym_id": "seals/HalfCheetah-v1" } } diff --git a/benchmarking/example_gail_seals_hopper_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/gail_seals_hopper_best_hp_eval.json similarity index 98% rename from benchmarking/example_gail_seals_hopper_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/gail_seals_hopper_best_hp_eval.json index 70787ff7e..cebdae71c 100644 --- a/benchmarking/example_gail_seals_hopper_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/gail_seals_hopper_best_hp_eval.json @@ -75,6 +75,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Hopper-v0" + "gym_id": "seals/Hopper-v1" } } diff --git a/benchmarking/example_gail_seals_swimmer_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/gail_seals_swimmer_best_hp_eval.json similarity index 96% rename from benchmarking/example_gail_seals_swimmer_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/gail_seals_swimmer_best_hp_eval.json index 650c5f46a..b0bd0e645 100644 --- a/benchmarking/example_gail_seals_swimmer_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/gail_seals_swimmer_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Swimmer-v0", + "gym_id": "seals/Swimmer-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Swimmer-v0" + "gym_id": "seals/Swimmer-v1" } } diff --git a/benchmarking/example_gail_seals_walker_best_hp_eval.json b/src/imitation/scripts/config/tuned_hps/gail_seals_walker_best_hp_eval.json similarity index 96% rename from benchmarking/example_gail_seals_walker_best_hp_eval.json rename to src/imitation/scripts/config/tuned_hps/gail_seals_walker_best_hp_eval.json index d85eb46d5..2626b4c43 100644 --- a/benchmarking/example_gail_seals_walker_best_hp_eval.json +++ b/src/imitation/scripts/config/tuned_hps/gail_seals_walker_best_hp_eval.json @@ -12,7 +12,7 @@ }, "expert": { "loader_kwargs": { - "gym_id": "seals/Walker2d-v0", + "gym_id": "seals/Walker2d-v1", "organization": "HumanCompatibleAI" } }, @@ -81,6 +81,6 @@ "n_episodes_eval": 50 }, "environment": { - "gym_id": "seals/Walker2d-v0" + "gym_id": "seals/Walker2d-v1" } } diff --git a/src/imitation/scripts/config/tuning.py b/src/imitation/scripts/config/tuning.py new file mode 100644 index 000000000..73313770a --- /dev/null +++ b/src/imitation/scripts/config/tuning.py @@ -0,0 +1,232 @@ +"""Config files for tuning experiments.""" + +import ray.tune as tune +import sacred +from torch import nn + +from imitation.algorithms import dagger as dagger_alg +from imitation.scripts.parallel import parallel_ex + +tuning_ex = sacred.Experiment("tuning", ingredients=[parallel_ex]) + + +@tuning_ex.named_config +def rl(): + parallel_run_config = dict( + sacred_ex_name="train_rl", + run_name="rl_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={"environment": {"num_vec": 1}}, + search_space={ + "config_updates": { + "rl": { + "batch_size": tune.choice([512, 1024, 2048, 4096, 8192]), + "rl_kwargs": { + "learning_rate": tune.loguniform(1e-5, 1e-2), + "batch_size": tune.choice([64, 128, 256, 512]), + "n_epochs": tune.choice([5, 10, 20]), + }, + }, + }, + }, + num_samples=100, + repeat=1, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def bc(): + parallel_run_config = dict( + sacred_ex_name="train_imitation", + run_name="bc_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + }, + search_space={ + "config_updates": { + "bc": dict( + batch_size=tune.choice([8, 16, 32, 64]), + l2_weight=tune.loguniform(1e-6, 1e-2), # L2 regularization weight + optimizer_kwargs=dict( + lr=tune.loguniform(1e-5, 1e-2), + ), + train_kwargs=dict( + n_epochs=tune.choice([1, 5, 10, 20]), + ), + ), + }, + "command_name": "bc", + }, + num_samples=64, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + + num_eval_seeds = 5 + eval_best_trial_resource_multiplier = 1 + + +@tuning_ex.named_config +def dagger(): + parallel_run_config = dict( + sacred_ex_name="train_imitation", + run_name="dagger_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "dagger": {"total_timesteps": 1e5}, + "bc": { + "batch_size": 16, + "l2_weight": 1e-4, + "optimizer_kwargs": {"lr": 1e-3}, + }, + }, + search_space={ + "config_updates": { + "bc": dict( + train_kwargs=dict( + n_epochs=tune.choice([1, 5, 10]), + ), + ), + "dagger": dict( + beta_schedule=tune.choice( + [dagger_alg.LinearBetaSchedule(i) for i in [1, 5, 15]] + + [ + dagger_alg.ExponentialBetaSchedule(i) + for i in [0.3, 0.5, 0.7] + ], + ), + rollout_round_min_episodes=tune.choice([3, 5, 10]), + ), + }, + "command_name": "dagger", + }, + num_samples=50, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def gail(): + parallel_run_config = dict( + sacred_ex_name="train_adversarial", + run_name="gail_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "total_timesteps": 1e7, + }, + search_space={ + "config_updates": { + "algorithm_kwargs": dict( + demo_batch_size=tune.choice([32, 128, 512, 2048, 8192]), + n_disc_updates_per_round=tune.choice([8, 16]), + ), + "rl": { + "batch_size": tune.choice([4096, 8192, 16384]), + "rl_kwargs": { + "ent_coef": tune.loguniform(1e-7, 1e-3), + "learning_rate": tune.loguniform(1e-5, 1e-2), + }, + }, + "algorithm_specific": {}, + }, + "command_name": "gail", + }, + num_samples=100, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def airl(): + parallel_run_config = dict( + sacred_ex_name="train_adversarial", + run_name="airl_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "total_timesteps": 1e7, + }, + search_space={ + "config_updates": { + "algorithm_kwargs": dict( + demo_batch_size=tune.choice([32, 128, 512, 2048, 8192]), + n_disc_updates_per_round=tune.choice([8, 16]), + ), + "rl": { + "batch_size": tune.choice([4096, 8192, 16384]), + "rl_kwargs": { + "ent_coef": tune.loguniform(1e-7, 1e-3), + "learning_rate": tune.loguniform(1e-5, 1e-2), + }, + }, + "algorithm_specific": {}, + }, + "command_name": "airl", + }, + num_samples=100, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 + + +@tuning_ex.named_config +def pc(): + parallel_run_config = dict( + sacred_ex_name="train_preference_comparisons", + run_name="pc_tuning", + base_named_configs=["logging.wandb_logging"], + base_config_updates={ + "environment": {"num_vec": 1}, + "demonstrations": {"source": "huggingface"}, + "total_timesteps": 2e7, + "total_comparisons": 5000, + "query_schedule": "hyperbolic", + "gatherer_kwargs": {"sample": True}, + }, + search_space={ + "named_configs": [ + ["reward.normalize_output_disable"], + ], + "config_updates": { + "train": { + "policy_kwargs": { + "activation_fn": tune.choice( + [ + nn.ReLU, + ], + ), + }, + }, + "num_iterations": tune.choice([25, 50]), + "initial_comparison_frac": tune.choice([0.1, 0.25]), + "reward_trainer_kwargs": { + "epochs": tune.choice([1, 3, 6]), + }, + "rl": { + "batch_size": tune.choice([512, 2048, 8192]), + "rl_kwargs": { + "learning_rate": tune.loguniform(1e-5, 1e-2), + "ent_coef": tune.loguniform(1e-7, 1e-3), + }, + }, + }, + }, + num_samples=100, + repeat=3, + resources_per_trial=dict(cpu=1), + ) + num_eval_seeds = 5 diff --git a/src/imitation/scripts/ingredients/reward.py b/src/imitation/scripts/ingredients/reward.py index 2e2b67022..6b2e0195e 100644 --- a/src/imitation/scripts/ingredients/reward.py +++ b/src/imitation/scripts/ingredients/reward.py @@ -46,6 +46,11 @@ def normalize_output_running(): normalize_output_layer = networks.RunningNorm # noqa: F841 +@reward_ingredient.named_config +def normalize_output_ema(): + normalize_output_layer = networks.EMANorm # noqa: F841 + + @reward_ingredient.named_config def reward_ensemble(): net_cls = reward_nets.RewardEnsemble diff --git a/src/imitation/scripts/parallel.py b/src/imitation/scripts/parallel.py index 6014a08b6..d5e5e2378 100644 --- a/src/imitation/scripts/parallel.py +++ b/src/imitation/scripts/parallel.py @@ -3,11 +3,13 @@ import collections.abc import copy import pathlib -from typing import Any, Callable, Dict, Mapping, Optional, Sequence +from typing import Any, Callable, Dict, Mapping, Sequence import ray import ray.tune import sacred +from ray.tune import search +from ray.tune.search import optuna from sacred.observers import FileStorageObserver from imitation.scripts.config.parallel import parallel_ex @@ -17,29 +19,33 @@ def parallel( sacred_ex_name: str, run_name: str, + num_samples: int, search_space: Mapping[str, Any], base_named_configs: Sequence[str], base_config_updates: Mapping[str, Any], resources_per_trial: Mapping[str, Any], init_kwargs: Mapping[str, Any], - local_dir: Optional[str], - upload_dir: Optional[str], -) -> None: + repeat: int, + experiment_checkpoint_path: str, + tune_run_kwargs: Dict[str, Any], +) -> ray.tune.ExperimentAnalysis: """Parallelize multiple runs of another Sacred Experiment using Ray Tune. A Sacred FileObserver is attached to the inner experiment and writes Sacred logs to "{RAY_LOCAL_DIR}/sacred/". These files are automatically copied over - to `upload_dir` if that argument is provided. + to `upload_dir` if that argument is provided in `tune_run_kwargs`. Args: - sacred_ex_name: The Sacred experiment to tune. Either "train_rl" or - "train_adversarial". + sacred_ex_name: The Sacred experiment to tune. Either "train_rl", + "train_imitation", "train_adversarial" or "train_preference_comparisons". run_name: A name describing this parallelizing experiment. This argument is also passed to `ray.tune.run` as the `name` argument. It is also saved in 'sacred/run.json' of each inner Sacred experiment under the 'experiment.name' key. This is equivalent to using the Sacred CLI '--name' option on the inner experiment. Offline analysis jobs can use this argument to group similar data. + num_samples: Number of times to sample from the hyperparameter space without + considering repetition using `repeat`. search_space: A dictionary which can contain Ray Tune search objects like `ray.tune.grid_search` and `ray.tune.sample_from`, and is passed as the `config` argument to `ray.tune.run()`. After the @@ -60,11 +66,22 @@ def parallel( generated Ray directory name, unlike config updates from `search_space`. resources_per_trial: Argument to `ray.tune.run()`. init_kwargs: Arguments to pass to `ray.init`. - local_dir: `local_dir` argument to `ray.tune.run()`. - upload_dir: `upload_dir` argument to `ray.tune.run()`. + repeat: Number of runs to repeat each trial for. + If `repeat` > 1, then optuna is used as the default search algorithm + unless specified otherwise in `tune_run_kwargs`. + experiment_checkpoint_path: Path containing the checkpoints of a previous + experiment ran using this script. Useful for evaluating the best trial + of the experiment. + tune_run_kwargs: Other arguments to pass to `ray.tune.run()`. Raises: TypeError: Named configs not string sequences or config updates not mappings. + ValueError: `repeat` > 1 but `search_alg` is not an instance of + `ray.tune.search.SearchAlgorithm`. + + Returns: + The result of running the parallel experiment with `ray.tune.run()`. + Useful for fetching the configs and results dataframe of all the trials. """ # Basic validation for config options before we enter parallel jobs. if not isinstance(base_named_configs, collections.abc.Sequence): @@ -95,15 +112,38 @@ def parallel( ) ray.init(**init_kwargs) + updated_tune_run_kwargs = copy.deepcopy(tune_run_kwargs) + if repeat > 1: + try: + # Use optuna as the default search algorithm for repeat runs. + algo = tune_run_kwargs.get("search_alg", optuna.OptunaSearch()) + updated_tune_run_kwargs["search_alg"] = search.Repeater(algo, repeat) + except AttributeError as e: + raise ValueError( + "repeat > 1 but search_alg is not an instance of " + "ray.tune.search.SearchAlgorithm", + ) from e + + if sacred_ex_name == "train_rl": + return_key = "monitor_return_mean" + else: + return_key = "imit_stats/monitor_return_mean" + try: - ray.tune.run( - trainable, - config=search_space, - name=run_name, - local_dir=local_dir, - resources_per_trial=resources_per_trial, - sync_config=ray.tune.syncer.SyncConfig(upload_dir=upload_dir), - ) + if experiment_checkpoint_path: + result = ray.tune.ExperimentAnalysis(experiment_checkpoint_path) + else: + result = ray.tune.run( + trainable, + config=search_space, + num_samples=num_samples * repeat, + name=run_name, + resources_per_trial=resources_per_trial, + metric=return_key, + mode="max", + **updated_tune_run_kwargs, + ) + return result finally: ray.shutdown() @@ -113,7 +153,7 @@ def _ray_tune_sacred_wrapper( run_name: str, base_named_configs: list, base_config_updates: Mapping[str, Any], -) -> Callable[[Mapping[str, Any], Any], Mapping[str, Any]]: +) -> Callable[[Dict[str, Any], Any], Mapping[str, Any]]: """From an Experiment build a wrapped run function suitable for Ray Tune. `ray.tune.run(...)` expects a trainable function that takes a dict @@ -164,16 +204,22 @@ def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: # TODO(shwang): Stop modifying CAPTURE_MODE once the issue is fixed. sacred.SETTINGS.CAPTURE_MODE = "sys" - run_kwargs = config + run_kwargs = dict(**config) updated_run_kwargs: Dict[str, Any] = {} # Import inside function rather than in module because Sacred experiments # are not picklable, and Ray requires this function to be picklable. from imitation.scripts.train_adversarial import train_adversarial_ex + from imitation.scripts.train_imitation import train_imitation_ex + from imitation.scripts.train_preference_comparisons import ( + train_preference_comparisons_ex, + ) from imitation.scripts.train_rl import train_rl_ex experiments = { "train_rl": train_rl_ex, "train_adversarial": train_adversarial_ex, + "train_imitation": train_imitation_ex, + "train_preference_comparisons": train_preference_comparisons_ex, } ex = experiments[sacred_ex_name] @@ -181,23 +227,23 @@ def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: named_configs = base_named_configs + run_kwargs["named_configs"] updated_run_kwargs["named_configs"] = named_configs - config_updates = {**base_config_updates, **run_kwargs["config_updates"]} + config_updates: Dict[str, Any] = {} + config_updates.update(base_config_updates) + config_updates.update(run_kwargs["config_updates"]) + # for repeat runs, set the seed using their trial index + if "__trial_index__" in run_kwargs: + config_updates.update(seed=run_kwargs.pop("__trial_index__")) updated_run_kwargs["config_updates"] = config_updates # Add other run_kwargs items to updated_run_kwargs. for k, v in run_kwargs.items(): if k not in updated_run_kwargs: updated_run_kwargs[k] = v - run = ex.run( **updated_run_kwargs, options={"--run": run_name, "--file_storage": "sacred"}, ) - # Ray Tune has a string formatting error if raylet completes without - # any calls to `reporter`. - reporter(done=True) - assert run.status == "COMPLETED" return run.result diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 292597561..c7b757c52 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -131,6 +131,7 @@ def dagger( expert_policy=expert_policy, custom_logger=custom_logger, bc_trainer=bc_trainer, + beta_schedule=dagger["beta_schedule"], rng=_rnd, ) diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 8fb13f4c4..71363daee 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -280,7 +280,7 @@ def save_callback(iteration_num): # Storing and evaluating policy only useful if we generated trajectory data if bool(trajectory_path is None): results = dict(results) - results["rollout"] = policy_evaluation.eval_policy(agent, venv) + results["imit_stats"] = policy_evaluation.eval_policy(agent, venv) if save_preferences: main_trainer.dataset.save(log_dir / "preferences.pkl") diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 199440163..1b5dfc028 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -160,7 +160,8 @@ def train_rl( policies_serialize.save_stable_model(output_dir, rl_algo) # Final evaluation of expert policy. - return policy_evaluation.eval_policy(rl_algo, venv) + eval_stats = policy_evaluation.eval_policy(rl_algo, venv) + return eval_stats def main_console(): diff --git a/src/imitation/scripts/tuning.py b/src/imitation/scripts/tuning.py new file mode 100644 index 000000000..24095b1de --- /dev/null +++ b/src/imitation/scripts/tuning.py @@ -0,0 +1,184 @@ +"""Tunes the hyperparameters of the algorithms.""" + +import copy +import pathlib +from typing import Dict + +import numpy as np +import ray +from pandas.api import types as pd_types +from ray.tune.search import optuna +from sacred.observers import FileStorageObserver + +from imitation.scripts.config.parallel import parallel_ex +from imitation.scripts.config.tuning import tuning_ex + + +@tuning_ex.main +def tune( + parallel_run_config, + eval_best_trial_resource_multiplier: int = 1, + num_eval_seeds: int = 5, +) -> None: + """Tune hyperparameters of imitation algorithms using the parallel script. + + The parallel script is called twice in this function. The first call is to + tune the hyperparameters. The second call is to evaluate the best trial on + a separate set of seeds. + + Args: + parallel_run_config: Dictionary of arguments to pass to the parallel script. + This is used to define the search space for tuning the hyperparameters. + eval_best_trial_resource_multiplier: Factor by which to multiply the + number of cpus per trial in `resources_per_trial`. This is useful for + allocating more resources per trial to the evaluation trials than the + resources for hyperparameter tuning since number of evaluation trials + is usually much smaller than the number of tuning trials. + num_eval_seeds: Number of distinct seeds to evaluate the best trial on. + Set to 0 to disable evaluation. + + Raises: + ValueError: If no trials are returned by the parallel run of tuning. + """ + updated_parallel_run_config = copy.deepcopy(parallel_run_config) + search_alg = optuna.OptunaSearch() + tune_run_kwargs = updated_parallel_run_config.setdefault("tune_run_kwargs", dict()) + tune_run_kwargs["search_alg"] = search_alg + run = parallel_ex.run(config_updates=updated_parallel_run_config) + experiment_analysis = run.result + if not experiment_analysis.trials: + raise ValueError( + "No trials found. Please ensure that the `experiment_checkpoint_path` " + "in `parallel_run_config` is passed correctly " + "or that the tuning run finished properly.", + ) + + return_key = "imit_stats/monitor_return_mean" + if updated_parallel_run_config["sacred_ex_name"] == "train_rl": + return_key = "monitor_return_mean" + best_trial = find_best_trial(experiment_analysis, return_key, print_return=True) + + if num_eval_seeds > 0: # evaluate the best trial + resources_per_trial_eval = copy.deepcopy( + updated_parallel_run_config["resources_per_trial"], + ) + # update cpus per trial only if it is provided in `resources_per_trial` + # Uses the default values (cpu=1) if it is not provided + if "cpu" in updated_parallel_run_config["resources_per_trial"]: + resources_per_trial_eval["cpu"] *= eval_best_trial_resource_multiplier + evaluate_trial( + best_trial, + num_eval_seeds, + updated_parallel_run_config["run_name"] + "_best_hp_eval", + updated_parallel_run_config, + resources_per_trial_eval, + return_key, + ) + + +def find_best_trial( + experiment_analysis: ray.tune.analysis.ExperimentAnalysis, + return_key: str, + print_return: bool = False, +) -> ray.tune.experiment.Trial: + """Find the trial with the best mean return across all seeds. + + Args: + experiment_analysis: The result of a parallel/tuning experiment. + return_key: The key of the return metric in the results dataframe. + print_return: Whether to print the mean and std of the returns + of the best trial. + + Returns: + best_trial: The trial with the best mean return across all seeds. + """ + df = experiment_analysis.results_df + # convert object dtype to str required by df.groupby + for col in df.columns: + if pd_types.is_object_dtype(df[col]): + df[col] = df[col].astype("str") + # group into separate HP configs + grp_keys = [c for c in df.columns if c.startswith("config")] + grp_keys = [c for c in grp_keys if "seed" not in c and "trial_index" not in c] + grps = df.groupby(grp_keys) + # store mean return of runs across all seeds in a group + # the transform method is applied to get the mean return for every trial + # instead of for every group. So every trial in a group will have the same + # mean return column. + df["mean_return"] = grps[return_key].transform(lambda x: x.mean()) + best_config_df = df[df["mean_return"] == df["mean_return"].max()] + row = best_config_df.iloc[0] + best_config_tag = row["experiment_tag"] + assert experiment_analysis.trials is not None # for mypy + best_trial = [ + t for t in experiment_analysis.trials if best_config_tag in t.experiment_tag + ][0] + + if print_return: + all_returns = df[df["mean_return"] == row["mean_return"]][return_key] + all_returns = all_returns.to_numpy() + print("All returns:", all_returns) + print("Mean return:", row["mean_return"]) + print("Std return:", np.std(all_returns)) + print("Total seeds:", len(all_returns)) + return best_trial + + +def evaluate_trial( + trial: ray.tune.experiment.Trial, + num_eval_seeds: int, + run_name: str, + parallel_run_config, + resources_per_trial: Dict[str, int], + return_key: str, + print_return: bool = False, +): + """Evaluate a given trial of a parallel run on a separate set of seeds. + + Args: + trial: The trial to evaluate. + num_eval_seeds: Number of distinct seeds to evaluate the best trial on. + run_name: The name of the evaluation run. + parallel_run_config: Dictionary of arguments passed to the parallel + script to get best_trial. + resources_per_trial: Resources to be used for each evaluation trial. + return_key: The key of the return metric in the results dataframe. + print_return: Whether to print the mean and std of the evaluation returns. + + Returns: + eval_run: The result of the evaluation run. + """ + config = trial.config + config["config_updates"].update( + seed=ray.tune.grid_search(list(range(100, 100 + num_eval_seeds))), + ) + eval_config_updates = parallel_run_config.copy() + eval_config_updates.update( + run_name=run_name, + num_samples=1, + search_space=config, + resources_per_trial=resources_per_trial, + repeat=1, + experiment_checkpoint_path="", + ) + # required for grid search + eval_config_updates["tune_run_kwargs"].update(search_alg=None) + eval_run = parallel_ex.run(config_updates=eval_config_updates) + eval_result = eval_run.result + returns = eval_result.results_df[return_key].to_numpy() + if print_return: + print("All returns:", returns) + print("Mean:", np.mean(returns)) + print("Std:", np.std(returns)) + return eval_run + + +def main_console(): + observer_path = pathlib.Path.cwd() / "output" / "sacred" / "tuning" + observer = FileStorageObserver(observer_path) + tuning_ex.observers.append(observer) + tuning_ex.run_commandline() + + +if __name__ == "__main__": # pragma: no cover + main_console() diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index b16dada10..ae39116e7 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -824,10 +824,10 @@ def test_train_rl_cnn_policy(tmpdir: str, rng): dict( sacred_ex_name="train_rl", base_named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["rl"], - n_seeds=2, + repeat=2, search_space={ "config_updates": { - "rl": {"rl_kwargs": {"learning_rate": tune.grid_search([3e-4, 1e-4])}}, + "rl": {"rl_kwargs": {"learning_rate": tune.choice([3e-4, 1e-4])}}, }, "meta_info": {"asdf": "I exist for coverage purposes"}, }, @@ -840,7 +840,8 @@ def test_train_rl_cnn_policy(tmpdir: str, rng): "demonstrations.path": CARTPOLE_TEST_ROLLOUT_PATH.absolute(), }, search_space={ - "command_name": tune.grid_search(["gail", "airl"]), + "command_name": "airl", + "config_updates": {"total_timesteps": tune.choice([5, 10])}, }, ), ] @@ -920,13 +921,16 @@ def test_parallel_train_adversarial_custom_env(tmpdir): config_updates = dict( sacred_ex_name="train_adversarial", - n_seeds=1, + repeat=2, base_named_configs=[env_named_config] + ALGO_FAST_CONFIGS["adversarial"], base_config_updates=dict( logging=dict(log_root=tmpdir), demonstrations=dict(path=path), ), - search_space=dict(command_name="gail"), + # specifying repeat=2 uses the optuna search algorithm which + # requires the search space to be non-empty. So we provide + # the command name using tune.choice. + search_space=dict(command_name=tune.choice(["gail"])), ) config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE) run = parallel.parallel_ex.run(config_updates=config_updates) @@ -978,7 +982,7 @@ def test_analyze_imitation(tmpdir: str, run_names: List[str], run_sacred_fn): assert run.status == "COMPLETED" # Check that analyze script finds the correct number of logs. - def check(run_name: Optional[str], count: int) -> None: + def check(run_name: Optional[str], count: int, table_verbosity: int = 1) -> None: run = analyze.analysis_ex.run( command_name="analyze_imitation", config_updates=dict( @@ -988,6 +992,7 @@ def check(run_name: Optional[str], count: int) -> None: csv_output_path=tmpdir_path / "analysis.csv", tex_output_path=tmpdir_path / "analysis.tex", print_table=True, + table_verbosity=table_verbosity, ), ) assert run.status == "COMPLETED" @@ -997,15 +1002,19 @@ def check(run_name: Optional[str], count: int) -> None: for run_name, count in Counter(run_names).items(): check(run_name, count) - check(None, len(run_names)) # Check total number of logs. + check(None, len(run_names), table_verbosity=3) # Check total number of logs. def test_analyze_gather_tb(tmpdir: str): if os.name == "nt": # pragma: no cover pytest.skip("gather_tb uses symlinks: not supported by Windows") - - config_updates: Dict[str, Any] = dict(local_dir=tmpdir, run_name="test") + num_runs = 2 + config_updates: Dict[str, Any] = dict( + tune_run_kwargs=dict(local_dir=tmpdir), + run_name="test", + ) config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE) + config_updates.update(num_samples=num_runs) parallel_run = parallel.parallel_ex.run( named_configs=["generate_test_data"], config_updates=config_updates, @@ -1020,7 +1029,7 @@ def test_analyze_gather_tb(tmpdir: str): ) assert run.status == "COMPLETED" assert isinstance(run.result, dict) - assert run.result["n_tb_dirs"] == 2 + assert run.result["n_tb_dirs"] == num_runs def test_pickle_fmt_rollout_test_data_is_pickle(): diff --git a/tests/test_benchmarking.py b/tests/test_benchmarking.py index ba01b38a2..cbae34688 100644 --- a/tests/test_benchmarking.py +++ b/tests/test_benchmarking.py @@ -1,12 +1,8 @@ -"""Tests for config files in benchmarking/ folder.""" -import pathlib +"""Tests for config files in imitation/scripts/config/tuned_hps/ folder.""" import pytest -from imitation.scripts import train_adversarial, train_imitation - -THIS_DIR = pathlib.Path(__file__).absolute().parent -BENCHMARKING_DIR = THIS_DIR.parent / "benchmarking" +from imitation.scripts import train_adversarial, train_imitation, tuning ALGORITHMS = ["bc", "dagger", "airl", "gail"] ENVIRONMENTS = [ @@ -25,7 +21,6 @@ def test_benchmarks_print_config_succeeds(algorithm: str, environment: str): # because running the configs requires MuJoCo. # Requiring MuJoCo to run the tests adds too much complexity. - # GIVEN if algorithm in ("bc", "dagger"): experiment = train_imitation.train_imitation_ex elif algorithm in ("airl", "gail"): @@ -34,13 +29,24 @@ def test_benchmarks_print_config_succeeds(algorithm: str, environment: str): raise ValueError(f"Unknown algorithm: {algorithm}") # pragma: no cover config_name = f"{algorithm}_{environment}" - config_file = str( - BENCHMARKING_DIR / f"example_{algorithm}_{environment}_best_hp_eval.json", - ) - - # WHEN - experiment.add_named_config(config_name, config_file) run = experiment.run(command_name="print_config", named_configs=[config_name]) + assert run.status == "COMPLETED" + - # THEN +@pytest.mark.parametrize("environment", ENVIRONMENTS) +@pytest.mark.parametrize("algorithm", ALGORITHMS) +def test_tuning_print_config_succeeds(algorithm: str, environment: str): + # We test the configs using the print_config command, + # because running the configs requires MuJoCo. + # Requiring MuJoCo to run the tests adds too much complexity. + experiment = tuning.tuning_ex + run = experiment.run( + command_name="print_config", + named_configs=[algorithm], + config_updates=dict( + parallel_run_config=dict( + base_named_configs=[f"{algorithm}_{environment}"], + ), + ), + ) assert run.status == "COMPLETED" diff --git a/tests/test_experiments.py b/tests/test_experiments.py index 0f6d314fe..c21ed0317 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -3,6 +3,7 @@ import glob import os import pathlib +import re import subprocess from typing import List @@ -18,30 +19,31 @@ ) THIS_DIR = pathlib.Path(__file__).absolute().parent -BENCHMARKING_DIR = THIS_DIR.parent / "benchmarking" +BENCHMARKING_DIR = THIS_DIR.parent / "src/imitation/scripts/config/tuned_hps" EXPERIMENTS_DIR = THIS_DIR.parent / "experiments" COMMANDS_PY_PATH = EXPERIMENTS_DIR / "commands.py" -EXPECTED_LOCAL_CONFIG_TEMPLATE = """python -m imitation.scripts.train_imitation dagger \ ---capture=sys --name=run0 --file_storage={output_dir}/sacred/\ -$USER-cmd-run0-dagger-0-8bf911a8 \ -with benchmarking/fast_dagger_seals_cartpole.json \ -seed=0 logging.log_root={output_dir}""" +EXPECTED_LOCAL_CONFIG_TEMPLATE = f"""python -m imitation.scripts.train_imitation \ +dagger --capture=sys --name=run0 --file_storage={{output_dir}}/sacred/\ +$USER-cmd-run0-dagger-0-c9420c90 \ +with {BENCHMARKING_DIR}/fast_dagger_seals_cartpole.json \ +seed=0 logging.log_root={{output_dir}}""" -EXPECTED_HOFVARPNIR_CONFIG_TEMPLATE = """ctl job run \ ---name $USER-cmd-run0-dagger-0-c3ac179d \ +BENCHMARKING_DIR_SUFFIX = re.sub(r".*/src/", "", str(BENCHMARKING_DIR)) +EXPECTED_HOFVARPNIR_CONFIG_TEMPLATE = f"""ctl job run \ +--name $USER-cmd-run0-dagger-0-c9420c90 \ --command "python -m imitation.scripts.train_imitation dagger \ ---capture=sys --name=run0 --file_storage={output_dir}/sacred/\ -$USER-cmd-run0-dagger-0-c3ac179d \ -with /data/imitation/benchmarking/fast_dagger_seals_cartpole.json \ -seed=0 logging.log_root={output_dir}" \ +--capture=sys --name=run0 --file_storage={{output_dir}}/sacred/\ +$USER-cmd-run0-dagger-0-c9420c90 \ +with /data/imitation/src/{BENCHMARKING_DIR_SUFFIX}/fast_dagger_seals_cartpole.json \ +seed=0 logging.log_root={{output_dir}}" \ --container hacobe/devbox:imitation \ --login --force-pull --never-restart --gpu 0 \ --shared-host-dir-mount /data""" def _get_benchmarking_path(benchmarking_file): - return os.path.join(BENCHMARKING_DIR.stem, benchmarking_file) + return os.path.join(BENCHMARKING_DIR, benchmarking_file) def _run_commands_from_flags(**kwargs) -> List[str]: @@ -148,10 +150,10 @@ def test_commands_local_config_with_custom_flags(): output_dir="/foo/bar", ) assert len(commands) == 1 - expected = """python -m imitation.scripts.train_imitation dagger \ + expected = f"""python -m imitation.scripts.train_imitation dagger \ --capture=sys --name=baz --file_storage=/foo/bar/sacred/\ -$USER-cmd-baz-dagger-1-8bf911a8 \ -with benchmarking/fast_dagger_seals_cartpole.json \ +$USER-cmd-baz-dagger-1-c9420c90 \ +with {BENCHMARKING_DIR}/fast_dagger_seals_cartpole.json \ seed=1 logging.log_root=/foo/bar""" assert commands[0] == expected @@ -178,10 +180,10 @@ def test_commands_hofvarpnir_config_with_custom_flags(): remote=True, ) assert len(commands) == 1 - expected = """ctl job run --name $USER-cmd-baz-dagger-1-345d0f8a \ + expected = """ctl job run --name $USER-cmd-baz-dagger-1-c9420c90 \ --command "python -m imitation.scripts.train_imitation dagger \ --capture=sys --name=baz --file_storage=/foo/bar/sacred/\ -$USER-cmd-baz-dagger-1-345d0f8a \ +$USER-cmd-baz-dagger-1-c9420c90 \ with /bas/bat/fast_dagger_seals_cartpole.json \ seed=1 logging.log_root=/foo/bar" --container bam \ --login --force-pull --never-restart --gpu 0 \ @@ -245,13 +247,13 @@ def test_commands_hofvarpnir_config_with_special_characters_in_flags(tmpdir): def test_commands_bc_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_bc_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("bc_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 - expected = """python -m imitation.scripts.train_imitation bc \ + expected = f"""python -m imitation.scripts.train_imitation bc \ --capture=sys --name=run0 --file_storage=output/sacred/\ -$USER-cmd-run0-bc-0-138a1475 \ -with benchmarking/example_bc_seals_ant_best_hp_eval.json \ +$USER-cmd-run0-bc-0-bb460c12 \ +with {BENCHMARKING_DIR}/bc_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected @@ -259,13 +261,13 @@ def test_commands_bc_config(): def test_commands_dagger_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_dagger_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("dagger_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 - expected = """python -m imitation.scripts.train_imitation dagger \ + expected = f"""python -m imitation.scripts.train_imitation dagger \ --capture=sys --name=run0 --file_storage=output/sacred/\ -$USER-cmd-run0-dagger-0-6a49161a \ -with benchmarking/example_dagger_seals_ant_best_hp_eval.json \ +$USER-cmd-run0-dagger-0-f0790db7 \ +with {BENCHMARKING_DIR}/dagger_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected @@ -273,13 +275,13 @@ def test_commands_dagger_config(): def test_commands_gail_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_gail_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("gail_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 - expected = """python -m imitation.scripts.train_adversarial gail \ + expected = f"""python -m imitation.scripts.train_adversarial gail \ --capture=sys --name=run0 --file_storage=output/sacred/\ -$USER-cmd-run0-gail-0-3ec8154d \ -with benchmarking/example_gail_seals_ant_best_hp_eval.json \ +$USER-cmd-run0-gail-0-d5be0cea \ +with {BENCHMARKING_DIR}/gail_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected @@ -287,13 +289,13 @@ def test_commands_gail_config(): def test_commands_airl_config(): if os.name == "nt": # pragma: no cover pytest.skip("commands.py not ported to Windows.") - cfg_pattern = _get_benchmarking_path("example_airl_seals_ant_best_hp_eval.json") + cfg_pattern = _get_benchmarking_path("airl_seals_ant_best_hp_eval.json") commands = _run_commands_from_flags(cfg_pattern=cfg_pattern) assert len(commands) == 1 - expected = """python -m imitation.scripts.train_adversarial airl \ + expected = f"""python -m imitation.scripts.train_adversarial airl \ --capture=sys --name=run0 \ ---file_storage=output/sacred/$USER-cmd-run0-airl-0-400e1558 \ -with benchmarking/example_airl_seals_ant_best_hp_eval.json \ +--file_storage=output/sacred/$USER-cmd-run0-airl-0-d7040cf5 \ +with {BENCHMARKING_DIR}/airl_seals_ant_best_hp_eval.json \ seed=0 logging.log_root=output""" assert commands[0] == expected