From 691e75945579cd8aaea3a133ffd1178bb978a450 Mon Sep 17 00:00:00 2001 From: taufeeque9 <9taufeeque9@gmail.com> Date: Wed, 4 Oct 2023 07:49:02 +0530 Subject: [PATCH] Fix test errors --- src/imitation/scripts/tuning.py | 6 +++--- tests/test_benchmarking.py | 34 +++++++++++++-------------------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/src/imitation/scripts/tuning.py b/src/imitation/scripts/tuning.py index a605a206a..24095b1de 100644 --- a/src/imitation/scripts/tuning.py +++ b/src/imitation/scripts/tuning.py @@ -2,7 +2,7 @@ import copy import pathlib -from typing import Any, Dict +from typing import Dict import numpy as np import ray @@ -16,7 +16,7 @@ @tuning_ex.main def tune( - parallel_run_config: Dict[str, Any], + parallel_run_config, eval_best_trial_resource_multiplier: int = 1, num_eval_seeds: int = 5, ) -> None: @@ -128,7 +128,7 @@ def evaluate_trial( trial: ray.tune.experiment.Trial, num_eval_seeds: int, run_name: str, - parallel_run_config: Dict[str, Any], + parallel_run_config, resources_per_trial: Dict[str, int], return_key: str, print_return: bool = False, diff --git a/tests/test_benchmarking.py b/tests/test_benchmarking.py index 18d4f12cf..0a93943ef 100644 --- a/tests/test_benchmarking.py +++ b/tests/test_benchmarking.py @@ -1,11 +1,9 @@ """Tests for config files in benchmarking/ folder.""" import pathlib -import subprocess -import sys import pytest -from imitation.scripts import train_adversarial, train_imitation +from imitation.scripts import train_adversarial, train_imitation, tuning THIS_DIR = pathlib.Path(__file__).absolute().parent BENCHMARKING_DIR = THIS_DIR.parent / "benchmarking" @@ -48,26 +46,20 @@ def test_benchmarks_print_config_succeeds(algorithm: str, environment: str): assert run.status == "COMPLETED" +@pytest.mark.parametrize("environment", ENVIRONMENTS) @pytest.mark.parametrize("algorithm", ALGORITHMS) -def test_tuning_print_config_succeeds(algorithm: str): +def test_tuning_print_config_succeeds(algorithm: str, environment: str): # We test the configs using the print_config command, # because running the configs requires MuJoCo. # Requiring MuJoCo to run the tests adds too much complexity. - - # We need to use sys.executable, not just "python", on Windows as - # subprocess.call ignores PATH (unless shell=True) so runs a - # system-wide Python interpreter outside of our venv. See: - # https://stackoverflow.com/questions/5658622/ - tuning_path = str(BENCHMARKING_DIR / "tuning.py") - env = 'parallel_run_config.base_named_configs=["seals_cartpole"]' - exit_code = subprocess.call( - [ - sys.executable, - tuning_path, - "print_config", - "with", - f"{algorithm}", - env, - ], + experiment = tuning.tuning_ex + run = experiment.run( + command_name="print_config", + named_configs=[algorithm], + config_updates=dict( + parallel_run_config=dict( + base_named_configs=[environment], + ), + ), ) - assert exit_code == 0 + assert run.status == "COMPLETED"