Skip to content

Commit

Permalink
handle command names
Browse files Browse the repository at this point in the history
  • Loading branch information
timbauman committed May 11, 2023
1 parent 27b9d82 commit 761a7e4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 57 deletions.
22 changes: 13 additions & 9 deletions src/imitation/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,20 @@ def _get_exp_command(sd: sacred_util.SacredDicts) -> str:
def _get_algo_name(sd: sacred_util.SacredDicts) -> str:
exp_command = _get_exp_command(sd)

if exp_command == "gail":
return "GAIL"
elif exp_command == "airl":
return "AIRL"
elif exp_command == "train_bc":
return "BC"
elif exp_command == "train_dagger":
return "DAgger"
COMMAND_TO_ALGO = {
"train_bc": "BC",
"bc": "BC",
"train_dagger": "DAgger",
"dagger": "DAgger",
"gail": "GAIL",
"airl": "AIRL",
"preference_comparisons": "Preference Comparisons",
}

if exp_command.lower() in COMMAND_TO_ALGO.keys():
return COMMAND_TO_ALGO[exp_command.lower()]
else:
return f"??exp_command={exp_command}"
raise ValueError(f"Unknown command: {exp_command}")


def _return_summaries(sd: sacred_util.SacredDicts) -> dict:
Expand Down
74 changes: 29 additions & 45 deletions src/imitation/scripts/compare_to_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,26 @@
experiment returns, as reported by `imitation.scripts.analyze`.
"""

import numpy as np
import pandas as pd
import scipy

from imitation.data import types


def compare_results_to_baseline(results_file: types.AnyPath) -> pd.DataFrame:
def compare_results_to_baseline(results_filename: types.AnyPath) -> pd.DataFrame:
"""Compare benchmark results to baseline results.
Args:
results_file: Path to a CSV file containing experiment results.
results_filename: Path to a CSV file containing experiment results.
Returns:
A string containing a table of p-values comparing the experiment results to
the baseline results.
"""
data = pd.read_csv(results_file)
data["imit_return"] = data["imit_return_summary"].apply(
lambda x: float(x.split(" ")[0]),
)
summary = (
data[["algo", "env_name", "imit_return"]]
.groupby(["algo", "env_name"])
.describe()
)
summary.columns = summary.columns.get_level_values(1)
summary = summary.reset_index()

# Table 2 (https://arxiv.org/pdf/2211.11972.pdf)
# todo: store results in this repo outside this file
baseline = pd.DataFrame.from_records(
[
{
"algo": "??exp_command=bc",
"env_name": "seals/Ant-v0",
"mean": 1953,
"margin": 123,
},
{
"algo": "??exp_command=bc",
"env_name": "seals/HalfCheetah-v0",
"mean": 3446,
"margin": 130,
},
],
)
baseline["count"] = 5
baseline["confidence_level"] = 0.95
# Back out the standard deviation from the margin of error.
results_summary = load_and_summarize_csv(results_filename)
baseline_summary = load_and_summarize_csv("baseline.csv")

t_score = scipy.stats.t.ppf(
1 - ((1 - baseline["confidence_level"]) / 2),
baseline["count"] - 1,
)
std_err = baseline["margin"] / t_score

baseline["std"] = std_err * np.sqrt(baseline["count"])

comparison = pd.merge(summary, baseline, on=["algo", "env_name"])
comparison = pd.merge(results_summary, baseline_summary, on=["algo", "env_name"])

comparison["pvalue"] = scipy.stats.ttest_ind_from_stats(
comparison["mean_x"],
Expand All @@ -88,6 +48,30 @@ def compare_results_to_baseline(results_file: types.AnyPath) -> pd.DataFrame:
return comparison[["algo", "env_name", "pvalue"]]


def load_and_summarize_csv(results_filename: types.AnyPath) -> pd.DataFrame:
"""Load a results CSV file and summarize the statistics.
Args:
results_filename: Path to a CSV file containing experiment results.
Returns:
A DataFrame containing the mean and standard deviation of the experiment
returns, grouped by algorithm and environment.
"""
data = pd.read_csv(results_filename)
data["imit_return"] = data["imit_return_summary"].apply(
lambda x: float(x.split(" ")[0]),
)
summary = (
data[["algo", "env_name", "imit_return"]]
.groupby(["algo", "env_name"])
.describe()
)
summary.columns = summary.columns.get_level_values(1)
summary = summary.reset_index()
return summary


def main() -> None: # pragma: no cover
"""Run the script."""
import sys
Expand Down
6 changes: 3 additions & 3 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,8 +1110,8 @@ def test_compare_to_baseline_p_values(
comparison.to_csv(tmpfile)

assert (
compare_to_baseline.compare_results_to_baseline(results_file=tmpfile)["pvalue"][
0
]
compare_to_baseline.compare_results_to_baseline(results_filename=tmpfile)[
"pvalue"
][0]
< p_value
)

0 comments on commit 761a7e4

Please sign in to comment.