Skip to content

Commit

Permalink
Introduce a dataclass for the return type of compute_probability_of_i…
Browse files Browse the repository at this point in the history
…mprovement.
  • Loading branch information
ernestum committed Oct 17, 2023
1 parent 7f4eb8c commit 7bc8dbe
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions benchmarking/compute_probability_of_improvement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Compute the probability that one algorithm improved over another."""
import argparse
import dataclasses
import pathlib
import warnings
from typing import Dict, List, Optional
Expand Down Expand Up @@ -55,6 +56,24 @@ def sample_matrix_from_runs_by_env(
).T


@dataclasses.dataclass
class ProbabilityOfImprovementResult:
"""The result of the probability of improvement computation.
Attributes:
probability_of_improvement: The probability of improvement.
confidence_interval: The confidence interval of the probability of improvement.
samples_per_env: The number of samples per environment after truncation.
baseline_samples_per_env: The number of baseline samples per environment after
truncation.
"""

probability_of_improvement: float
confidence_interval: np.ndarray
samples_per_env: int
baseline_samples_per_env: int


def compute_probability_of_improvement(
runs_by_env: Dict[str, List[SacredRun]],
baseline_runs_by_env: Dict[str, List[SacredRun]],
Expand All @@ -69,11 +88,9 @@ def compute_probability_of_improvement(
interval.
Returns:
A tuple of:
- probability of improvement
- confidence interval
- number of samples per env
- number of baseline samples per env
A ProbabilityOfImprovementResult object that contains the probability with its
confidence interval and some information about the number of samples that has
effectively been used.
"""
envs = set(runs_by_env.keys())
baseline_envs = set(baseline_runs_by_env.keys())
Expand Down Expand Up @@ -103,7 +120,7 @@ def compute_probability_of_improvement(
probability_of_improvement = probabs["baseline_vs_new"]
confidence_interval = np.squeeze(error_intervals["baseline_vs_new"])

return (
return ProbabilityOfImprovementResult(
probability_of_improvement,
confidence_interval,
samples_per_env,
Expand Down Expand Up @@ -251,12 +268,7 @@ def main():
f"[{', '.join(comparison_envs)}].",
)

(
probability_of_improvement,
error_interval,
n_samples,
n_baseline_samples,
) = compute_probability_of_improvement(
probability_of_improvement_result = compute_probability_of_improvement(
runs_by_env=runs_by_algo_and_env[args.algo],
baseline_runs_by_env=baseline_runs_by_algo_and_env[args.baseline_algo],
reps=args.bootstrap_reps,
Expand All @@ -271,15 +283,18 @@ def main():
)

print(
f"Comparison based on {n_samples} samples per environment for {algo_str} and"
f" {n_baseline_samples} samples per environment for {baseline_algo_str}.",
f"Comparison based on {probability_of_improvement_result.n_samples} samples per "
f"environment for {algo_str} and "
f"{probability_of_improvement_result.n_baseline_samples} samples per "
f"environment for {baseline_algo_str}.",
)
print(f"Samples taken in {', '.join(comparison_envs)}")
print()
print(f"Probability of improvement of {algo_str} over {baseline_algo_str}:")
print(
f"{probability_of_improvement:.3f} "
f"({error_interval[0]:.3f}, {error_interval[1]:.3f}, "
f"{probability_of_improvement_result.probability_of_improvement:.3f} "
f"({probability_of_improvement_result.confidence_interval[0]:.3f}, "
f"{probability_of_improvement_result.confidence_interval[1]:.3f}, "
f"reps={args.bootstrap_reps:,})",
)

Expand Down

0 comments on commit 7bc8dbe

Please sign in to comment.