Skip to content

Commit

Permalink
Add partial support for dictionary observation spaces (bc, density) (#…
Browse files Browse the repository at this point in the history
…785)

* first pass of dict obs functionality

* cleanup DictObs

* add dict space to test_types.py, fix some problems

* add dict-obs test for rollout

* add bc.py test

* cleanup

* small fixes

* small fixes

* fix type error in interactive.py

* fix introduced error in mce_irl.py

* fix minor ci complaint

* add basic dictobs tests

* change default bc policy for dict obs space

* refine rollout.py typechecks, comments

* check rollout produces dictobs of correct shape

* cleanup types and dictobs helpers

* clean useless lines

* clean up print statements

* fix typos

Co-authored-by: Adam Gleave <adam@gleave.me>

* assert matching keys in from_obs_list

* move maybe_wrap, clean rollout

* change policy callable to take dict[str, np.ndarray] not dictobs

* rollout info wrapper supports dictobs

* fix from_obs_list key consistency check

* xfail save/load tests with dictobs

* doc for dictobs wrapper

* don't error on int observations

* lint fixes

* cleanup bc test for dict obs

* cleanup bc.py unwrapping

* cleanup rollout.py

* cleanup dictobs interface

* small cleanups

* coverage fixes, test fix

* adjust error types

* docstrings for type helpers

* add dict obs space support for density

* fix typos

Co-authored-by: Adam Gleave <adam@gleave.me>

* Adam suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

* small changes for code review

* fix docstring

* remove FloatReward

* Fix test_bc

* Turn off GPU finding to avoid using gpu device

* Check None to ensure __add__ can work

* fix docstring

* bypass pytype and lint test

* format with black

* Test dict space in density algo

* black format

* small fix

* Add DictObs into test_wrappers

* fix format

* minor fix

* type and lint fix

* Add policy training test

* suppress line too long lint check on a line

* acts to obs for clarity

* Add HumanReadableWrapper

* fix dict env observation space

* adjust wrapper and not set render_mode inside

* Add additional obs check

* Upgrade pytype and remove workaround for old versions

* Fix test_rollout test

* add RemoveHumanReadableWrapper and update ob space

* Revert "add RemoveHumanReadableWrapper and update ob space"

This reverts commit ee83ec5.

* Revert "adjust wrapper and not set render_mode inside"

This reverts commit a9b32bd.

* Revert "fix dict env observation space"

This reverts commit ba6a6a7.

* Revert "Add HumanReadableWrapper"

This reverts commit 6e5c3e8.

* Revert "acts to obs for clarity"

This reverts commit be79cf5.

* address comments

* new pytype need input directory or file

* fix np.dtype

* ignore typed-dict-error

* context manager related fix

* keep pytype checking more failures

* Revert "keep pytype checking more failures"

This reverts commit f5288c6.

* Revert "context manager related fix"

This reverts commit 5c1d751.

* Revert "ignore typed-dict-error"

This reverts commit 5c6e5b8.

* Revert "fix np.dtype"

This reverts commit 6884538.

* Revert "new pytype need input directory or file"

This reverts commit 15541cd.

* Revert "Upgrade pytype and remove workaround for old versions"

This reverts commit 194ec1a.

* lint fix

* fix type check

* fix lint

---------

Co-authored-by: Adam Gleave <adam@gleave.me>
Co-authored-by: ZiyueWang25 <wfuymu@gmail.com>
  • Loading branch information
3 people authored Oct 5, 2023
1 parent 573b086 commit e6d8886
Show file tree
Hide file tree
Showing 25 changed files with 881 additions and 195 deletions.
40 changes: 33 additions & 7 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
Mapping,
Expand All @@ -22,7 +23,7 @@
import numpy as np
import torch as th
import tqdm
from stable_baselines3.common import policies, utils, vec_env
from stable_baselines3.common import policies, torch_layers, utils, vec_env

from imitation.algorithms import base as algo_base
from imitation.data import rollout, types
Expand Down Expand Up @@ -99,7 +100,12 @@ class BehaviorCloningLossCalculator:
def __call__(
self,
policy: policies.ActorCriticPolicy,
obs: Union[th.Tensor, np.ndarray],
obs: Union[
types.AnyTensor,
types.DictObs,
Dict[str, np.ndarray],
Dict[str, th.Tensor],
],
acts: Union[th.Tensor, np.ndarray],
) -> BCTrainingMetrics:
"""Calculate the supervised learning loss used to train the behavioral clone.
Expand All @@ -113,9 +119,18 @@ def __call__(
A BCTrainingMetrics object with the loss and all the components it
consists of.
"""
obs = util.safe_to_tensor(obs)
tensor_obs = types.map_maybe_dict(
util.safe_to_tensor,
types.maybe_unwrap_dictobs(obs),
)
acts = util.safe_to_tensor(acts)
_, log_prob, entropy = policy.evaluate_actions(obs, acts)

