Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueWang25 committed Oct 5, 2023
1 parent 1073967 commit ae17588
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 49 deletions.
2 changes: 0 additions & 2 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Any,
Callable,
Dict,
Dict,
Iterable,
Iterator,
Mapping,
Expand All @@ -25,7 +24,6 @@
import torch as th
import tqdm
from stable_baselines3.common import policies, torch_layers, utils, vec_env
from stable_baselines3.common import policies, torch_layers, utils, vec_env

from imitation.algorithms import base as algo_base
from imitation.data import rollout, types, wrappers
Expand Down
1 change: 0 additions & 1 deletion tests/algorithms/test_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
import pytest
import torch as th
from stable_baselines3.common import envs as sb_envs
from stable_baselines3.common import evaluation
from stable_baselines3.common import policies as sb_policies
from stable_baselines3.common import vec_env
Expand Down
2 changes: 0 additions & 2 deletions tests/algorithms/test_density_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
import numpy as np
import pytest
import stable_baselines3
from stable_baselines3.common import envs as sb_envs
from stable_baselines3.common import policies, vec_env

from imitation.algorithms.density import DensityAlgorithm, DensityType
from imitation.data import rollout, types
from imitation.data.types import TrajectoryWithRew
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.base import RandomPolicy
from imitation.testing import reward_improvement

Expand Down
44 changes: 0 additions & 44 deletions tests/data/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,50 +403,6 @@ def observation(self, observation):
return {"a": observation, "b": observation / 2}


def test_dictionary_observations(rng):
"""Test we can generate a rollout for a dict-type observation environment.
Args:
rng: Random state to use (with fixed seed).
"""
env = gym.make("CartPole-v1")
env = monitor.Monitor(env, None)
env = DictObsWrapper(env)
venv = vec_env.DummyVecEnv([lambda: env])

policy = serialize.load_policy("zero", venv)
trajs = rollout.generate_trajectories(
policy,
venv,
rollout.make_min_episodes(10),
rng=rng,
)
for traj in trajs:
assert isinstance(traj.obs, types.DictObs)
np.testing.assert_allclose(traj.obs.get("a") / 2, traj.obs.get("b"))


class DictObsWrapper(gym.ObservationWrapper):
"""Simple wrapper that turns the observation into a dictionary.
The observation is duplicated, with "b" rescaled.
"""

def __init__(self, env: gym.Env):
"""Builds DictObsWrapper.
Args:
env: The wrapped Env.
"""
super().__init__(env)
self.observation_space = gym.spaces.Dict(
{"a": env.observation_space, "b": env.observation_space},
)

def observation(self, observation):
return {"a": observation, "b": observation / 2}


def test_dictionary_observations(rng):
"""Test we can generate a rollout for a dict-type observation environment.
Expand Down

0 comments on commit ae17588

Please sign in to comment.