Skip to content

Commit

Permalink
Merge branch 'master' into benchmark-pr
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 committed Oct 9, 2023
2 parents 01755a2 + f099c33 commit fdcef92
Show file tree
Hide file tree
Showing 17 changed files with 43 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ repos:
name: pytype
language: system
types: [python]
entry: "bash -c 'pytype -j ${NUM_CPUS:-auto}'"
entry: "bash -c 'pytype --keep-going -j ${NUM_CPUS:-auto}'"
require_serial: true
verbose: true
- id: docs
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/1_train_bc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"expert = load_policy(\n",
" \"ppo-huggingface\",\n",
" organization=\"HumanCompatibleAI\",\n",
" env_name=\"seals-CartPole-v0\",\n",
" env_name=\"seals/CartPole-v0\",\n",
" venv=env,\n",
")"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/2_train_dagger.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"expert = load_policy(\n",
" \"ppo-huggingface\",\n",
" organization=\"HumanCompatibleAI\",\n",
" env_name=\"seals-CartPole-v0\",\n",
" env_name=\"seals/CartPole-v0\",\n",
" venv=env,\n",
")"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/3_train_gail.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"expert = load_policy(\n",
" \"ppo-huggingface\",\n",
" organization=\"HumanCompatibleAI\",\n",
" env_name=\"seals:seals/CartPole-v0\",\n",
" env_name=\"seals/CartPole-v0\",\n",
" venv=env,\n",
")"
]
Expand Down
7 changes: 3 additions & 4 deletions docs/tutorials/4_train_airl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"metadata": {},
"outputs": [],
"source": [
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
"import numpy as np\n",
"from imitation.policies.serialize import load_policy\n",
"from imitation.util.util import make_vec_env\n",
Expand All @@ -34,12 +33,12 @@
"FAST = True\n",
"\n",
"if FAST:\n",
" N_RL_TRAIN_STEPS = 300_000\n",
" N_RL_TRAIN_STEPS = 100_000\n",
"else:\n",
" N_RL_TRAIN_STEPS = 2_000_000\n",
"\n",
"venv = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" \"seals:seals/CartPole-v0\",\n",
" rng=np.random.default_rng(SEED),\n",
" n_envs=8,\n",
" post_wrappers=[\n",
Expand All @@ -49,7 +48,7 @@
"expert = load_policy(\n",
" \"ppo-huggingface\",\n",
" organization=\"HumanCompatibleAI\",\n",
" env_name=\"seals-CartPole-v0\",\n",
" env_name=\"seals/CartPole-v0\",\n",
" venv=venv,\n",
")"
]
Expand Down
11 changes: 7 additions & 4 deletions docs/tutorials/7_train_density.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
"metadata": {},
"outputs": [],
"source": [
"from imitation.policies.serialize import load_policy\n",
"from stable_baselines3.common.policies import ActorCriticPolicy\n",
"from stable_baselines3 import PPO\n",
"from huggingface_sb3 import load_from_hub\n",
"from imitation.data import rollout\n",
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
"from stable_baselines3.common.evaluation import evaluate_policy\n",
Expand All @@ -70,12 +70,15 @@
"\n",
"rng = np.random.default_rng(seed=SEED)\n",
"env_name = \"Pendulum-v1\"\n",
"expert = PPO.load(\n",
" load_from_hub(\"HumanCompatibleAI/ppo-Pendulum-v1\", \"ppo-Pendulum-v1.zip\")\n",
").policy\n",
"rollout_env = DummyVecEnv(\n",
" [lambda: RolloutInfoWrapper(gym.make(env_name)) for _ in range(N_VEC)]\n",
")\n",
"expert = load_policy(\n",
" \"ppo-huggingface\",\n",
" organization=\"HumanCompatibleAI\",\n",
" env_name=env_name,\n",
" venv=rollout_env,\n",
")\n",
"rollouts = rollout.rollout(\n",
" expert,\n",
" rollout_env,\n",
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,12 @@ build-backend = "setuptools.build_meta"

[tool.black]
target-version = ["py38"]

[tool.pytype]
inputs = [
"src/",
"tests/",
"experiments/",
"setup.py"
]
python_version = "3.8"
8 changes: 0 additions & 8 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,3 @@ omit =
source =
src/imitation
*venv/lib/python*/site-packages/imitation

[pytype]
inputs =
src/
tests/
experiments/
setup.py
python_version >= 3.8
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ATARI_REQUIRE = [
"seals[atari]~=0.2.1",
]
PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else []
PYTYPE = ["pytype==2023.9.27"] if IS_NOT_WINDOWS else []

# Note: the versions of the test and doc requirements should be tightly pinned to known
# working versions to make our CI/CD pipeline as stable as possible.
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/eval_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def eval_policy(
sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes)
post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None
render_mode = "rgb_array" if videos else None
with environment.make_venv(
with environment.make_venv( # type: ignore[wrong-keyword-args]
post_wrappers=post_wrappers,
env_make_kwargs=dict(render_mode=render_mode),
) as venv:
Expand Down
6 changes: 3 additions & 3 deletions src/imitation/scripts/ingredients/demonstrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def _generate_expert_trajs(
raise ValueError("n_expert_demos must be specified when generating demos.")

logger.info(f"Generating {n_expert_demos} expert trajectories")
with environment.make_rollout_venv() as rollout_env:
with environment.make_rollout_venv() as env: # type: ignore[wrong-arg-count]
return rollout.rollout(
expert.get_expert_policy(rollout_env),
rollout_env,
expert.get_expert_policy(env),
env,
rollout.make_sample_until(min_episodes=n_expert_demos),
rng=_rnd,
)
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def train_adversarial(
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
Expand Down
6 changes: 3 additions & 3 deletions src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def bc(
custom_logger, log_dir = logging_ingredient.setup_logging()

expert_trajs = demonstrations.get_expert_trajectories()
with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger)

bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"])
Expand Down Expand Up @@ -115,7 +115,7 @@ def dagger(
if dagger["use_offline_rollouts"]:
expert_trajs = demonstrations.get_expert_trajectories()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger)

bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"])
Expand Down Expand Up @@ -162,7 +162,7 @@ def sqil(
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
sqil_trainer = sqil_algorithm.SQIL(
venv=venv,
demonstrations=expert_trajs,
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train_preference_comparisons(

custom_logger, log_dir = logging_ingredient.setup_logging()

with environment.make_venv() as venv:
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
reward_net = reward.make_reward_net(venv)
relabel_reward_fn = functools.partial(
reward_net.predict_processed,
Expand Down
4 changes: 3 additions & 1 deletion src/imitation/scripts/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def train_rl(
policy_dir.mkdir(parents=True, exist_ok=True)

post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)]
with environment.make_venv(post_wrappers=post_wrappers) as venv:
with environment.make_venv( # type: ignore[wrong-keyword-args]
post_wrappers=post_wrappers,
) as venv:
callback_objs = []
if reward_type is not None:
reward_fn = load_reward(
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_make_data_loader():
for batch, expected_batch in zip(data_loader, trans_mapping):
assert batch.keys() == expected_batch.keys()
for k in batch.keys():
v = batch[k]
v = batch[k] # type: ignore[typed-dict-error]
if isinstance(v, th.Tensor):
v = v.numpy()
assert np.all(v == expected_batch[k])
12 changes: 6 additions & 6 deletions tests/data/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype):
@pytest.mark.parametrize("sample_shape", [(), (1,), (5, 2)])
def test_buffer_store_errors(sample_shape):
capacity = 11
dtype = "float32"
dtype = np.float32

def buf():
return Buffer(capacity, {"k": sample_shape}, {"k": dtype})
Expand Down Expand Up @@ -208,14 +208,14 @@ def buf():


def test_buffer_sample_errors():
b = Buffer(10, {"k": (2, 1)}, dtypes={"k": bool})
b = Buffer(10, {"k": (2, 1)}, dtypes={"k": np.bool_})
with pytest.raises(ValueError):
b.sample(5)


def test_buffer_init_errors():
with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"):
Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a="float32", c=bool))
Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a=np.float32, c=np.bool_))


def test_replay_buffer_init_errors():
Expand All @@ -225,13 +225,13 @@ def test_replay_buffer_init_errors():
):
ReplayBuffer(15, venv=gym.make("CartPole-v1"), obs_shape=(10, 10))
with pytest.raises(ValueError, match=r"Shape or dtype missing.*"):
ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=bool)
ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=np.bool_)
with pytest.raises(ValueError, match=r"Shape or dtype missing.*"):
ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=bool, act_dtype=bool)
ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=np.bool_, act_dtype=np.bool_)


def test_buffer_from_data():
data = np.ndarray([50, 30], dtype=bool)
data = np.ndarray([50, 30], dtype=np.bool_)
buf = Buffer.from_data({"k": data})
assert buf._arrays["k"] is not data
assert data.dtype == buf._arrays["k"].dtype
Expand Down

0 comments on commit fdcef92

Please sign in to comment.