From ae1758846d16b4ceb7284579cc06b46feb322a4a Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Thu, 5 Oct 2023 13:35:27 -0700 Subject: [PATCH] resolve conflict --- src/imitation/algorithms/bc.py | 2 - tests/algorithms/test_bc.py | 1 - tests/algorithms/test_density_baselines.py | 2 - tests/data/test_rollout.py | 44 ---------------------- 4 files changed, 49 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index dbf8f248c..207fc3d66 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -10,7 +10,6 @@ Any, Callable, Dict, - Dict, Iterable, Iterator, Mapping, @@ -25,7 +24,6 @@ import torch as th import tqdm from stable_baselines3.common import policies, torch_layers, utils, vec_env -from stable_baselines3.common import policies, torch_layers, utils, vec_env from imitation.algorithms import base as algo_base from imitation.data import rollout, types, wrappers diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 9ec92c00a..8de49c66e 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -10,7 +10,6 @@ import numpy as np import pytest import torch as th -from stable_baselines3.common import envs as sb_envs from stable_baselines3.common import evaluation from stable_baselines3.common import policies as sb_policies from stable_baselines3.common import vec_env diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index 370068092..5c92feb58 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -7,13 +7,11 @@ import numpy as np import pytest import stable_baselines3 -from stable_baselines3.common import envs as sb_envs from stable_baselines3.common import policies, vec_env from imitation.algorithms.density import DensityAlgorithm, DensityType from imitation.data import rollout, types from imitation.data.types import TrajectoryWithRew -from imitation.data.wrappers import RolloutInfoWrapper from imitation.policies.base import RandomPolicy from imitation.testing import reward_improvement diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 9a2a2d53f..c8c8cd021 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -403,50 +403,6 @@ def observation(self, observation): return {"a": observation, "b": observation / 2} -def test_dictionary_observations(rng): - """Test we can generate a rollout for a dict-type observation environment. - - Args: - rng: Random state to use (with fixed seed). - """ - env = gym.make("CartPole-v1") - env = monitor.Monitor(env, None) - env = DictObsWrapper(env) - venv = vec_env.DummyVecEnv([lambda: env]) - - policy = serialize.load_policy("zero", venv) - trajs = rollout.generate_trajectories( - policy, - venv, - rollout.make_min_episodes(10), - rng=rng, - ) - for traj in trajs: - assert isinstance(traj.obs, types.DictObs) - np.testing.assert_allclose(traj.obs.get("a") / 2, traj.obs.get("b")) - - -class DictObsWrapper(gym.ObservationWrapper): - """Simple wrapper that turns the observation into a dictionary. - - The observation is duplicated, with "b" rescaled. - """ - - def __init__(self, env: gym.Env): - """Builds DictObsWrapper. - - Args: - env: The wrapped Env. - """ - super().__init__(env) - self.observation_space = gym.spaces.Dict( - {"a": env.observation_space, "b": env.observation_space}, - ) - - def observation(self, observation): - return {"a": observation, "b": observation / 2} - - def test_dictionary_observations(rng): """Test we can generate a rollout for a dict-type observation environment.