Skip to content

Commit

Permalink
Incorporate review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 committed Oct 4, 2023
1 parent 71f6c92 commit 747ad32
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 32 deletions.
11 changes: 4 additions & 7 deletions src/imitation/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,10 @@ def analyze_imitation(
Returns:
The DataFrame generated from the Sacred logs.
"""
if table_verbosity == 3:
# Get column names for which we have get value using make_entry_fn
# These are same across Level 2 & 3. In Level 3, we additionally add remaining
# config columns.
table_entry_fns_subset = _get_table_entry_fns_subset(2)
else:
table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity)
# Get column names for which we have get value using make_entry_fn
# These are same across Level 2 & 3. In Level 3, we additionally add remaining
# config columns.
table_entry_fns_subset = _get_table_entry_fns_subset(min(table_verbosity, 2))

output_table = pd.DataFrame()
for sd in _gather_sacred_dicts():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,24 @@ def bc():
search_space={
"config_updates": {
"bc": dict(
batch_size=tune.choice([8, 16, 32, 64]),
batch_size=tune.choice([8]),
l2_weight=tune.loguniform(1e-6, 1e-2), # L2 regularization weight
optimizer_kwargs=dict(
lr=tune.loguniform(1e-5, 1e-2),
),
train_kwargs=dict(
n_epochs=tune.choice([1, 5, 10, 20]),
n_epochs=tune.choice([1]),
),
),
},
"command_name": "bc",
},
num_samples=64,
repeat=3,
num_samples=2,
repeat=2,
resources_per_trial=dict(cpu=1),
)

num_eval_seeds = 5
num_eval_seeds = 1
eval_best_trial_resource_multiplier = 1


Expand Down Expand Up @@ -117,7 +117,7 @@ def dagger():
def gail():
parallel_run_config = dict(
sacred_ex_name="train_adversarial",
run_name="gail_tuning_hc",
run_name="gail_tuning",
base_named_configs=["logging.wandb_logging"],
base_config_updates={
"environment": {"num_vec": 1},
Expand Down
12 changes: 2 additions & 10 deletions src/imitation/scripts/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import collections.abc
import copy
import glob
import pathlib
from typing import Any, Callable, Dict, Mapping, Sequence

Expand Down Expand Up @@ -37,8 +36,8 @@ def parallel(
to `upload_dir` if that argument is provided in `tune_run_kwargs`.
Args:
sacred_ex_name: The Sacred experiment to tune. Either "train_rl" or
"train_imitation" or "train_adversarial" or "train_preference_comparisons".
sacred_ex_name: The Sacred experiment to tune. Either "train_rl",
"train_imitation", "train_adversarial" or "train_preference_comparisons".
run_name: A name describing this parallelizing experiment.
This argument is also passed to `ray.tune.run` as the `name` argument.
It is also saved in 'sacred/run.json' of each inner Sacred experiment
Expand Down Expand Up @@ -132,14 +131,7 @@ def parallel(

try:
if experiment_checkpoint_path:
# load experiment analysis results
result = ray.tune.ExperimentAnalysis(experiment_checkpoint_path)
result._load_checkpoints_from_latest(
glob.glob(experiment_checkpoint_path + "/experiment_state*.json"),
)
# update result.trials using all the experiment_state json files
result.trials = None
result.fetch_trial_dataframes()
else:
result = ray.tune.run(
trainable,
Expand Down
26 changes: 18 additions & 8 deletions benchmarking/tuning.py → src/imitation/scripts/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from pandas.api import types as pd_types
from ray.tune.search import optuna
from sacred.observers import FileStorageObserver
from tuning_config import parallel_ex, tuning_ex

from imitation.scripts.config.parallel import parallel_ex
from imitation.scripts.config.tuning import tuning_ex


@tuning_ex.main
Expand All @@ -18,10 +20,15 @@ def tune(
eval_best_trial_resource_multiplier: int = 1,
num_eval_seeds: int = 5,
) -> None:
"""Tune hyperparameters of imitation algorithms using parallel script.
"""Tune hyperparameters of imitation algorithms using the parallel script.
The parallel script is called twice in this function. The first call is to
tune the hyperparameters. The second call is to evaluate the best trial on
a separate set of seeds.
Args:
parallel_run_config: Dictionary of arguments to pass to the parallel script.
This is used to define the search space for tuning the hyperparameters.
eval_best_trial_resource_multiplier: Factor by which to multiply the
number of cpus per trial in `resources_per_trial`. This is useful for
allocating more resources per trial to the evaluation trials than the
Expand All @@ -35,10 +42,8 @@ def tune(
"""
updated_parallel_run_config = copy.deepcopy(parallel_run_config)
search_alg = optuna.OptunaSearch()
if "tune_run_kwargs" in updated_parallel_run_config:
updated_parallel_run_config["tune_run_kwargs"]["search_alg"] = search_alg
else:
updated_parallel_run_config["tune_run_kwargs"] = dict(search_alg=search_alg)
tune_run_kwargs = updated_parallel_run_config.setdefault("tune_run_kwargs", dict())
tune_run_kwargs["search_alg"] = search_alg
run = parallel_ex.run(config_updates=updated_parallel_run_config)
experiment_analysis = run.result
if not experiment_analysis.trials:
Expand Down Expand Up @@ -93,9 +98,13 @@ def find_best_trial(
if pd_types.is_object_dtype(df[col]):
df[col] = df[col].astype("str")
# group into separate HP configs
grp_keys = [c for c in df.columns if c.startswith("config") and "seed" not in c]
grp_keys = [c for c in df.columns if c.startswith("config")]
grp_keys = [c for c in grp_keys if "seed" not in c and "trial_index" not in c]
grps = df.groupby(grp_keys)
# store mean return of runs across all seeds in a group
# the transform method is applied to get the mean return for every trial
# instead of for every group. So every trial in a group will have the same
# mean return column.
df["mean_return"] = grps[return_key].transform(lambda x: x.mean())
best_config_df = df[df["mean_return"] == df["mean_return"].max()]
row = best_config_df.iloc[0]
Expand Down Expand Up @@ -149,10 +158,11 @@ def evaluate_trial(
num_samples=1,
search_space=config,
resources_per_trial=resources_per_trial,
search_alg=None,
repeat=1,
experiment_checkpoint_path="",
)
# required for grid search
eval_config_updates["tune_run_kwargs"].update(search_alg=None)
eval_run = parallel_ex.run(config_updates=eval_config_updates)
eval_result = eval_run.result
returns = eval_result.results_df[return_key].to_numpy()
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def test_analyze_imitation(tmpdir: str, run_names: List[str], run_sacred_fn):
assert run.status == "COMPLETED"

# Check that analyze script finds the correct number of logs.
def check(run_name: Optional[str], count: int, table_verbosity=1) -> None:
def check(run_name: Optional[str], count: int, table_verbosity: int = 1) -> None:
run = analyze.analysis_ex.run(
command_name="analyze_imitation",
config_updates=dict(
Expand Down

0 comments on commit 747ad32

Please sign in to comment.