Skip to content

Commit

Permalink
Add HumanReadableWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueWang25 committed Oct 3, 2023
1 parent be79cf5 commit 6e5c3e8
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 9 deletions.
61 changes: 59 additions & 2 deletions src/imitation/data/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""Environment wrappers for collecting rollouts."""

from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Dict, Union

import gymnasium as gym
from gymnasium.core import Env
import numpy as np
import numpy.typing as npt
from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper

from imitation.data import rollout, types

# The key for human readable data in the observation.
HR_OBS_KEY = "HR_OBS"


class BufferingWrapper(VecEnvWrapper):
"""Saves transitions of underlying VecEnv.
Expand Down Expand Up @@ -170,7 +174,7 @@ def pop_transitions(self) -> types.TransitionsWithRew:


class RolloutInfoWrapper(gym.Wrapper):
"""Add the entire episode's rewards and observations to `info` at episode end.
"""Adds the entire episode's rewards and observations to `info` at episode end.
Whenever done=True, `info["rollouts"]` is a dict with keys "obs" and "rews", whose
corresponding values hold the NumPy arrays containing the raw observations and
Expand Down Expand Up @@ -206,3 +210,56 @@ def step(self, action):
"rews": np.stack(self._rews),
}
return obs, rew, terminated, truncated, info


class HumanReadableWrapper(gym.Wrapper):
"""Adds human-readable observation to `obs` at every step."""

def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"):
"""Builds HumanReadableWrapper
Args:
env: Environment to wrap.
original_obs_key: The key for original observation if the original
observation is not in dict format.
"""
env.render_mode = "rgb_array"
self._original_obs_key = original_obs_key
super().__init__(env)

def _add_hr_obs(
self, obs: Union[np.ndarray, Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
"""Adds human-readable observation to obs.
Transforms obs into dictionary if it is not already, and adds the human-readable
observation from `env.render()` under the key HR_OBS_KEY.
Args:
obs: Observation from environment.
Returns:
Observation dictionary with the human-readable data
Raises:
KeyError: When the key HR_OBS_KEY already exists in the observation
dictionary.
"""
assert self.env.render_mode is not None
assert self.env.render_mode == "rgb_array"
hr_obs = self.env.render()
if not isinstance(obs, Dict):
obs = {self._original_obs_key: obs}

if HR_OBS_KEY in obs:
raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict")
obs[HR_OBS_KEY] = hr_obs
return obs

def reset(self, **kwargs):
obs, info = super().reset(**kwargs)
return self._add_hr_obs(obs), info

def step(self, action):
obs, rew, terminated, truncated, info = self.env.step(action)
return self._add_hr_obs(obs), rew, terminated, truncated, info
43 changes: 36 additions & 7 deletions tests/data/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
"""Tests for `imitation.data.wrappers`."""

from typing import List, Sequence, Type
from typing import List, Sequence, Type, Dict, Optional

import gymnasium as gym
import numpy as np
import pytest
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.data import types
from imitation.data.wrappers import BufferingWrapper
from imitation.data.wrappers import (
BufferingWrapper,
HumanReadableWrapper,
HR_OBS_KEY,
)


class _CountingEnv(gym.Env): # pragma: no cover
Expand All @@ -31,7 +35,7 @@ def __init__(self, episode_length=5):
self.episode_length = episode_length
self.timestep = None

def reset(self, seed=None):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
t, self.timestep = 0, 1
return t, {}

Expand All @@ -47,6 +51,9 @@ def step(self, action):
done = t == self.episode_length
return t, t * 10, done, False, {}

def render(self) -> np.ndarray:
return np.array([self.timestep] * 10)


class _CountingDictEnv(_CountingEnv): # pragma: no cover
"""Similar to _CountingEnv, but with Dict observation."""
Expand All @@ -57,9 +64,9 @@ def __init__(self, episode_length=5):
spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())},
)

def reset(self, seed=None):
t, self.timestep = 0.0, 1.0
return {"t": t}, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
t, self.timestep = 0, 1
return {"t": t, "2t": 2 * t}, {}

def step(self, action):
if self.timestep is None:
Expand All @@ -71,7 +78,7 @@ def step(self, action):

t, self.timestep = self.timestep, self.timestep + 1
done = t == self.episode_length
return {"t": t}, t * 10, done, False, {}
return {"t": t, "2t": 2 * t}, t * 10, done, False, {}


Envs = [_CountingEnv, _CountingDictEnv]
Expand Down Expand Up @@ -278,3 +285,25 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]):
assert venv.n_transitions == 0
with pytest.raises(RuntimeError, match=".* empty .*"):
venv.pop_transitions()


@pytest.mark.parametrize("Env", Envs)
@pytest.mark.parametrize("original_obs_key", ["k1", "k2"])
def test_human_readable_wrapper(Env: Type[gym.Env], original_obs_key: str):
num_obs_key_expected = 2 if Env == _CountingEnv else 3
origin_obs_key = original_obs_key if Env == _CountingEnv else "t"
env = HumanReadableWrapper(Env(), original_obs_key=original_obs_key)

obs, _ = env.reset()
assert isinstance(obs, Dict)
assert HR_OBS_KEY in obs
assert len(obs) == num_obs_key_expected
assert obs[origin_obs_key] == 0
_assert_equal_scrambled_vectors(obs[HR_OBS_KEY], np.array([1] * 10))

next_obs, *_ = env.step(env.action_space.sample())
assert isinstance(next_obs, Dict)
assert HR_OBS_KEY in next_obs
assert len(next_obs) == num_obs_key_expected
assert next_obs[origin_obs_key] == 1
_assert_equal_scrambled_vectors(next_obs[HR_OBS_KEY], np.array([2] * 10))

0 comments on commit 6e5c3e8

Please sign in to comment.