Skip to content

Commit

Permalink
Add mean/std/ICM and confidence intervals to the markdown summary scr…
Browse files Browse the repository at this point in the history
…ipt.
  • Loading branch information
ernestum committed Oct 17, 2023
1 parent c53888c commit 057ce8b
Showing 1 changed file with 124 additions and 44 deletions.
168 changes: 124 additions & 44 deletions benchmarking/sacred_output_to_markdown_summary.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,157 @@
"""Generate a markdown summary of the results of a benchmarking run."""
import argparse
import pathlib
import sys
from collections import Counter
from functools import lru_cache
from typing import Generator, Sequence, cast

import datasets
import numpy as np
from huggingface_sb3 import EnvironmentName
from rliable import library as rly
from rliable import metrics

from imitation.data import rollout, types
from imitation.data.huggingface_utils import TrajectoryDatasetSequence
from imitation.util.sacred_file_parsing import (
find_sacred_runs,
group_runs_by_algo_and_env,
)


def print_markdown_summary(path: pathlib.Path):
@lru_cache(maxsize=None)
def get_random_agent_score(env: str):
stats = rollout.rollout_stats(
cast(
Sequence[types.TrajectoryWithRew],
TrajectoryDatasetSequence(
datasets.load_dataset(
f"HumanCompatibleAI/random-{EnvironmentName(env)}",
)["train"],
),
),
)
return stats["monitor_return_mean"]


def print_markdown_summary(path: pathlib.Path) -> Generator[str, None, None]:
if not path.exists():
raise NotADirectoryError(f"Path {path} does not exist.")

print("# Benchmark Summary")
yield "# Benchmark Summary"
yield ""
yield (
f"This is a summary of the sacred runs in `{path}` generated by "
f"`sacred_output_to_markdown_summary.py`."
)

runs_by_algo_and_env = group_runs_by_algo_and_env(path)
algos = sorted(runs_by_algo_and_env.keys())

print("## Run status" "")
print("Status | Count")
print("--- | ---")
status_counts = Counter((run["status"] for _, run in find_sacred_runs(path)))
statuses = sorted(list(status_counts))
for status in statuses:
print(f"{status} | {status_counts[status]}")
print()
# Note: we only print the status section if there are multiple statuses
if not (len(statuses) == 1 and statuses[0] == "COMPLETED"):
yield "## Run status" ""
yield "Status | Count"
yield "--- | ---"
for status in statuses:
yield f"{status} | {status_counts[status]}"
yield ""

yield "## Detailed Run Status"
yield f"Algorithm | Environment | {' | '.join(statuses)}"
yield "--- | --- " + " | --- " * len(statuses)
for algo in algos:
envs = sorted(runs_by_algo_and_env[algo].keys())
for env in envs:
status_counts = Counter(
(run["status"] for run in runs_by_algo_and_env[algo][env]),
)
yield (
f"{algo} | {env} | "
f"{' | '.join([str(status_counts[status]) for status in statuses])}"
)

yield "## Scores"
yield ""
yield (
"The scores are normalized based on the performance of a random agent as the"
" baseline and the expert as the maximum possible score as explained "
"[in this blog post](https://araffin.github.io/post/rliable/):"
)
yield "> `(score - random_score) / (expert_score - random_score)`"
yield ""
yield (
"Aggregate scores and confidence intervals are computed using the "
"[rliable library](https://agarwl.github.io/rliable/)."
)

print("## Detailed Run Status")
print(f"Algorithm | Environment | {' | '.join(sorted(list(status_counts)))}")
print("--- | --- " + " | --- " * len(statuses))
for algo in algos:
envs = sorted(runs_by_algo_and_env[algo].keys())
for env in envs:
status_counts = Counter(
(run["status"] for run in runs_by_algo_and_env[algo][env]),
)
print(
f"{algo} | {env} | "
f"{' | '.join([str(status_counts[status]) for status in statuses])}",
)
print()
print("## Raw Scores")
print()
for algo in algos:
print(f"### {algo.upper()}")
print("Environment | Scores | Expert Scores")
print("--- | --- | ---")
yield f"### {algo.upper()}"
yield "Environment | Score (mean/std)| Normalized Score (mean/std) | N"
yield " --- | --- | --- | --- "
envs = sorted(runs_by_algo_and_env[algo].keys())
accumulated_normalized_scores = []
for env in envs:
completed_runs = [
run
for run in runs_by_algo_and_env[algo][env]
if run["status"] == "COMPLETED"
]
algo_scores = [
scores = [
run["result"]["imit_stats"]["monitor_return_mean"]
for run in completed_runs
for run in runs_by_algo_and_env[algo][env]
]
expert_scores = [
run["result"]["expert_stats"]["monitor_return_mean"]
for run in completed_runs
for run in runs_by_algo_and_env[algo][env]
]
print(
random_score = get_random_agent_score(env)
normalized_score = [
(score - random_score) / (expert_score - random_score)
for score, expert_score in zip(scores, expert_scores)
]
accumulated_normalized_scores.append(normalized_score)

yield (
f"{env} | "
f"{', '.join([f'{score:.2f}' for score in algo_scores])} | "
f"{', '.join([f'{score:.2f}' for score in expert_scores])}",
f"{np.mean(scores):.3f} / {np.std(scores):.3f} | "
f"{np.mean(normalized_score):.3f} / {np.std(normalized_score):.3f} | "
f"{len(scores)}"
)
print()

aggregate_scores, aggregate_score_cis = rly.get_interval_estimates(
{"normalized_score": np.asarray(accumulated_normalized_scores).T},
lambda x: np.array([metrics.aggregate_mean(x), metrics.aggregate_iqm(x)]),
reps=1000,
)
yield ""
yield "#### Aggregate Normalized scores"

yield "Metric | Value | 95% CI"
yield " --- | --- | --- "
yield (
f"Mean | "
f"{aggregate_scores['normalized_score'][0]:.3f} | "
f"[{aggregate_score_cis['normalized_score'][0][0]:.3f}, "
f"{aggregate_score_cis['normalized_score'][0][1]:.3f}]"
)
yield (
f"IQM | "
f"{aggregate_scores['normalized_score'][1]:.3f} | "
f"[{aggregate_score_cis['normalized_score'][1][0]:.3f}, "
f"{aggregate_score_cis['normalized_score'][1][1]:.3f}]"
)
yield ""


if __name__ == "__main__":
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <path to sacred run folder>")
sys.exit(1)
parser = argparse.ArgumentParser(
description="Generate a markdown summary of the results of a benchmarking run.",
)
parser.add_argument("path", type=pathlib.Path)
parser.add_argument("--output", type=pathlib.Path, default="summary.md")

args = parser.parse_args()

print_markdown_summary(pathlib.Path(sys.argv[1]))
with open(args.output, "w") as fh:
for line in print_markdown_summary(pathlib.Path(args.path)):
fh.write(line)
fh.write("\n")
fh.flush()

0 comments on commit 057ce8b

Please sign in to comment.