Skip to content

Commit

Permalink
Merge pull request #31 from octo-models/fix_robot_eval
Browse files Browse the repository at this point in the history
Fix WidowX evaluation script
  • Loading branch information
mees authored Jan 24, 2024
2 parents ae97d6b + 9994ea0 commit 653c54a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 22 deletions.
38 changes: 26 additions & 12 deletions examples/04_eval_finetuned_on_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio
from octo.utils.gym_wrappers import (
HistoryWrapper,
TemporalEnsembleWrapper,
UnnormalizeActionProprio,
)

np.set_printoptions(suppress=True)

Expand Down Expand Up @@ -87,7 +91,7 @@ def main(_):

env_params = WidowXConfigs.DefaultEnvParams.copy()
env_params.update(ENV_PARAMS)
env_params["state_state"] = list(start_state)
env_params["start_state"] = list(start_state)
widowx_client = WidowXClient(host=FLAGS.ip, port=FLAGS.port)
widowx_client.init(env_params, image_size=FLAGS.im_size)
env = WidowXGym(
Expand All @@ -104,12 +108,14 @@ def main(_):

# wrap the robot environment
env = UnnormalizeActionProprio(
env, model.dataset_statistics, normalization_type="normal"
env, model.dataset_statistics["bridge_dataset"], normalization_type="normal"
)
env = HistoryWrapper(env, FLAGS.horizon)
env = RHCWrapper(env, FLAGS.exec_horizon)
env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon)
# switch TemporalEnsembleWrapper with RHCWrapper for receding horizon control
# env = RHCWrapper(env, FLAGS.exec_horizon)

# create policy functions
# create policy function
@jax.jit
def sample_actions(
pretrained_model: OctoModel,
Expand All @@ -127,11 +133,19 @@ def sample_actions(
# remove batch dim
return actions[0]

policy_fn = partial(
sample_actions,
model,
argmax=FLAGS.deterministic,
temperature=FLAGS.temperature,
def supply_rng(f, rng=jax.random.PRNGKey(0)):
def wrapped(*args, **kwargs):
nonlocal rng
rng, key = jax.random.split(rng)
return f(*args, rng=key, **kwargs)

return wrapped

policy_fn = supply_rng(
partial(
sample_actions,
model,
)
)

goal_image = jnp.zeros((FLAGS.im_size, FLAGS.im_size, 3), dtype=np.uint8)
Expand Down Expand Up @@ -177,12 +191,12 @@ def sample_actions(
else:
raise NotImplementedError()

input("Press [Enter] to start.")

# reset env
obs, _ = env.reset()
time.sleep(2.0)

input("Press [Enter] to start.")

# do rollout
last_tstep = time.time()
images = []
Expand Down
15 changes: 9 additions & 6 deletions examples/envs/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# Octo Eval Environments
# Octo Evaluation Environments

The `step` and `reset` functions of the Gym environment should return observations with the images, depth images, and/or
proprioceptive information that the policy expects as input. Specifically, the returned observations should be dictionaries
proprioceptive information that the model expects as input. Specifically, the returned observations should be dictionaries
of the form:
```
obs = {
"image_0": ...,
"image_1": ...,
"image_primary": ...,
"image_wrist": ...,
...
"depth_0": ...,
"depth_1": ...,
"depth_primary": ...,
"depth_wrist": ...,
...
"proprio": ...,
}
```

Note that the image keys should be `image_{key}` where `key` is one of the `image_obs_keys` specified in the data loading config used to train the model (typically this is `primary` and/or `wrist`).
If a key is not present in the observation dictionary, the model will substitute it with padding.

Check out the example environments in this folder to help you integrate your own environment!
6 changes: 3 additions & 3 deletions examples/envs/widowx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def convert_obs(obs, im_size):
proprio = np.concatenate([obs["state"][:6], [0], obs["state"][-1:]])
# NOTE: assume image_1 is not available
return {
"image_0": image_obs,
"image_primary": image_obs,
"proprio": proprio,
}


def null_obs(img_size):
return {
"image_0": np.zeros((img_size, img_size, 3), dtype=np.uint8),
"image_primary": np.zeros((img_size, img_size, 3), dtype=np.uint8),
"proprio": np.zeros((8,), dtype=np.float64),
}

Expand All @@ -72,7 +72,7 @@ def __init__(
self.blocking = blocking
self.observation_space = gym.spaces.Dict(
{
"image_0": gym.spaces.Box(
"image_primary": gym.spaces.Box(
low=np.zeros((im_size, im_size, 3)),
high=255 * np.ones((im_size, im_size, 3)),
dtype=np.uint8,
Expand Down
3 changes: 2 additions & 1 deletion octo/utils/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,11 @@ def __init__(
self.observation_space, gym.spaces.Dict
), "Only Dict observation spaces are supported."
spaces = self.observation_space.spaces
self.resize_size = resize_size

if resize_size is None:
self.keys_to_resize = {}
elif isinstance(self.resize_size, tuple):
elif isinstance(resize_size, tuple):
self.keys_to_resize = {k: resize_size for k in spaces if "image_" in k}
else:
self.keys_to_resize = {
Expand Down

0 comments on commit 653c54a

Please sign in to comment.