diff --git a/examples/04_eval_finetuned_on_robot.py b/examples/04_eval_finetuned_on_robot.py index 3afa953b..221d476b 100644 --- a/examples/04_eval_finetuned_on_robot.py +++ b/examples/04_eval_finetuned_on_robot.py @@ -22,7 +22,11 @@ from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus from octo.model.octo_model import OctoModel -from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio +from octo.utils.gym_wrappers import ( + HistoryWrapper, + TemporalEnsembleWrapper, + UnnormalizeActionProprio, +) np.set_printoptions(suppress=True) @@ -87,7 +91,7 @@ def main(_): env_params = WidowXConfigs.DefaultEnvParams.copy() env_params.update(ENV_PARAMS) - env_params["state_state"] = list(start_state) + env_params["start_state"] = list(start_state) widowx_client = WidowXClient(host=FLAGS.ip, port=FLAGS.port) widowx_client.init(env_params, image_size=FLAGS.im_size) env = WidowXGym( @@ -104,12 +108,14 @@ def main(_): # wrap the robot environment env = UnnormalizeActionProprio( - env, model.dataset_statistics, normalization_type="normal" + env, model.dataset_statistics["bridge_dataset"], normalization_type="normal" ) env = HistoryWrapper(env, FLAGS.horizon) - env = RHCWrapper(env, FLAGS.exec_horizon) + env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon) + # switch TemporalEnsembleWrapper with RHCWrapper for receding horizon control + # env = RHCWrapper(env, FLAGS.exec_horizon) - # create policy functions + # create policy function @jax.jit def sample_actions( pretrained_model: OctoModel, @@ -127,11 +133,19 @@ def sample_actions( # remove batch dim return actions[0] - policy_fn = partial( - sample_actions, - model, - argmax=FLAGS.deterministic, - temperature=FLAGS.temperature, + def supply_rng(f, rng=jax.random.PRNGKey(0)): + def wrapped(*args, **kwargs): + nonlocal rng + rng, key = jax.random.split(rng) + return f(*args, rng=key, **kwargs) + + return wrapped + + policy_fn = supply_rng( + partial( + sample_actions, + model, + ) ) goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8) @@ -177,12 +191,12 @@ def sample_actions( else: raise NotImplementedError() + input("Press [Enter] to start.") + # reset env obs, _ = env.reset() time.sleep(2.0) - input("Press [Enter] to start.") - # do rollout last_tstep = time.time() images = [] diff --git a/examples/envs/README.md b/examples/envs/README.md index 601d58f3..ad170f2a 100644 --- a/examples/envs/README.md +++ b/examples/envs/README.md @@ -1,18 +1,21 @@ -# Octo Eval Environments +# Octo Evaluation Environments The `step` and `reset` functions of the Gym environment should return observations with the images, depth images, and/or -proprioceptive information that the policy expects as input. Specifically, the returned observations should be dictionaries +proprioceptive information that the model expects as input. Specifically, the returned observations should be dictionaries of the form: ``` obs = { - "image_0": ..., - "image_1": ..., + "image_primary": ..., + "image_wrist": ..., ... - "depth_0": ..., - "depth_1": ..., + "depth_primary": ..., + "depth_wrist": ..., ... "proprio": ..., } ``` +Note that the image keys should be `image_{key}` where `key` is one of the `image_obs_keys` specified in the data loading config used to train the model (typically this is `primary` and/or `wrist`). +If a key is not present in the observation dictionary, the model will substitute it with padding. + Check out the example environments in this folder to help you integrate your own environment! diff --git a/examples/envs/widowx_env.py b/examples/envs/widowx_env.py index 0018b760..90b8e013 100644 --- a/examples/envs/widowx_env.py +++ b/examples/envs/widowx_env.py @@ -41,14 +41,14 @@ def convert_obs(obs, im_size): proprio = np.concatenate([obs["state"][:6], [0], obs["state"][-1:]]) # NOTE: assume image_1 is not available return { - "image_0": image_obs, + "image_primary": image_obs, "proprio": proprio, } def null_obs(img_size): return { - "image_0": np.zeros((img_size, img_size, 3), dtype=np.uint8), + "image_primary": np.zeros((img_size, img_size, 3), dtype=np.uint8), "proprio": np.zeros((8,), dtype=np.float64), } @@ -72,7 +72,7 @@ def __init__( self.blocking = blocking self.observation_space = gym.spaces.Dict( { - "image_0": gym.spaces.Box( + "image_primary": gym.spaces.Box( low=np.zeros((im_size, im_size, 3)), high=255 * np.ones((im_size, im_size, 3)), dtype=np.uint8, diff --git a/octo/utils/gym_wrappers.py b/octo/utils/gym_wrappers.py index 84a499dc..4ef500a2 100644 --- a/octo/utils/gym_wrappers.py +++ b/octo/utils/gym_wrappers.py @@ -222,10 +222,11 @@ def __init__( self.observation_space, gym.spaces.Dict ), "Only Dict observation spaces are supported." spaces = self.observation_space.spaces + self.resize_size = resize_size if resize_size is None: self.keys_to_resize = {} - elif isinstance(self.resize_size, tuple): + elif isinstance(resize_size, tuple): self.keys_to_resize = {k: resize_size for k in spaces if "image_" in k} else: self.keys_to_resize = {