From 3918d3ed35554b820123bfdce8607c5d75279238 Mon Sep 17 00:00:00 2001 From: HomerW Date: Mon, 20 May 2024 12:02:14 -0700 Subject: [PATCH] add normalize proprio wrapper --- examples/03_eval_finetuned.py | 5 ++++- octo/utils/gym_wrappers.py | 36 +++++++++++++++++++++++++++++++++ octo/utils/train_callbacks.py | 5 +++-- octo/utils/visualization_lib.py | 15 +++++++++++--- scripts/train.py | 4 +--- 5 files changed, 56 insertions(+), 9 deletions(-) diff --git a/examples/03_eval_finetuned.py b/examples/03_eval_finetuned.py index ad0d594b..f897b21a 100644 --- a/examples/03_eval_finetuned.py +++ b/examples/03_eval_finetuned.py @@ -30,7 +30,7 @@ from envs.aloha_sim_env import AlohaGymEnv # noqa from octo.model.octo_model import OctoModel -from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper +from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper from octo.utils.train_callbacks import supply_rng FLAGS = flags.FLAGS @@ -64,6 +64,9 @@ def main(_): ################################################################################################################## env = gym.make("aloha-sim-cube-v0") + # wrap env to normalize proprio + env = NormalizeProprio(env, model.dataset_statistics) + # add wrappers for history and "receding horizon control", i.e. action chunking env = HistoryWrapper(env, horizon=1) env = RHCWrapper(env, exec_horizon=50) diff --git a/octo/utils/gym_wrappers.py b/octo/utils/gym_wrappers.py index 6b124c67..2b1ff03e 100644 --- a/octo/utils/gym_wrappers.py +++ b/octo/utils/gym_wrappers.py @@ -4,6 +4,7 @@ import gym import gym.spaces +import jax import numpy as np import tensorflow as tf @@ -243,3 +244,38 @@ def observation(self, observation): image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy() observation[k] = image return observation + + +class NormalizeProprio(gym.ObservationWrapper): + """ + Un-normalizes the proprio. + """ + + def __init__( + self, + env: gym.Env, + action_proprio_metadata: dict, + ): + self.action_proprio_metadata = jax.tree_map( + lambda x: np.array(x), + action_proprio_metadata, + is_leaf=lambda x: isinstance(x, list), + ) + super().__init__(env) + + def normalize(self, data, metadata): + mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool)) + return np.where( + mask, + (data - metadata["mean"]) / (metadata["std"] + 1e-8), + data, + ) + + def observation(self, obs): + if "proprio" in self.action_proprio_metadata: + obs["proprio"] = self.normalize( + obs["proprio"], self.action_proprio_metadata["proprio"] + ) + else: + assert "proprio" not in obs, "Cannot normalize proprio without metadata." + return obs diff --git a/octo/utils/train_callbacks.py b/octo/utils/train_callbacks.py index 2a3491db..a406162f 100644 --- a/octo/utils/train_callbacks.py +++ b/octo/utils/train_callbacks.py @@ -333,7 +333,7 @@ class RolloutVisualizationCallback(Callback): visualizer_kwargs_list: Sequence[Mapping[str, Any]] text_processor: TextProcessor trajs_for_rollouts: int - unnormalization_statistics: dict + action_proprio_metadata: dict modes_to_evaluate: str = ("text_conditioned", "image_conditioned") def __post_init__(self): @@ -346,6 +346,7 @@ def __post_init__(self): self.rollout_visualizers = [ RolloutVisualizer( + action_proprio_metadata=self.action_proprio_metadata, **kwargs, ) for kwargs in self.visualizer_kwargs_list @@ -358,7 +359,7 @@ def __call__(self, train_state: TrainState, step: int): partial( get_policy_sampled_actions, train_state, - unnormalization_statistics=self.unnormalization_statistics, + unnormalization_statistics=self.action_proprio_metadata["action"], zero_text=self.zero_text, samples_per_state=1, policy_mode=mode, diff --git a/octo/utils/visualization_lib.py b/octo/utils/visualization_lib.py index 10de5b29..6671be44 100644 --- a/octo/utils/visualization_lib.py +++ b/octo/utils/visualization_lib.py @@ -4,7 +4,7 @@ matplotlib.use("Agg") from dataclasses import dataclass -from typing import Any, Dict +from typing import Any, Dict, Optional import dlimp as dl import flax @@ -20,7 +20,12 @@ import tqdm import wandb -from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, TemporalEnsembleWrapper +from octo.utils.gym_wrappers import ( + HistoryWrapper, + NormalizeProprio, + RHCWrapper, + TemporalEnsembleWrapper, +) BASE_METRIC_KEYS = { "mse": ("mse", tuple()), # What is the MSE @@ -275,6 +280,7 @@ class RolloutVisualizer: use_temp_ensembling (bool): Whether to use temporal ensembling or receding horizon control. vis_fps (int): FPS of logged rollout video video_subsample_rate (int): Subsampling rate for video logging (to reduce video size for high-frequency control) + action_proprio_metadata (dict): Dictionary of normalization statistics for proprio and actions. """ name: str @@ -286,9 +292,12 @@ class RolloutVisualizer: use_temp_ensembling: bool = True vis_fps: int = 10 video_subsample_rate: int = 1 + action_proprio_metadata: Optional[dict] = None def __post_init__(self): self._env = gym.make(self.env_name, **self.env_kwargs) + if self.action_proprio_metadata is not None: + self._env = NormalizeProprio(self._env, self.action_proprio_metadata) self._env = HistoryWrapper( self._env, self.history_length, @@ -321,7 +330,7 @@ def listdict2dictlist(LD): # policy outputs are shape [batch, n_samples, pred_horizon, act_dim] # we remove batch dimension & use first sampled action, ignoring other samples actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task) - actions = actions[0, 0] + actions = np.array(actions[0, 0]) obs, reward, done, trunc, info = self._env.step(actions) if "observations" in info: images.extend( diff --git a/scripts/train.py b/scripts/train.py index d6b0d08d..33dd80f6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -282,9 +282,7 @@ def train_step(state: TrainState, batch: Data): dataset_name = rollout_kwargs.pop("dataset_name") rollout_callback = RolloutVisualizationCallback( text_processor=text_processor, - unnormalization_statistics=train_data.dataset_statistics[dataset_name][ - "action" - ], + action_proprio_metadata=train_data.dataset_statistics[dataset_name], **rollout_kwargs, ) else: