From 747ad32787e56a6939f6064eedb0cda8a67c3b1a Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Wed, 4 Oct 2023 05:58:13 +0530 Subject: [PATCH] Incorporate review comments --- src/imitation/scripts/analyze.py | 11 +++----- .../imitation/scripts/config/tuning.py | 12 ++++----- src/imitation/scripts/parallel.py | 12 ++------- .../imitation/scripts}/tuning.py | 26 +++++++++++++------ tests/scripts/test_scripts.py | 2 +- 5 files changed, 31 insertions(+), 32 deletions(-) rename benchmarking/tuning_config.py => src/imitation/scripts/config/tuning.py (97%) rename {benchmarking => src/imitation/scripts}/tuning.py (85%) diff --git a/src/imitation/scripts/analyze.py b/src/imitation/scripts/analyze.py index 96b34bd6e..b63538f6d 100644 --- a/src/imitation/scripts/analyze.py +++ b/src/imitation/scripts/analyze.py @@ -268,13 +268,10 @@ def analyze_imitation( Returns: The DataFrame generated from the Sacred logs. """ - if table_verbosity == 3: - # Get column names for which we have get value using make_entry_fn - # These are same across Level 2 & 3. In Level 3, we additionally add remaining - # config columns. - table_entry_fns_subset = _get_table_entry_fns_subset(2) - else: - table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity) + # Get column names for which we have get value using make_entry_fn + # These are same across Level 2 & 3. In Level 3, we additionally add remaining + # config columns. + table_entry_fns_subset = _get_table_entry_fns_subset(min(table_verbosity, 2)) output_table = pd.DataFrame() for sd in _gather_sacred_dicts(): diff --git a/benchmarking/tuning_config.py b/src/imitation/scripts/config/tuning.py similarity index 97% rename from benchmarking/tuning_config.py rename to src/imitation/scripts/config/tuning.py index 239537406..07161d04c 100644 --- a/benchmarking/tuning_config.py +++ b/src/imitation/scripts/config/tuning.py @@ -49,24 +49,24 @@ def bc(): search_space={ "config_updates": { "bc": dict( - batch_size=tune.choice([8, 16, 32, 64]), + batch_size=tune.choice([8]), l2_weight=tune.loguniform(1e-6, 1e-2), # L2 regularization weight optimizer_kwargs=dict( lr=tune.loguniform(1e-5, 1e-2), ), train_kwargs=dict( - n_epochs=tune.choice([1, 5, 10, 20]), + n_epochs=tune.choice([1]), ), ), }, "command_name": "bc", }, - num_samples=64, - repeat=3, + num_samples=2, + repeat=2, resources_per_trial=dict(cpu=1), ) - num_eval_seeds = 5 + num_eval_seeds = 1 eval_best_trial_resource_multiplier = 1 @@ -117,7 +117,7 @@ def dagger(): def gail(): parallel_run_config = dict( sacred_ex_name="train_adversarial", - run_name="gail_tuning_hc", + run_name="gail_tuning", base_named_configs=["logging.wandb_logging"], base_config_updates={ "environment": {"num_vec": 1}, diff --git a/src/imitation/scripts/parallel.py b/src/imitation/scripts/parallel.py index 38881ee2b..d5e5e2378 100644 --- a/src/imitation/scripts/parallel.py +++ b/src/imitation/scripts/parallel.py @@ -2,7 +2,6 @@ import collections.abc import copy -import glob import pathlib from typing import Any, Callable, Dict, Mapping, Sequence @@ -37,8 +36,8 @@ def parallel( to `upload_dir` if that argument is provided in `tune_run_kwargs`. Args: - sacred_ex_name: The Sacred experiment to tune. Either "train_rl" or - "train_imitation" or "train_adversarial" or "train_preference_comparisons". + sacred_ex_name: The Sacred experiment to tune. Either "train_rl", + "train_imitation", "train_adversarial" or "train_preference_comparisons". run_name: A name describing this parallelizing experiment. This argument is also passed to `ray.tune.run` as the `name` argument. It is also saved in 'sacred/run.json' of each inner Sacred experiment @@ -132,14 +131,7 @@ def parallel( try: if experiment_checkpoint_path: - # load experiment analysis results result = ray.tune.ExperimentAnalysis(experiment_checkpoint_path) - result._load_checkpoints_from_latest( - glob.glob(experiment_checkpoint_path + "/experiment_state*.json"), - ) - # update result.trials using all the experiment_state json files - result.trials = None - result.fetch_trial_dataframes() else: result = ray.tune.run( trainable, diff --git a/benchmarking/tuning.py b/src/imitation/scripts/tuning.py similarity index 85% rename from benchmarking/tuning.py rename to src/imitation/scripts/tuning.py index 9c3f52498..a605a206a 100644 --- a/benchmarking/tuning.py +++ b/src/imitation/scripts/tuning.py @@ -9,7 +9,9 @@ from pandas.api import types as pd_types from ray.tune.search import optuna from sacred.observers import FileStorageObserver -from tuning_config import parallel_ex, tuning_ex + +from imitation.scripts.config.parallel import parallel_ex +from imitation.scripts.config.tuning import tuning_ex @tuning_ex.main @@ -18,10 +20,15 @@ def tune( eval_best_trial_resource_multiplier: int = 1, num_eval_seeds: int = 5, ) -> None: - """Tune hyperparameters of imitation algorithms using parallel script. + """Tune hyperparameters of imitation algorithms using the parallel script. + + The parallel script is called twice in this function. The first call is to + tune the hyperparameters. The second call is to evaluate the best trial on + a separate set of seeds. Args: parallel_run_config: Dictionary of arguments to pass to the parallel script. + This is used to define the search space for tuning the hyperparameters. eval_best_trial_resource_multiplier: Factor by which to multiply the number of cpus per trial in `resources_per_trial`. This is useful for allocating more resources per trial to the evaluation trials than the @@ -35,10 +42,8 @@ def tune( """ updated_parallel_run_config = copy.deepcopy(parallel_run_config) search_alg = optuna.OptunaSearch() - if "tune_run_kwargs" in updated_parallel_run_config: - updated_parallel_run_config["tune_run_kwargs"]["search_alg"] = search_alg - else: - updated_parallel_run_config["tune_run_kwargs"] = dict(search_alg=search_alg) + tune_run_kwargs = updated_parallel_run_config.setdefault("tune_run_kwargs", dict()) + tune_run_kwargs["search_alg"] = search_alg run = parallel_ex.run(config_updates=updated_parallel_run_config) experiment_analysis = run.result if not experiment_analysis.trials: @@ -93,9 +98,13 @@ def find_best_trial( if pd_types.is_object_dtype(df[col]): df[col] = df[col].astype("str") # group into separate HP configs - grp_keys = [c for c in df.columns if c.startswith("config") and "seed" not in c] + grp_keys = [c for c in df.columns if c.startswith("config")] + grp_keys = [c for c in grp_keys if "seed" not in c and "trial_index" not in c] grps = df.groupby(grp_keys) # store mean return of runs across all seeds in a group + # the transform method is applied to get the mean return for every trial + # instead of for every group. So every trial in a group will have the same + # mean return column. df["mean_return"] = grps[return_key].transform(lambda x: x.mean()) best_config_df = df[df["mean_return"] == df["mean_return"].max()] row = best_config_df.iloc[0] @@ -149,10 +158,11 @@ def evaluate_trial( num_samples=1, search_space=config, resources_per_trial=resources_per_trial, - search_alg=None, repeat=1, experiment_checkpoint_path="", ) + # required for grid search + eval_config_updates["tune_run_kwargs"].update(search_alg=None) eval_run = parallel_ex.run(config_updates=eval_config_updates) eval_result = eval_run.result returns = eval_result.results_df[return_key].to_numpy() diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index a44639cef..5fc2f122d 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -981,7 +981,7 @@ def test_analyze_imitation(tmpdir: str, run_names: List[str], run_sacred_fn): assert run.status == "COMPLETED" # Check that analyze script finds the correct number of logs. - def check(run_name: Optional[str], count: int, table_verbosity=1) -> None: + def check(run_name: Optional[str], count: int, table_verbosity: int = 1) -> None: run = analyze.analysis_ex.run( command_name="analyze_imitation", config_updates=dict(