From 7b8b4bf446d10a768c86895df31d78e150717680 Mon Sep 17 00:00:00 2001 From: Ziyue Wang <38305983+ZiyueWang25@users.noreply.github.com> Date: Sat, 7 Oct 2023 20:43:22 -0700 Subject: [PATCH] Upgrade pytype (#801) * Upgrade pytype and remove workaround for old versions * new pytype need input directory or file * fix np.dtype * ignore typed-dict-error * context manager related fix * keep pytype checking more failures * Move pytype config to pyproject.toml * Use inputs specified in pyproject.toml * Fix lint * Fix undefined variable * Fix end of file * Fix typo --------- Co-authored-by: Adam Gleave --- .pre-commit-config.yaml | 2 +- pyproject.toml | 9 +++++++++ setup.cfg | 8 -------- setup.py | 2 +- src/imitation/scripts/eval_policy.py | 2 +- src/imitation/scripts/ingredients/demonstrations.py | 6 +++--- src/imitation/scripts/train_adversarial.py | 2 +- src/imitation/scripts/train_imitation.py | 6 +++--- .../scripts/train_preference_comparisons.py | 2 +- src/imitation/scripts/train_rl.py | 4 +++- tests/algorithms/test_base.py | 2 +- tests/data/test_buffer.py | 12 ++++++------ 12 files changed, 30 insertions(+), 27 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f07fd58d9..4d266d302 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -78,7 +78,7 @@ repos: name: pytype language: system types: [python] - entry: "bash -c 'pytype -j ${NUM_CPUS:-auto}'" + entry: "bash -c 'pytype --keep-going -j ${NUM_CPUS:-auto}'" require_serial: true verbose: true - id: docs diff --git a/pyproject.toml b/pyproject.toml index d7f080706..dac31079f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,12 @@ build-backend = "setuptools.build_meta" [tool.black] target-version = ["py38"] + +[tool.pytype] +inputs = [ + "src/", + "tests/", + "experiments/", + "setup.py" +] +python_version = "3.8" diff --git a/setup.cfg b/setup.cfg index 5a3a93cf0..dc06cb335 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,11 +59,3 @@ omit = source = src/imitation *venv/lib/python*/site-packages/imitation - -[pytype] -inputs = - src/ - tests/ - experiments/ - setup.py -python_version >= 3.8 diff --git a/setup.py b/setup.py index e9edc116a..4b3349493 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ ATARI_REQUIRE = [ "seals[atari]~=0.2.1", ] -PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] +PYTYPE = ["pytype==2023.9.27"] if IS_NOT_WINDOWS else [] # Note: the versions of the test and doc requirements should be tightly pinned to known # working versions to make our CI/CD pipeline as stable as possible. diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 952d85e16..86e6f8d53 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -96,7 +96,7 @@ def eval_policy( sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None render_mode = "rgb_array" if videos else None - with environment.make_venv( + with environment.make_venv( # type: ignore[wrong-keyword-args] post_wrappers=post_wrappers, env_make_kwargs=dict(render_mode=render_mode), ) as venv: diff --git a/src/imitation/scripts/ingredients/demonstrations.py b/src/imitation/scripts/ingredients/demonstrations.py index 1367c0722..33ad68520 100644 --- a/src/imitation/scripts/ingredients/demonstrations.py +++ b/src/imitation/scripts/ingredients/demonstrations.py @@ -143,10 +143,10 @@ def _generate_expert_trajs( raise ValueError("n_expert_demos must be specified when generating demos.") logger.info(f"Generating {n_expert_demos} expert trajectories") - with environment.make_rollout_venv() as rollout_env: + with environment.make_rollout_venv() as env: # type: ignore[wrong-arg-count] return rollout.rollout( - expert.get_expert_policy(rollout_env), - rollout_env, + expert.get_expert_policy(env), + env, rollout.make_sample_until(min_episodes=n_expert_demos), rng=_rnd, ) diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 26c8d7bcf..9afc51135 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -119,7 +119,7 @@ def train_adversarial( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 58dae3484..292597561 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -73,7 +73,7 @@ def bc( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger) bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"]) @@ -115,7 +115,7 @@ def dagger( if dagger["use_offline_rollouts"]: expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger) bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"]) @@ -161,7 +161,7 @@ def sqil( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] sqil_trainer = sqil_algorithm.SQIL( venv=venv, demonstrations=expert_trajs, diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 79ee4c136..8fb13f4c4 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -166,7 +166,7 @@ def train_preference_comparisons( custom_logger, log_dir = logging_ingredient.setup_logging() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 96d35122c..199440163 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -99,7 +99,9 @@ def train_rl( policy_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] - with environment.make_venv(post_wrappers=post_wrappers) as venv: + with environment.make_venv( # type: ignore[wrong-keyword-args] + post_wrappers=post_wrappers, + ) as venv: callback_objs = [] if reward_type is not None: reward_fn = load_reward( diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 1654a1482..23868a893 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -147,7 +147,7 @@ def test_make_data_loader(): for batch, expected_batch in zip(data_loader, trans_mapping): assert batch.keys() == expected_batch.keys() for k in batch.keys(): - v = batch[k] + v = batch[k] # type: ignore[typed-dict-error] if isinstance(v, th.Tensor): v = v.numpy() assert np.all(v == expected_batch[k]) diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index e7615461e..64f607df2 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -176,7 +176,7 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype): @pytest.mark.parametrize("sample_shape", [(), (1,), (5, 2)]) def test_buffer_store_errors(sample_shape): capacity = 11 - dtype = "float32" + dtype = np.float32 def buf(): return Buffer(capacity, {"k": sample_shape}, {"k": dtype}) @@ -208,14 +208,14 @@ def buf(): def test_buffer_sample_errors(): - b = Buffer(10, {"k": (2, 1)}, dtypes={"k": bool}) + b = Buffer(10, {"k": (2, 1)}, dtypes={"k": np.bool_}) with pytest.raises(ValueError): b.sample(5) def test_buffer_init_errors(): with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"): - Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a="float32", c=bool)) + Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a=np.float32, c=np.bool_)) def test_replay_buffer_init_errors(): @@ -225,13 +225,13 @@ def test_replay_buffer_init_errors(): ): ReplayBuffer(15, venv=gym.make("CartPole-v1"), obs_shape=(10, 10)) with pytest.raises(ValueError, match=r"Shape or dtype missing.*"): - ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=bool) + ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=np.bool_) with pytest.raises(ValueError, match=r"Shape or dtype missing.*"): - ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=bool, act_dtype=bool) + ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=np.bool_, act_dtype=np.bool_) def test_buffer_from_data(): - data = np.ndarray([50, 30], dtype=bool) + data = np.ndarray([50, 30], dtype=np.bool_) buf = Buffer.from_data({"k": data}) assert buf._arrays["k"] is not data assert data.dtype == buf._arrays["k"].dtype