# policy.evaluate_actions's type signatures are incorrect.
# See https://github.com/DLR-RM/stable-baselines3/issues/1679
(_, log_prob, entropy) = policy.evaluate_actions(
tensor_obs, # type: ignore[arg-type]
acts,
)
prob_true_act = th.exp(log_prob).mean()
log_prob = log_prob.mean()
entropy = entropy.mean() if entropy is not None else None
Expand Down Expand Up @@ -324,12 +339,18 @@ def __init__(
self.rng = rng

if policy is None:
extractor = (
torch_layers.CombinedExtractor
if isinstance(observation_space, gym.spaces.Dict)
else torch_layers.FlattenExtractor
)
policy = policy_base.FeedForward32Policy(
observation_space=observation_space,
action_space=action_space,
# Set lr_schedule to max value to force error if policy.optimizer
# is used by mistake (should use self.optimizer instead).
lr_schedule=lambda _: th.finfo(th.float32).max,
features_extractor_class=extractor,
)
self._policy = policy.to(utils.get_device(device))
# TODO(adam): make policy mandatory and delete observation/action space params?
Expand Down Expand Up @@ -464,9 +485,14 @@ def process_batch():
minibatch_size,
num_samples_so_far,
), batch in batches_with_stats:
obs = th.as_tensor(batch["obs"], device=self.policy.device).detach()
acts = th.as_tensor(batch["acts"], device=self.policy.device).detach()
training_metrics = self.loss_calculator(self.policy, obs, acts)
obs_tensor: Union[th.Tensor, Dict[str, th.Tensor]]
# unwraps the observation if it's a dictobs and converts arrays to tensors
obs_tensor = types.map_maybe_dict(
lambda x: util.safe_to_tensor(x, device=self.policy.device),
types.maybe_unwrap_dictobs(batch["obs"]),
)
acts = util.safe_to_tensor(batch["acts"], device=self.policy.device)
training_metrics = self.loss_calculator(self.policy, obs_tensor, acts)

