Skip to content

Commit

Permalink
Upgrade pytype (#801)
Browse files Browse the repository at this point in the history
* Upgrade pytype and remove workaround for old versions

* new pytype need input directory or file

* fix np.dtype

* ignore typed-dict-error

* context manager related fix

* keep pytype checking more failures

* Move pytype config to pyproject.toml

* Use inputs specified in pyproject.toml

* Fix lint

* Fix undefined variable

* Fix end of file

* Fix typo

---------

Co-authored-by: Adam Gleave <adam@gleave.me>
  • Loading branch information
ZiyueWang25 and AdamGleave authored Oct 8, 2023
1 parent aca4c07 commit 7b8b4bf
Show file tree
Hide file tree
Showing 12 changed files with 30 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ repos:
name: pytype
language: system
types: [python]
entry: "bash -c 'pytype -j ${NUM_CPUS:-auto}'"
entry: "bash -c 'pytype --keep-going -j ${NUM_CPUS:-auto}'"
require_serial: true
verbose: true
- id: docs
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,12 @@ build-backend = "setuptools.build_meta"

[tool.black]
target-version = ["py38"]

[tool.pytype]
inputs = [
"src/",
"tests/",
"experiments/",
"setup.py"
]
python_version = "3.8"
8 changes: 0 additions & 8 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,3 @@ omit =
source =
src/imitation
*venv/lib/python*/site-packages/imitation

[pytype]
inputs =
src/
tests/
experiments/
setup.py
python_version >= 3.8
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ATARI_REQUIRE = [
"seals[atari]~=0.2.1",
]
PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else []
PYTYPE = ["pytype==2023.9.27"] if IS_NOT_WINDOWS else []

# Note: the versions of the test and doc requirements should be tightly pinned to known
# working versions to make our CI/CD pipeline as stable as possible.
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/eval_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def eval_policy(
sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes)
post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None
render_mode = "rgb_array" if videos else None
with environment.make_venv(
with environment.make_venv( # type: ignore[wrong-keyword-args]
post_wrappers=post_wrappers,
env_make_kwargs=dict(render_mode=render_mode),
) as venv:
Expand Down
6 changes: 3 additions & 3 deletions src/imitation/scripts/ingredients/demonstrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def _generate_expert_trajs(
raise ValueError("n_expert_demos must be specified when generating demos.")

logger.info(f"Generating {n_expert_demos} expert trajectories")
with environment.make_rollout_venv() as rollout_env:
with environment.make_rollout_venv() as env: # type: ignore[wrong-arg-count]
return rollout.rollout(
expert.get_expert_policy(rollout_env),
rollout_env,
expert.get_expert_policy(env),
env,
rollout.make_sample_until(min_episodes=n_expert_demos),
rng=_rnd,
)
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def train_adversarial(
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
Expand Down
6 changes: 3 additions & 3 deletions src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def bc(
custom_logger, log_dir = logging_ingredient.setup_logging()

expert_trajs = demonstrations.get_expert_trajectories()
with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger)

bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"])
Expand Down Expand Up @@ -115,7 +115,7 @@ def dagger(
if dagger["use_offline_rollouts"]:
expert_trajs = demonstrations.get_expert_trajectories()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger)

bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"])
Expand Down Expand Up @@ -161,7 +161,7 @@ def sqil(
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
sqil_trainer = sqil_algorithm.SQIL(
venv=venv,
demonstrations=expert_trajs,
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train_preference_comparisons(

custom_logger, log_dir = logging_ingredient.setup_logging()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
Expand Down
4 changes: 3 additions & 1 deletion src/imitation/scripts/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def train_rl(
policy_dir.mkdir(parents=True, exist_ok=True)

post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)]
with environment.make_venv(post_wrappers=post_wrappers) as venv:
with environment.make_venv( # type: ignore[wrong-keyword-args]
post_wrappers=post_wrappers,
) as venv:
callback_objs = []
if reward_type is not None:
reward_fn = load_reward(
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_make_data_loader():
for batch, expected_batch in zip(data_loader, trans_mapping):
assert batch.keys() == expected_batch.keys()
for k in batch.keys():
v = batch[k]
v = batch[k] # type: ignore[typed-dict-error]
if isinstance(v, th.Tensor):
v = v.numpy()
assert np.all(v == expected_batch[k])
12 changes: 6 additions & 6 deletions tests/data/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype):
@pytest.mark.parametrize("sample_shape", [(), (1,), (5, 2)])
def test_buffer_store_errors(sample_shape):
capacity = 11
dtype = "float32"
dtype = np.float32

def buf():
return Buffer(capacity, {"k": sample_shape}, {"k": dtype})
Expand Down Expand Up @@ -208,14 +208,14 @@ def buf():


def test_buffer_sample_errors():
b = Buffer(10, {"k": (2, 1)}, dtypes={"k": bool})
b = Buffer(10, {"k": (2, 1)}, dtypes={"k": np.bool_})
with pytest.raises(ValueError):
b.sample(5)


def test_buffer_init_errors():
with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"):
Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a="float32", c=bool))
Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a=np.float32, c=np.bool_))


def test_replay_buffer_init_errors():
Expand All @@ -225,13 +225,13 @@ def test_replay_buffer_init_errors():
):
ReplayBuffer(15, venv=gym.make("CartPole-v1"), obs_shape=(10, 10))
with pytest.raises(ValueError, match=r"Shape or dtype missing.*"):
ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=bool)
ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=np.bool_)
with pytest.raises(ValueError, match=r"Shape or dtype missing.*"):
ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=bool, act_dtype=bool)
ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=np.bool_, act_dtype=np.bool_)


def test_buffer_from_data():
data = np.ndarray([50, 30], dtype=bool)
data = np.ndarray([50, 30], dtype=np.bool_)
buf = Buffer.from_data({"k": data})
assert buf._arrays["k"] is not data
assert data.dtype == buf._arrays["k"].dtype
Expand Down

0 comments on commit 7b8b4bf

Please sign in to comment.