Skip to content

Commit

Permalink
Add scripts to generate benchmark summary information and computing t…
Browse files Browse the repository at this point in the history
…he probability of improvement.
  • Loading branch information
ernestum committed Oct 16, 2023
1 parent baa1134 commit 000975e
Show file tree
Hide file tree
Showing 3 changed files with 398 additions and 0 deletions.
251 changes: 251 additions & 0 deletions benchmarking/compute_probability_of_improvement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
"""Compute the probability that one algorithm improved over another."""
import argparse
import pathlib
import sys
import warnings
from typing import Dict, List, Optional

import numpy as np
from rliable import library as rly
from rliable import metrics

from imitation.util.sacred_file_parsing import SacredRun, group_runs_by_algo_and_env


def sample_matrix_from_runs_by_env(
runs_by_env: Dict[str, List[SacredRun]],
envs: Optional[List[str]] = None,
) -> np.ndarray:
"""Samples a matrix of scores from the runs for each environment.
Note: when the number of samples for each environment is not equal, the samples
will be truncated to the minimum sample count.
Args:
runs_by_env: A dictionary mapping environment names to lists of runs.
envs: The environments to sample from. If None, all environments are used.
Returns:
A matrix of scores of shape (n_samples, n_envs).
"""
if envs is None:
envs = list(runs_by_env.keys())

sample_counts_by_env = {env: len(runs_by_env[env]) for env in envs}

min_sample_count = min(sample_counts_by_env.values())
if not all(
sample_counts_by_env[env] == sample_counts_by_env[envs[0]] for env in envs
):
warnings.warn(
f"The runs for the environments have different sample counts "
f"{sample_counts_by_env}. "
f"This is not supported by the probability of improvement. Therefore, "
f"samples will be truncated to the minimum sample count of"
f" {min_sample_count}",
)

return np.asarray(
[
[
run["result"]["imit_stats"]["monitor_return_mean"]
for run in runs_by_env[env][:min_sample_count]
]
for env in envs
],
).T


def compute_probability_of_improvement(
runs_by_env: Dict[str, List[SacredRun]],
baseline_runs_by_env: Dict[str, List[SacredRun]],
reps: int,
):
"""Computes the probability of improvement of the runs over the baseline runs.
Args:
runs_by_env: A dictionary mapping environment names to lists of runs.
baseline_runs_by_env: A dictionary mapping environment names to lists of runs.
reps: The number of bootstrap repetitions to use to compute the confidence
interval.
Returns:
A tuple of:
- probability of improvement
- confidence interval
- number of samples per env
- number of baseline samples per env
"""
envs = runs_by_env.keys()
baseline_envs = baseline_runs_by_env.keys()
comparison_envs = sorted(set(envs).intersection(set(baseline_envs)))

run_scores = sample_matrix_from_runs_by_env(runs_by_env, comparison_envs)
baseline_run_scores = sample_matrix_from_runs_by_env(
baseline_runs_by_env,
comparison_envs,
)
samples_per_env = run_scores.shape[0]
baseline_samples_per_env = baseline_run_scores.shape[0]

probabs, error_intervals = rly.get_interval_estimates(
{"baseline_vs_new": (baseline_run_scores, run_scores)},
metrics.probability_of_improvement,
reps=reps,
)
probability_of_improvement = probabs["baseline_vs_new"]
confidence_interval = np.squeeze(error_intervals["baseline_vs_new"])

