Skip to content

Commit

Permalink
Fix test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 committed Oct 4, 2023
1 parent 747ad32 commit 691e759
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 24 deletions.
6 changes: 3 additions & 3 deletions src/imitation/scripts/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import copy
import pathlib
from typing import Any, Dict
from typing import Dict

import numpy as np
import ray
Expand All @@ -16,7 +16,7 @@

@tuning_ex.main
def tune(
parallel_run_config: Dict[str, Any],
parallel_run_config,
eval_best_trial_resource_multiplier: int = 1,
num_eval_seeds: int = 5,
) -> None:
Expand Down Expand Up @@ -128,7 +128,7 @@ def evaluate_trial(
trial: ray.tune.experiment.Trial,
num_eval_seeds: int,
run_name: str,
parallel_run_config: Dict[str, Any],
parallel_run_config,
resources_per_trial: Dict[str, int],
return_key: str,
print_return: bool = False,
Expand Down
34 changes: 13 additions & 21 deletions tests/test_benchmarking.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Tests for config files in benchmarking/ folder."""
import pathlib
import subprocess
import sys

import pytest

from imitation.scripts import train_adversarial, train_imitation
from imitation.scripts import train_adversarial, train_imitation, tuning

THIS_DIR = pathlib.Path(__file__).absolute().parent
BENCHMARKING_DIR = THIS_DIR.parent / "benchmarking"
Expand Down Expand Up @@ -48,26 +46,20 @@ def test_benchmarks_print_config_succeeds(algorithm: str, environment: str):
assert run.status == "COMPLETED"


@pytest.mark.parametrize("environment", ENVIRONMENTS)
@pytest.mark.parametrize("algorithm", ALGORITHMS)
def test_tuning_print_config_succeeds(algorithm: str):
def test_tuning_print_config_succeeds(algorithm: str, environment: str):
# We test the configs using the print_config command,
# because running the configs requires MuJoCo.
# Requiring MuJoCo to run the tests adds too much complexity.

# We need to use sys.executable, not just "python", on Windows as
# subprocess.call ignores PATH (unless shell=True) so runs a
# system-wide Python interpreter outside of our venv. See:
# https://stackoverflow.com/questions/5658622/
tuning_path = str(BENCHMARKING_DIR / "tuning.py")
env = 'parallel_run_config.base_named_configs=["seals_cartpole"]'
exit_code = subprocess.call(
[
sys.executable,
tuning_path,
"print_config",
"with",
f"{algorithm}",
env,
],
experiment = tuning.tuning_ex
run = experiment.run(
command_name="print_config",
named_configs=[algorithm],
config_updates=dict(
parallel_run_config=dict(
base_named_configs=[environment],
),
),
)
assert exit_code == 0
assert run.status == "COMPLETED"

0 comments on commit 691e759

Please sign in to comment.