From 6e5c3e8d695c0c29b721672dbaf3dabae628b23f Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 12:18:06 -0700 Subject: [PATCH] Add HumanReadableWrapper --- src/imitation/data/wrappers.py | 61 ++++++++++++++++++++++++++++++++-- tests/data/test_wrappers.py | 43 ++++++++++++++++++++---- 2 files changed, 95 insertions(+), 9 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 94c88111d..83e4dd0cc 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -1,14 +1,18 @@ """Environment wrappers for collecting rollouts.""" -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Dict, Union import gymnasium as gym +from gymnasium.core import Env import numpy as np import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper from imitation.data import rollout, types +# The key for human readable data in the observation. +HR_OBS_KEY = "HR_OBS" + class BufferingWrapper(VecEnvWrapper): """Saves transitions of underlying VecEnv. @@ -170,7 +174,7 @@ def pop_transitions(self) -> types.TransitionsWithRew: class RolloutInfoWrapper(gym.Wrapper): - """Add the entire episode's rewards and observations to `info` at episode end. + """Adds the entire episode's rewards and observations to `info` at episode end. Whenever done=True, `info["rollouts"]` is a dict with keys "obs" and "rews", whose corresponding values hold the NumPy arrays containing the raw observations and @@ -206,3 +210,56 @@ def step(self, action): "rews": np.stack(self._rews), } return obs, rew, terminated, truncated, info + + +class HumanReadableWrapper(gym.Wrapper): + """Adds human-readable observation to `obs` at every step.""" + + def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): + """Builds HumanReadableWrapper + + Args: + env: Environment to wrap. + original_obs_key: The key for original observation if the original + observation is not in dict format. + """ + env.render_mode = "rgb_array" + self._original_obs_key = original_obs_key + super().__init__(env) + + def _add_hr_obs( + self, obs: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + """Adds human-readable observation to obs. + + Transforms obs into dictionary if it is not already, and adds the human-readable + observation from `env.render()` under the key HR_OBS_KEY. + + Args: + obs: Observation from environment. + + Returns: + Observation dictionary with the human-readable data + + Raises: + KeyError: When the key HR_OBS_KEY already exists in the observation + dictionary. + """ + assert self.env.render_mode is not None + assert self.env.render_mode == "rgb_array" + hr_obs = self.env.render() + if not isinstance(obs, Dict): + obs = {self._original_obs_key: obs} + + if HR_OBS_KEY in obs: + raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict") + obs[HR_OBS_KEY] = hr_obs + return obs + + def reset(self, **kwargs): + obs, info = super().reset(**kwargs) + return self._add_hr_obs(obs), info + + def step(self, action): + obs, rew, terminated, truncated, info = self.env.step(action) + return self._add_hr_obs(obs), rew, terminated, truncated, info diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 33677c68f..cfde9dbcc 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,6 +1,6 @@ """Tests for `imitation.data.wrappers`.""" -from typing import List, Sequence, Type +from typing import List, Sequence, Type, Dict, Optional import gymnasium as gym import numpy as np @@ -8,7 +8,11 @@ from stable_baselines3.common.vec_env import DummyVecEnv from imitation.data import types -from imitation.data.wrappers import BufferingWrapper +from imitation.data.wrappers import ( + BufferingWrapper, + HumanReadableWrapper, + HR_OBS_KEY, +) class _CountingEnv(gym.Env): # pragma: no cover @@ -31,7 +35,7 @@ def __init__(self, episode_length=5): self.episode_length = episode_length self.timestep = None - def reset(self, seed=None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): t, self.timestep = 0, 1 return t, {} @@ -47,6 +51,9 @@ def step(self, action): done = t == self.episode_length return t, t * 10, done, False, {} + def render(self) -> np.ndarray: + return np.array([self.timestep] * 10) + class _CountingDictEnv(_CountingEnv): # pragma: no cover """Similar to _CountingEnv, but with Dict observation.""" @@ -57,9 +64,9 @@ def __init__(self, episode_length=5): spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, ) - def reset(self, seed=None): - t, self.timestep = 0.0, 1.0 - return {"t": t}, {} + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + t, self.timestep = 0, 1 + return {"t": t, "2t": 2 * t}, {} def step(self, action): if self.timestep is None: @@ -71,7 +78,7 @@ def step(self, action): t, self.timestep = self.timestep, self.timestep + 1 done = t == self.episode_length - return {"t": t}, t * 10, done, False, {} + return {"t": t, "2t": 2 * t}, t * 10, done, False, {} Envs = [_CountingEnv, _CountingDictEnv] @@ -278,3 +285,25 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): assert venv.n_transitions == 0 with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() + + +@pytest.mark.parametrize("Env", Envs) +@pytest.mark.parametrize("original_obs_key", ["k1", "k2"]) +def test_human_readable_wrapper(Env: Type[gym.Env], original_obs_key: str): + num_obs_key_expected = 2 if Env == _CountingEnv else 3 + origin_obs_key = original_obs_key if Env == _CountingEnv else "t" + env = HumanReadableWrapper(Env(), original_obs_key=original_obs_key) + + obs, _ = env.reset() + assert isinstance(obs, Dict) + assert HR_OBS_KEY in obs + assert len(obs) == num_obs_key_expected + assert obs[origin_obs_key] == 0 + _assert_equal_scrambled_vectors(obs[HR_OBS_KEY], np.array([1] * 10)) + + next_obs, *_ = env.step(env.action_space.sample()) + assert isinstance(next_obs, Dict) + assert HR_OBS_KEY in next_obs + assert len(next_obs) == num_obs_key_expected + assert next_obs[origin_obs_key] == 1 + _assert_equal_scrambled_vectors(next_obs[HR_OBS_KEY], np.array([2] * 10))