return (
probability_of_improvement,
confidence_interval,
samples_per_env,
baseline_samples_per_env,
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("runs_dir", type=pathlib.Path)
parser.add_argument("baseline_runs_dir", nargs="?", default=None, type=pathlib.Path)
parser.add_argument("--baseline-algo", type=str)
parser.add_argument("--algo", type=str)
parser.add_argument("--bootstrap-reps", type=int, default=2000)

args = parser.parse_args()

if args.baseline_runs_dir is None:
args.baseline_runs_dir = args.runs_dir

runs_by_algo_and_env = group_runs_by_algo_and_env(
args.runs_dir,
only_completed_runs=True,
)
baseline_runs_by_algo_and_env = group_runs_by_algo_and_env(
args.baseline_runs_dir,
only_completed_runs=True,
)

algos = sorted(runs_by_algo_and_env.keys())
baseline_algos = sorted(baseline_runs_by_algo_and_env.keys())

try:
if len(algos) == 0:
raise ValueError(f"The run directory [{args.runs_dir}] contains no runs.")

if len(baseline_algos) == 0:
raise ValueError(
f"The baseline run directory [{args.baseline_runs_dir}] "
f"contains no runs.",
)

if "algo" not in args is None:
if len(algos) == 1:
args.algo = algos[0]
else:
raise ValueError(
f"The run directory [{args.runs_dir}] contains runs for the "
f"algorithms [{', '.join(algos)}]. Please use the --algo option "
f" to specify which algorithms runs to compare.",
)

if args.baseline_algo is None:
if len(baseline_algos) == 1:
args.baseline_algo = baseline_algos[0]
elif args.algo in baseline_algos:
args.baseline_algo = args.algo
else:
raise ValueError(
f"The baseline run directory [{args.baseline_runs_dir}] contains "
f"runs for the algorithms [{', '.join(baseline_algos)}]. "
f"Please use the --baseline-algo option specify which one to "
f"compare to.",
)

if args.algo not in algos:
raise ValueError(
f"The run directory [{args.runs_dir}] contains runs for the algorithms"
f" [{', '.join(algos)}]. You specified [{args.algo}], for which no"
f" runs can be found in the run directory",
)

if args.baseline_algo not in baseline_algos:
raise ValueError(
f"The baseline run directory [{args.baseline_runs_dir}] contains runs "
f"for the algorithms [{', '.join(baseline_algos)}]. "
f"You specified [{args.baseline_algo}], for which no runs can be found"
f" in the baseline run directory",
)

if (args.algo == args.baseline_algo) and (
args.runs_dir == args.baseline_runs_dir
):
warnings.warn(
"You are comparing two equal sets of runs. "
"This is probably not what you want.",
)

envs = runs_by_algo_and_env[args.algo].keys()
baseline_envs = baseline_runs_by_algo_and_env[args.baseline_algo].keys()

comparison_envs = set(envs).intersection(set(baseline_envs))

if len(comparison_envs) == 0:
raise ValueError(
f"The baseline runs are for the environments "
f"[{', '.join(baseline_envs)}], while the runs are for the "
f"environments [{', '.join(envs)}]. "
f"There is no overlap in the environments of the two run sets, so no "
f"comparison can be made",
)

ignoring_some_envs = len(comparison_envs) < len(envs)
ignoring_some_baseline_envs = len(comparison_envs) < len(baseline_envs)
if ignoring_some_envs or ignoring_some_baseline_envs:
warnings.warn(
f"The baseline runs are for the environments "
f"[{', '.join(baseline_envs)}], "
f"while the runs are for the environments [{', '.join(envs)}]. "
f"The comparison will only be made for the environments "
f"[{', '.join(comparison_envs)}].",
)

except ValueError as e:
print(e)
sys.exit(1)

(
probability_of_improvement,
error_interval,
n_samples,
n_baseline_samples,
) = compute_probability_of_improvement(
runs_by_env=runs_by_algo_and_env[args.algo],
baseline_runs_by_env=baseline_runs_by_algo_and_env[args.baseline_algo],
reps=args.bootstrap_reps,
)

show_path = args.algo == args.baseline_algo
algo_str = f"{args.algo} ({args.runs_dir})" if show_path else args.algo
baseline_algo_str = (
f"{args.baseline_algo} ({args.baseline_runs_dir})"
if show_path
else args.baseline_algo
)

print(
f"Comparison based on {n_samples} samples per environment for {algo_str} and"
f" {n_baseline_samples} samples per environment for {baseline_algo_str}.",
)
print(f"Samples taken in {', '.join(comparison_envs)}")
print()
print(f"Probability of improvement of {algo_str} over {baseline_algo_str}:")
print(
f"{probability_of_improvement:.3f} "
f"({error_interval[0]:.3f}, {error_interval[1]:.3f}, "
f"reps={args.bootstrap_reps})",
)


if __name__ == "__main__":
main()
77 changes: 77 additions & 0 deletions benchmarking/sacred_output_to_markdown_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Generate a markdown summary of the results of a benchmarking run."""
import pathlib
import sys
from collections import Counter

from imitation.util.sacred_file_parsing import (
find_sacred_runs,
group_runs_by_algo_and_env,
)


def print_markdown_summary(path: pathlib.Path):
if not path.exists():
raise NotADirectoryError(f"Path {path} does not exist.")

print("# Benchmark Summary")
runs_by_algo_and_env = group_runs_by_algo_and_env(path)
algos = sorted(runs_by_algo_and_env.keys())

print("## Run status" "")
print("Status | Count")
print("--- | ---")
status_counts = Counter((run["status"] for _, run in find_sacred_runs(path)))
statuses = sorted(list(status_counts))
for status in statuses:
print(f"{status} | {status_counts[status]}")
print()

print("## Detailed Run Status")
print(f"Algorithm | Environment | {' | '.join(sorted(list(status_counts)))}")
print("--- | --- " + " | --- " * len(statuses))
for algo in algos:
envs = sorted(runs_by_algo_and_env[algo].keys())
for env in envs:
status_counts = Counter(
(run["status"] for run in runs_by_algo_and_env[algo][env]),
)
print(
f"{algo} | {env} | "
f"{' | '.join([str(status_counts[status]) for status in statuses])}",
)
print()
print("## Raw Scores")
print()
for algo in algos:
print(f"### {algo.upper()}")
print("Environment | Scores | Expert Scores")
print("--- | --- | ---")
envs = sorted(runs_by_algo_and_env[algo].keys())
for env in envs:
completed_runs = [
run
for run in runs_by_algo_and_env[algo][env]
if run["status"] == "COMPLETED"
]
algo_scores = [
run["result"]["imit_stats"]["monitor_return_mean"]
for run in completed_runs
]
expert_scores = [
run["result"]["expert_stats"]["monitor_return_mean"]
for run in completed_runs
]
print(
f"{env} | "
f"{', '.join([f'{score:.2f}' for score in algo_scores])} | "
f"{', '.join([f'{score:.2f}' for score in expert_scores])}",
)
print()


if __name__ == "__main__":
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <path to sacred run folder>")
sys.exit(1)

print_markdown_summary(pathlib.Path(sys.argv[1]))
Loading

0 comments on commit 000975e

Please sign in to comment.