Skip to content

Commit

Permalink
add normalize proprio wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
HomerW committed May 20, 2024
1 parent 86ff03a commit 3918d3e
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 9 deletions.
5 changes: 4 additions & 1 deletion examples/03_eval_finetuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from envs.aloha_sim_env import AlohaGymEnv # noqa

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper
from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper
from octo.utils.train_callbacks import supply_rng

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -64,6 +64,9 @@ def main(_):
##################################################################################################################
env = gym.make("aloha-sim-cube-v0")

# wrap env to normalize proprio
env = NormalizeProprio(env, model.dataset_statistics)

# add wrappers for history and "receding horizon control", i.e. action chunking
env = HistoryWrapper(env, horizon=1)
env = RHCWrapper(env, exec_horizon=50)
Expand Down
36 changes: 36 additions & 0 deletions octo/utils/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gym
import gym.spaces
import jax
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -243,3 +244,38 @@ def observation(self, observation):
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()
observation[k] = image
return observation


class NormalizeProprio(gym.ObservationWrapper):
"""
Un-normalizes the proprio.
"""

def __init__(
self,
env: gym.Env,
action_proprio_metadata: dict,
):
self.action_proprio_metadata = jax.tree_map(
lambda x: np.array(x),
action_proprio_metadata,
is_leaf=lambda x: isinstance(x, list),
)
super().__init__(env)

def normalize(self, data, metadata):
mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool))
return np.where(
mask,
(data - metadata["mean"]) / (metadata["std"] + 1e-8),
data,
)

def observation(self, obs):
if "proprio" in self.action_proprio_metadata:
obs["proprio"] = self.normalize(
obs["proprio"], self.action_proprio_metadata["proprio"]
)
else:
assert "proprio" not in obs, "Cannot normalize proprio without metadata."
return obs
5 changes: 3 additions & 2 deletions octo/utils/train_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class RolloutVisualizationCallback(Callback):
visualizer_kwargs_list: Sequence[Mapping[str, Any]]
text_processor: TextProcessor
trajs_for_rollouts: int
unnormalization_statistics: dict
action_proprio_metadata: dict
modes_to_evaluate: str = ("text_conditioned", "image_conditioned")

def __post_init__(self):
Expand All @@ -346,6 +346,7 @@ def __post_init__(self):

self.rollout_visualizers = [
RolloutVisualizer(
action_proprio_metadata=self.action_proprio_metadata,
**kwargs,
)
for kwargs in self.visualizer_kwargs_list
Expand All @@ -358,7 +359,7 @@ def __call__(self, train_state: TrainState, step: int):
partial(
get_policy_sampled_actions,
train_state,
unnormalization_statistics=self.unnormalization_statistics,
unnormalization_statistics=self.action_proprio_metadata["action"],
zero_text=self.zero_text,
samples_per_state=1,
policy_mode=mode,
Expand Down
15 changes: 12 additions & 3 deletions octo/utils/visualization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

matplotlib.use("Agg")
from dataclasses import dataclass
from typing import Any, Dict
from typing import Any, Dict, Optional

import dlimp as dl
import flax
Expand All @@ -20,7 +20,12 @@
import tqdm
import wandb

from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, TemporalEnsembleWrapper
from octo.utils.gym_wrappers import (
HistoryWrapper,
NormalizeProprio,
RHCWrapper,
TemporalEnsembleWrapper,
)

BASE_METRIC_KEYS = {
"mse": ("mse", tuple()), # What is the MSE
Expand Down Expand Up @@ -275,6 +280,7 @@ class RolloutVisualizer:
use_temp_ensembling (bool): Whether to use temporal ensembling or receding horizon control.
vis_fps (int): FPS of logged rollout video
video_subsample_rate (int): Subsampling rate for video logging (to reduce video size for high-frequency control)
action_proprio_metadata (dict): Dictionary of normalization statistics for proprio and actions.
"""

name: str
Expand All @@ -286,9 +292,12 @@ class RolloutVisualizer:
use_temp_ensembling: bool = True
vis_fps: int = 10
video_subsample_rate: int = 1
action_proprio_metadata: Optional[dict] = None

def __post_init__(self):
self._env = gym.make(self.env_name, **self.env_kwargs)
if self.action_proprio_metadata is not None:
self._env = NormalizeProprio(self._env, self.action_proprio_metadata)
self._env = HistoryWrapper(
self._env,
self.history_length,
Expand Down Expand Up @@ -321,7 +330,7 @@ def listdict2dictlist(LD):
# policy outputs are shape [batch, n_samples, pred_horizon, act_dim]
# we remove batch dimension & use first sampled action, ignoring other samples
actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task)
actions = actions[0, 0]
actions = np.array(actions[0, 0])
obs, reward, done, trunc, info = self._env.step(actions)
if "observations" in info:
images.extend(
Expand Down
4 changes: 1 addition & 3 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,7 @@ def train_step(state: TrainState, batch: Data):
dataset_name = rollout_kwargs.pop("dataset_name")
rollout_callback = RolloutVisualizationCallback(
text_processor=text_processor,
unnormalization_statistics=train_data.dataset_statistics[dataset_name][
"action"
],
action_proprio_metadata=train_data.dataset_statistics[dataset_name],
**rollout_kwargs,
)
else:
Expand Down

0 comments on commit 3918d3e

Please sign in to comment.