# Renormalise the loss to be averaged over the whole
# batch size instead of the minibatch size.
Expand Down
107 changes: 52 additions & 55 deletions src/imitation/algorithms/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def __init__(

def _get_demo_from_batch(
self,
obs_b: np.ndarray,
obs_b: types.Observation,
act_b: np.ndarray,
next_obs_b: Optional[np.ndarray],
next_obs_b: Optional[types.Observation],
) -> Dict[Optional[int], List[np.ndarray]]:
if next_obs_b is None and self.density_type == DensityType.STATE_STATE_DENSITY:
raise ValueError(
Expand All @@ -145,11 +145,18 @@ def _get_demo_from_batch(
)

assert act_b.shape[1:] == self.venv.action_space.shape
assert obs_b.shape[1:] == self.venv.observation_space.shape
ob_space = self.venv.observation_space
if isinstance(obs_b, types.DictObs):
exp_shape = {
k: v.shape for k, v in ob_space.items() # type: ignore[attr-defined]
}
obs_shape = {k: v.shape[1:] for k, v in obs_b.items()}
assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}"
else:
assert obs_b.shape[1:] == ob_space.shape
assert len(act_b) == len(obs_b)
if next_obs_b is not None:
assert next_obs_b.shape[1:] == self.venv.observation_space.shape
assert len(next_obs_b) == len(obs_b)
assert next_obs_b.shape == obs_b.shape

if next_obs_b is not None:
next_obs_b_iterator: Iterable = next_obs_b
Expand Down Expand Up @@ -200,14 +207,17 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:
# analogous to cast above.
demonstrations = cast(Iterable[types.TransitionMapping], demonstrations)

def to_np_maybe_dictobs(x):
if isinstance(x, types.DictObs):
return x
else:
return util.safe_to_numpy(x, warn=True)

for batch in demonstrations:
transitions.update(
self._get_demo_from_batch(
util.safe_to_numpy(batch["obs"], warn=True),
util.safe_to_numpy(batch["acts"], warn=True),
util.safe_to_numpy(batch.get("next_obs"), warn=True),
),
)
obs = to_np_maybe_dictobs(batch["obs"])
acts = util.safe_to_numpy(batch["acts"], warn=True)
next_obs = to_np_maybe_dictobs(batch.get("next_obs"))
transitions.update(self._get_demo_from_batch(obs, acts, next_obs))
else:
raise TypeError(
f"Unsupported demonstration type {type(demonstrations)}",
Expand Down Expand Up @@ -253,65 +263,40 @@ def _fit_density(self, transitions: np.ndarray) -> neighbors.KernelDensity:

def _preprocess_transition(
self,
obs: np.ndarray,
obs: types.Observation,
act: np.ndarray,
next_obs: Optional[np.ndarray],
next_obs: Optional[types.Observation],
) -> np.ndarray:
"""Compute flattened transition on subset specified by `self.density_type`."""
flattened_obs = space_utils.flatten(
self.venv.observation_space,
types.maybe_unwrap_dictobs(obs),
)
flattened_obs = _check_data_is_np_array(flattened_obs, "observation")
if self.density_type == DensityType.STATE_DENSITY:
flat_observations = space_utils.flatten(self.venv.observation_space, obs)
if not isinstance(flat_observations, np.ndarray):
raise ValueError(
"The density estimator only supports spaces that "
"flatten to a numpy array but the observation space "
f"flattens to {type(flat_observations)}",
)

return flat_observations
return flattened_obs
elif self.density_type == DensityType.STATE_ACTION_DENSITY:
flat_observation = space_utils.flatten(self.venv.observation_space, obs)
flat_action = space_utils.flatten(self.venv.action_space, act)

if not isinstance(flat_observation, np.ndarray):
raise ValueError(
"The density estimator only supports spaces that "
"flatten to a numpy array but the observation space "
f"flattens to {type(flat_observation)}",
)
if not isinstance(flat_action, np.ndarray):
raise ValueError(
"The density estimator only supports spaces that "
"flatten to a numpy array but the action space "
f"flattens to {type(flat_action)}",
)

return np.concatenate([flat_observation, flat_action])
flattened_action = space_utils.flatten(self.venv.action_space, act)
flattened_action = _check_data_is_np_array(flattened_action, "action")
return np.concatenate([flattened_obs, flattened_action])
elif self.density_type == DensityType.STATE_STATE_DENSITY:
assert next_obs is not None
flat_observation = space_utils.flatten(self.venv.observation_space, obs)
flat_next_observation = space_utils.flatten(
flat_next_obs = space_utils.flatten(
self.venv.observation_space,
next_obs,
types.maybe_unwrap_dictobs(next_obs),
)
flat_next_obs = _check_data_is_np_array(flat_next_obs, "observation")
assert type(flattened_obs) is type(flat_next_obs)

if not isinstance(flat_observation, np.ndarray):
raise ValueError(
"The density estimator only supports spaces that "
"flatten to a numpy array but the observation space "
f"flattens to {type(flat_observation)}",
)

assert type(flat_observation) is type(flat_next_observation)

return np.concatenate([flat_observation, flat_next_observation])
return np.concatenate([flattened_obs, flat_next_obs])
else:
raise ValueError(f"Unknown density type {self.density_type}")

def __call__(
self,
state: np.ndarray,
state: types.Observation,
action: np.ndarray,
next_state: np.ndarray,
next_state: types.Observation,
done: np.ndarray,
steps: Optional[np.ndarray] = None,
) -> np.ndarray:
Expand Down Expand Up @@ -347,6 +332,8 @@ def __call__(

rew_list = []
assert len(state) == len(action) and len(state) == len(next_state)
state = types.maybe_wrap_in_dictobs(state)
next_state = types.maybe_wrap_in_dictobs(next_state)
for idx, (obs, act, next_obs) in enumerate(zip(state, action, next_state)):
flat_trans = self._preprocess_transition(obs, act, next_obs)
assert self._scaler is not None
Expand Down Expand Up @@ -424,3 +411,13 @@ def policy(self) -> base_class.BasePolicy:
assert self.rl_algo is not None
assert self.rl_algo.policy is not None
return self.rl_algo.policy


def _check_data_is_np_array(data: space_utils.FlatType, name: str) -> np.ndarray:
"""Raises error if the flattened data is not a numpy array."""
assert isinstance(data, np.ndarray), (
"The density estimator only supports spaces that "
f"flatten to a numpy array but the {name} space "
f"flattens to {type(data)}",
)
return data
36 changes: 28 additions & 8 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
"""
import collections
import warnings
from typing import Any, Iterable, List, Mapping, NoReturn, Optional, Tuple, Type, Union
from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
NoReturn,
Optional,
Tuple,
Type,
Union,
)

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -347,7 +358,7 @@ def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None
num_demos = 0
for traj in trajs:
cum_discount = 1.0
for obs in traj.obs:
for obs in types.assert_not_dictobs(traj.obs):
self.demo_state_om[obs] += cum_discount
cum_discount *= self.discount
num_demos += 1
Expand Down Expand Up @@ -411,23 +422,32 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:

if isinstance(demonstrations, types.Transitions):
self._set_demo_from_obs(
demonstrations.obs,
types.assert_not_dictobs(demonstrations.obs),
demonstrations.dones,
demonstrations.next_obs,
types.assert_not_dictobs(demonstrations.next_obs),
)
elif isinstance(demonstrations, types.TransitionsMinimal):
self._set_demo_from_obs(demonstrations.obs, None, None)
self._set_demo_from_obs(
types.assert_not_dictobs(demonstrations.obs),
None,
None,
)
elif isinstance(demonstrations, Iterable):
# Demonstrations are a Torch DataLoader or other Mapping iterable
# Collect them together into one big NumPy array. This is inefficient,
# we could compute the running statistics instead, but in practice do
# not expect large dataset sizes together with MCE IRL.
collated_list = collections.defaultdict(list)
collated_list: Dict[
str,
List[types.AnyTensor],
] = collections.defaultdict(list)
for batch in demonstrations:
assert isinstance(batch, Mapping)
for k in ("obs", "dones", "next_obs"):
if k in batch:
collated_list[k].append(batch[k])
x = batch.get(k)
if x is not None:
assert isinstance(x, (np.ndarray, th.Tensor))
collated_list[k].append(x)
collated = {k: np.concatenate(v) for k, v in collated_list.items()}

assert "obs" in collated
Expand Down
4 changes: 2 additions & 2 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,9 @@ def rewards(self, transitions: Transitions) -> th.Tensor:
Shape - (num_transitions, ) for Single reward network and
(num_transitions, num_networks) for ensemble of networks.
"""
state = transitions.obs
state = types.assert_not_dictobs(transitions.obs)
action = transitions.acts
next_state = transitions.next_obs
next_state = types.assert_not_dictobs(transitions.next_obs)
done = transitions.dones
if self.ensemble_model is not None:
rews_np = self.ensemble_model.predict_processed_all(
Expand Down
10 changes: 5 additions & 5 deletions src/imitation/data/buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Buffers to store NumPy arrays and transitions in."""

import dataclasses
from typing import Any, Mapping, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -368,15 +367,16 @@ def from_data(
Returns:
A new ReplayBuffer.
"""
obs_shape = transitions.obs.shape[1:]
obs = types.assert_not_dictobs(transitions.obs)
obs_shape = obs.shape[1:]
act_shape = transitions.acts.shape[1:]
if capacity is None:
capacity = transitions.obs.shape[0]
capacity = obs.shape[0]
instance = cls(
capacity=capacity,
obs_shape=obs_shape,
act_shape=act_shape,
obs_dtype=transitions.obs.dtype,
obs_dtype=obs.dtype,
act_dtype=transitions.acts.dtype,
)
instance.store(transitions, truncate_ok=truncate_ok)
Expand Down Expand Up @@ -406,7 +406,7 @@ def store(self, transitions: types.Transitions, truncate_ok: bool = True) -> Non
Raises:
ValueError: The arguments didn't have the same length.
""" # noqa: DAR402
trans_dict = dataclasses.asdict(transitions)
trans_dict = types.dataclass_quick_asdict(transitions)
# Remove unnecessary fields
trans_dict = {k: trans_dict[k] for k in self._buffer.sample_shapes.keys()}
self._buffer.store(trans_dict, truncate_ok=truncate_ok)
Expand Down
2 changes: 2 additions & 0 deletions src/imitation/data/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def trajectories_to_dict(
],
terminal=[traj.terminal for traj in trajectories],
)
if any(isinstance(traj.obs, types.DictObs) for traj in trajectories):
raise ValueError("DictObs are not currently supported")

# Encode infos as jsonpickled strings
trajectory_dict["infos"] = [
Expand Down
Loading

0 comments on commit e6d8886

Please sign in to comment.