diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c0f81218..c6f212bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,4 +22,4 @@ repos: hooks: - id: isort exclude: ^experiments/ - args: ["--profile", "black", "--src", "jaxrl_m", "--src", "experiments"] + args: ["--profile", "black", "--src", "orca", "--src", "experiments"] diff --git a/README.md b/README.md index 7953b627..ae5b0178 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,19 @@ A robot foundation model. +## Training + +``` +python experiments/main/train.py --config.data_path=<...> --config.save_dir=<...> +``` + ## Contributing -Experimental things and training/eval scripts should go in `experiments/`. To make any changes to files outside of `experiments/`, please open a pull request. +Experimental things and training/eval scripts should go in `experiments/`. To make any changes to files outside of your experiments directory, please open a pull request. + +To enable code checks and auto-formatting, please install pre-commit hooks: +``` +pre-commit install +``` ## Environment ``` diff --git a/experiments/dibya/scripts/train_real.sh b/experiments/dibya/scripts/train_real.sh new file mode 100644 index 00000000..0db29ff4 --- /dev/null +++ b/experiments/dibya/scripts/train_real.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +CONFIG_NAME=${1:-transformer_bc} +DATA_CONFIG_NAME=${2:-all} +echo "Using config $CONFIG_NAME and data config $DATA_CONFIG_NAME" +NAME="test" + +CMD="python experiments/main/train.py \ + --config experiments/main/configs/train_config.py:$CONFIG_NAME \ + --bridgedata_config experiments/main/configs/data_config.py:$DATA_CONFIG_NAME \ + --name $NAME \ + --config.data_path=gs://rail-tpus-homer-v4/data_new" +shift 2 +echo $CMD "$@" +$CMD "$@" diff --git a/orca/data/bridge/bridgedata_numpy_to_tfrecord.py b/experiments/homer/bridge_preprocessing/bridgedata_numpy_to_tfrecord.py similarity index 100% rename from orca/data/bridge/bridgedata_numpy_to_tfrecord.py rename to experiments/homer/bridge_preprocessing/bridgedata_numpy_to_tfrecord.py diff --git a/orca/data/bridge/bridgedata_raw_to_numpy.py b/experiments/homer/bridge_preprocessing/bridgedata_raw_to_numpy.py similarity index 100% rename from orca/data/bridge/bridgedata_raw_to_numpy.py rename to experiments/homer/bridge_preprocessing/bridgedata_raw_to_numpy.py diff --git a/experiments/homer/configs/data_config.py b/experiments/homer/configs/data_config.py index a2d6801e..00890ce0 100644 --- a/experiments/homer/configs/data_config.py +++ b/experiments/homer/configs/data_config.py @@ -180,7 +180,7 @@ def get_config(config_string): "proprio": { "mean": ACT_MEAN, "std": ACT_STD, - } + }, }, } ), @@ -201,7 +201,7 @@ def get_config(config_string): "proprio": { "mean": ACT_MEAN, "std": ACT_STD, - } + }, }, } ), diff --git a/experiments/homer/configs/train_config.py b/experiments/homer/configs/train_config.py index b0e410b6..bb292d54 100644 --- a/experiments/homer/configs/train_config.py +++ b/experiments/homer/configs/train_config.py @@ -12,6 +12,7 @@ def update_config(_prototype, **kwargs): result.pop("_overwrite", None) return ConfigDict(result) + def get_config(config_string): base_sim_config = dict( batch_size=256, @@ -102,7 +103,7 @@ def get_config(config_string): task_tokenizers=["sim-goal-obs-tokenizer"], task_tokenizer_kwargs={"sim-goal-obs-tokenizer": {}}, dataset_kwargs=base_data_config, - **base_sim_config + **base_sim_config, ) ), "transformer_bc": ConfigDict( @@ -115,7 +116,7 @@ def get_config(config_string): task_tokenizers=["goal-obs-tokenizer"], task_tokenizer_kwargs={"goal-obs-tokenizer": {}}, dataset_kwargs=base_data_config, - **base_real_config + **base_real_config, ) ), "transformer_bc_film_lang": ConfigDict( @@ -142,9 +143,7 @@ def get_config(config_string): observation_tokenizers=["obs-tokenizer"], observation_tokenizer_kwargs={"obs-tokenizer": {"num_tokens": 64}}, task_tokenizers=["language-tokenizer"], - task_tokenizer_kwargs={ - "language-tokenizer": {"num_tokens": 16} - }, + task_tokenizer_kwargs={"language-tokenizer": {"num_tokens": 16}}, **base_real_config, ) ), diff --git a/orca/utils/python_utils.py b/experiments/homer/env_utils/python_utils.py similarity index 100% rename from orca/utils/python_utils.py rename to experiments/homer/env_utils/python_utils.py diff --git a/orca/utils/rlds_data_utils.py b/experiments/homer/env_utils/rlds_data_utils.py similarity index 100% rename from orca/utils/rlds_data_utils.py rename to experiments/homer/env_utils/rlds_data_utils.py diff --git a/orca/utils/sim_utils.py b/experiments/homer/env_utils/sim_utils.py similarity index 100% rename from orca/utils/sim_utils.py rename to experiments/homer/env_utils/sim_utils.py diff --git a/orca/utils/timer_utils.py b/experiments/homer/env_utils/timer_utils.py similarity index 100% rename from orca/utils/timer_utils.py rename to experiments/homer/env_utils/timer_utils.py diff --git a/orca/utils/train_utils.py b/experiments/homer/env_utils/train_utils.py similarity index 100% rename from orca/utils/train_utils.py rename to experiments/homer/env_utils/train_utils.py diff --git a/orca/envs/wrappers/chunking.py b/experiments/homer/env_wrappers/wrappers/chunking.py similarity index 100% rename from orca/envs/wrappers/chunking.py rename to experiments/homer/env_wrappers/wrappers/chunking.py diff --git a/orca/envs/wrappers/dmcgym.py b/experiments/homer/env_wrappers/wrappers/dmcgym.py similarity index 100% rename from orca/envs/wrappers/dmcgym.py rename to experiments/homer/env_wrappers/wrappers/dmcgym.py diff --git a/orca/envs/wrappers/mujoco.py b/experiments/homer/env_wrappers/wrappers/mujoco.py similarity index 100% rename from orca/envs/wrappers/mujoco.py rename to experiments/homer/env_wrappers/wrappers/mujoco.py diff --git a/orca/envs/wrappers/norm.py b/experiments/homer/env_wrappers/wrappers/norm.py similarity index 100% rename from orca/envs/wrappers/norm.py rename to experiments/homer/env_wrappers/wrappers/norm.py diff --git a/orca/envs/wrappers/roboverse.py b/experiments/homer/env_wrappers/wrappers/roboverse.py similarity index 99% rename from orca/envs/wrappers/roboverse.py rename to experiments/homer/env_wrappers/wrappers/roboverse.py index 835749b8..dd4a6abb 100644 --- a/orca/envs/wrappers/roboverse.py +++ b/experiments/homer/env_wrappers/wrappers/roboverse.py @@ -58,7 +58,6 @@ def __init__( env: gym.Env, goal_sampler: Union[np.ndarray, Callable], ): - super().__init__(env) self.env = env self.observation_space = gym.spaces.Dict( @@ -96,7 +95,6 @@ def render(self, *args, **kwargs): return self.env.render_obs() def reset(self, **kwargs): - if not callable(self.goal_sampler): idx = np.random.randint(len(self.goal_sampler["observations"]["image"])) goal_image = self.goal_sampler["observations"]["image"][idx] diff --git a/orca/envs/wrappers/video_recorder.py b/experiments/homer/env_wrappers/wrappers/video_recorder.py similarity index 99% rename from orca/envs/wrappers/video_recorder.py rename to experiments/homer/env_wrappers/wrappers/video_recorder.py index 658417db..505bd6fc 100644 --- a/orca/envs/wrappers/video_recorder.py +++ b/experiments/homer/env_wrappers/wrappers/video_recorder.py @@ -107,7 +107,6 @@ def stop_recording(self): self.num_record_episodes = None def step(self, action: np.ndarray): # NOQA - if self.num_record_episodes is None or self.num_record_episodes == 0: observation, reward, terminated, truncated, info = self.env.step(action) diff --git a/experiments/homer/eval.py b/experiments/homer/eval.py index f0c91523..60513e85 100644 --- a/experiments/homer/eval.py +++ b/experiments/homer/eval.py @@ -1,26 +1,27 @@ -import sys -from widowx_envs.widowx.widowx_env import BridgeDataRailRLPrivateWidowX import os -import numpy as np -from PIL import Image -from flax.training import checkpoints +import sys import traceback -import wandb -from orca.vision import encoders -from orca.agents import agents + import matplotlib +import numpy as np +import wandb from absl import app, flags, logging +from flax.training import checkpoints +from PIL import Image +from widowx_envs.widowx.widowx_env import BridgeDataRailRLPrivateWidowX + +from orca.agents import agents +from orca.vision import encoders matplotlib.use("Agg") -import matplotlib.pyplot as plt import time +from collections import deque from datetime import datetime + import jax -import time +import matplotlib.pyplot as plt import tensorflow as tf from widowx_envs.utils.multicam_server_rospkg.src.topic_utils import IMTopic -from collections import deque -from orca.utils.python_utils import list_of_dicts_to_dict_of_lists np.set_printoptions(suppress=True) @@ -56,6 +57,16 @@ WORKSPACE_BOUNDS = np.array([[0.1, -0.15, -0.1, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]) +def list_of_dicts_to_dict_of_lists(list_of_dicts): + dict_of_lists = {} + for dictionary in list_of_dicts: + for key, value in dictionary.items(): + if key not in dict_of_lists: + dict_of_lists[key] = [] + dict_of_lists[key].append(value) + return dict_of_lists + + def unnormalize_action(action, mean, std): return action * std + mean diff --git a/experiments/main/configs/data_config.py b/experiments/main/configs/data_config.py new file mode 100644 index 00000000..00890ce0 --- /dev/null +++ b/experiments/main/configs/data_config.py @@ -0,0 +1,371 @@ +import ml_collections + +PNP_TASKS = [ + "bridge_data_v1/berkeley/toysink1_room8052/put_pan_from_sink_into_drying_rack/", + "bridge_data_v1/berkeley/toysink1_room8052/put_pan_from_drying_rack_into_sink/", + "bridge_data_v1/berkeley/toysink1_room8052/put_pan_on_stove_from_sink/", + "bridge_data_v1/berkeley/toysink1_room8052/put_spoon_into_pan/", + "bridge_data_v1/berkeley/toysink1_room8052/put_pan_from_stove_to_sink/", + "bridge_data_v1/berkeley/toysink1_room8052/put_eggplant_into_pan/", + "bridge_data_v1/berkeley/realkitchen1_counter/put_spoon_on_plate/", + "bridge_data_v1/berkeley/realkitchen1_counter/pick_up_sponge_and_wipe_plate/", + "bridge_data_v1/berkeley/realkitchen1_dishwasher/pick_up_any_cup/", + "bridge_data_v1/berkeley/realkitchen1_dishwasher/pick_up_green_mug/", + "bridge_data_v1/berkeley/realkitchen1_dishwasher/pick_up_glass_cup/", + "bridge_data_v1/berkeley/toysink2_bww/put_carrot_on_plate/", + "bridge_data_v1/berkeley/toysink2_bww/put_spoon_in_pot/", + "bridge_data_v1/berkeley/toysink2_bww/put_knife_on_cutting_board/", + "bridge_data_v1/berkeley/toysink2_bww/put_cup_from_counter_or_drying_rack_into_sink/", + "bridge_data_v1/berkeley/toysink2_bww/put_eggplant_into_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen4/put_banana_in_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen4/put_lid_on_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen4/put_pear_on_plate/", + "bridge_data_v1/berkeley/toykitchen4/put_carrot_in_bowl/", + "bridge_data_v1/berkeley/toykitchen4/put_sushi_in_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen4/put_detergent_in_sink/", + "bridge_data_v1/berkeley/laundry_machine/take_clothes_out_of_laundry_machine/", + "bridge_data_v1/berkeley/laundry_machine/put_clothes_in_laundry_machine/", + "bridge_data_v1/berkeley/toysink3_bww/put_cup_into_pot_or_pan/", + "bridge_data_v1/berkeley/toysink3_bww/put_lid_on_pot_or_pan/", + "bridge_data_v1/berkeley/toysink3_bww/put_cup_from_anywhere_into_sink/", + "bridge_data_v1/berkeley/toysink3_bww/put_knife_in_pot_or_pan/", + "bridge_data_v1/berkeley/toysink3_bww/put_green_squash_into_pot_or_pan/", + "bridge_data_v1/berkeley/toysink3_bww/take_lid_off_pot_or_pan/", + "bridge_data_v1/berkeley/toysink3_bww/put_pot_or_pan_from_sink_into_drying_rack/", + "bridge_data_v1/berkeley/toysink3_bww/put_brush_into_pot_or_pan/", + "bridge_data_v1/berkeley/toysink3_bww/put_detergent_from_sink_into_drying_rack/", + "bridge_data_v1/berkeley/toykitchen1/put_sushi_on_plate/", + "bridge_data_v1/berkeley/toykitchen1/put_pan_in_sink/", + "bridge_data_v1/berkeley/toykitchen1/put_broccoli_in_bowl/", + "bridge_data_v1/berkeley/toykitchen1/put_pot_on_stove_which_is_near_stove_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_corn_into_bowl/", + "bridge_data_v1/berkeley/toykitchen1/take_can_out_of_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_fork_from_basket_to_tray/", + "bridge_data_v1/berkeley/toykitchen1/put_eggplant_on_plate/", + "bridge_data_v1/berkeley/toykitchen1/put_lid_on_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_corn_in_pot_which_is_in_sink_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_carrot_on_plate/", + "bridge_data_v1/berkeley/toykitchen1/take_carrot_off_plate_cardboardfence/", + "bridge_data_v1/berkeley/toykitchen1/put_sweet_potato_in_pan_which_is_on_stove_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_big_spoon_from_basket_to_tray/", + "bridge_data_v1/berkeley/toykitchen1/take_broccoli_out_of_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_lid_on_stove/", + "bridge_data_v1/berkeley/toykitchen1/put_sweet_potato_in_pan_which_is_on_stove/", + "bridge_data_v1/berkeley/toykitchen1/put_corn_in_pan_which_is_on_stove_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_pear_in_bowl/", + "bridge_data_v1/berkeley/toykitchen1/pick_up_pan_from_stove_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_banana_on_plate/", + "bridge_data_v1/berkeley/toykitchen1/pick_up_pot_from_sink_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_sweet_potato_in_pot_which_is_in_sink_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_pot_in_sink/", + "bridge_data_v1/berkeley/toykitchen1/take_sushi_out_of_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_detergent_in_sink/", + "bridge_data_v1/berkeley/toykitchen1/take_broccoli_out_of_pan_cardboardfence/", + "bridge_data_v1/berkeley/toykitchen1/put_broccoli_in_pot_cardboardfence/", + "bridge_data_v1/berkeley/toykitchen1/take_lid_off_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen1/take_carrot_off_plate/", + "bridge_data_v1/berkeley/toykitchen1/put_eggplant_in_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_carrot_on_plate_cardboardfence/", + "bridge_data_v1/berkeley/toykitchen1/pick_up_bowl_and_put_in_small4fbox/", + "bridge_data_v1/berkeley/toykitchen1/put_carrot_on_cutting_board/", + "bridge_data_v1/berkeley/toykitchen1/put_red_bottle_in_sink/", + "bridge_data_v1/berkeley/toykitchen1/put_pepper_in_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_broccoli_in_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_knife_on_cutting_board/", + "bridge_data_v1/berkeley/toykitchen1/put_small_spoon_from_basket_to_tray/", + "bridge_data_v1/berkeley/toykitchen1/put_corn_in_pan_which-is_on_stove_distractors/", + "bridge_data_v1/berkeley/toykitchen1/put_pepper_in_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen1/put_green_squash_in_pot_or_pan/", + "bridge_data_v1/berkeley/tabletop_dark_wood/put_spatula_on_cutting_board/", + "bridge_data_v1/berkeley/tabletop_dark_wood/put_banana_in_colander/", + "bridge_data_v1/berkeley/tabletop_dark_wood/take_banana_out_of_colander/", + "bridge_data_v1/berkeley/toykitchen6/take_cup_off_plate/", + "bridge_data_v1/berkeley/toykitchen6/put_spatula_on_plate_sink/", + "bridge_data_v1/berkeley/toykitchen6/take_spoon_out_of_bowl_sink/", + "bridge_data_v1/berkeley/toykitchen6/put_beet_in_pot_sink/", + "bridge_data_v1/berkeley/toykitchen6/put_corn_in_bowl_sink/", + "bridge_data_v1/berkeley/toykitchen6/put_blueberries_on_plate_sink/", + "bridge_data_v1/berkeley/toykitchen6/take_corn_out_of_bowl_sink/", + "bridge_data_v1/berkeley/toykitchen6/put_cup_on_plate/", + "bridge_data_v1/berkeley/toykitchen6/take_blueberries_off_plate_sink/", + "bridge_data_v1/berkeley/toykitchen6/take_beet_from_pot_sink/", + "bridge_data_v1/berkeley/toykitchen6/take_spatula_off_plate_sink/", + "bridge_data_v1/berkeley/toykitchen6/put_spoon_in_bowl_sink/", + "bridge_data_v1/berkeley/tabletop_white/put_sushi_on_plate/", + "bridge_data_v1/berkeley/tabletop_white/take_sushi_off_plate/", + "bridge_data_v1/berkeley/tabletop_light_wood/put_cucumber_in_cup/", + "bridge_data_v1/berkeley/tabletop_light_wood/take_cucumber_out_of_cup/", + "bridge_data_v1/berkeley/toykitchen2/take_bowl_off_plate_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/put_bowl_on_plate/", + "bridge_data_v1/berkeley/toykitchen2/take_bowl_off_plate/", + "bridge_data_v1/berkeley/toykitchen2/take_sushi_out_of_pot_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/take_lid_off_pot_cardboardfence/", + "bridge_data_v1/berkeley/toykitchen2/put_potato_in_pot_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/take_carrot_out_of_pot_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/put_carrot_in_pot_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/put_knife_on_cutting_board_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/put_cap_on_container/", + "bridge_data_v1/berkeley/toykitchen2/put_sushi_in_pot_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/put_bowl_on_plate_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/put_lid_on_pot_cardboardfence/", + "bridge_data_v1/berkeley/toykitchen2/put_banana_in_pot_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2/put_pear_in_bowl_cardboardfence/", + "bridge_data_v1/berkeley/toykitchen2/put_knife_in_pot_cardboard_fence/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_spatula_in_pan/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_pot_or_pan_on_stove/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_potato_in_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_pear_in_bowl/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_knife_on_cutting_board/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_pot_or_pan_in_sink/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_strawberry_in_pot/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_lemon_on_plate/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_corn_on_plate/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_sushi_on_plate/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_can_in_pot/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_potato_on_plate/", + "bridge_data_v1/berkeley/toykitchen2_room8052/lift_bowl/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_carrot_in_pot_or_pan/", + "bridge_data_v1/berkeley/toykitchen2_room8052/put_sweet_potato_in_pot/", + "bridge_data_v1/berkeley/tool_chest/pick_up_blue_pen_and_put_into_drawer/", + "bridge_data_v1/berkeley/tool_chest/pick_up_red_srewdriver/", + "bridge_data_v1/berkeley/tool_chest/pick_up_box_cutter_and_put_into_drawer/", + "bridge_data_v1/berkeley/tool_chest/pick_up_violet_Allen_key/", + "bridge_data_v1/berkeley/tool_chest/pick_up_bit_holder/", + "bridge_data_v1/berkeley/tool_chest/pick_up_scissors_and_put_into_drawer/", + "bridge_data_v1/berkeley/tool_chest/pick_up_glue_and_put_into_drawer/", +] + +ACT_MEAN = [ + 1.9296819e-04, + 1.3667766e-04, + -1.4583133e-04, + -1.8390431e-04, + -3.0808983e-04, + 2.7425270e-04, + 5.9716219e-01, +] + +ACT_STD = [ + 0.00912848, + 0.0127196, + 0.01229497, + 0.02606696, + 0.02875283, + 0.07807977, + 0.48710242, +] + + +def get_config(config_string): + possible_structures = { + "all": ml_collections.ConfigDict( + { + "include": [ + [ + "icra/?*/?*/?*", + "flap/?*/?*/?*", + "bridge_data_v1/berkeley/?*/?*", + "rss/?*/?*/?*", + "bridge_data_v2/?*/?*/?*", + "scripted/?*", + ] + ], + "exclude": [], + "sample_weights": None, + "action_proprio_metadata": { + "action": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + "proprio": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + }, + } + ), + "test": ml_collections.ConfigDict( + { + "include": [ + [ + "rss/?*/?*/?*", + ] + ], + "exclude": [], + "sample_weights": None, + "action_proprio_metadata": { + "action": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + "proprio": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + }, + } + ), + "all_except_scripted": ml_collections.ConfigDict( + { + "include": [ + [ + "icra/?*/?*/?*", + "icra_validation/?*/?*/?*", + "flap/?*/?*/?*", + "bridge_data_v1/berkeley/?*/?*", + "rss/?*/?*/?*", + "bridge_data_v2/?*/?*/?*", + ] + ], + "exclude": [], + "sample_weights": None, + "action_metadata": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + } + ), + "franka": ml_collections.ConfigDict( + { + "include": [["?*"]], + "exclude": [], + "sample_weights": None, + "action_metadata": { + "mean": [ + 5.2401489e-01, + -6.7343891e-02, + 2.5386891e-01, + 2.6513453e00, + -8.4149389e-04, + 1.2696550e-02, + 2.9238686e-01, + ], + "std": [ + 0.08792825, + 0.08270102, + 0.11227315, + 1.5259572, + 0.09435784, + 0.16661045, + 0.41294536, + ], + }, + } + ), + "all_exclude_toykitchen7": ml_collections.ConfigDict( + { + "include": [ + [ + "icra/?*/?*/?*", + "icra_validation/?*/?*/?*", + "flap/?*/?*/?*", + "bridge_data_v1/berkeley/?*/?*", + "rss/?*/?*/?*", + "bridge_data_v2/?*/?*/?*", + "scripted/?*", + ] + ], + "exclude": [ + "*toykitchen7*", + "*tabletop_dark_wood*", + "*icra_validation/toykitchen_fixed_cam_offline_validation/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_combo/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_push_sweep/tabletop*", + ], + "sample_weights": None, + "action_metadata": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + } + ), + "all_finetune": ml_collections.ConfigDict( + { + "include": [ + [ + "icra/*/*/*", + "icra_validation/?*/?*/?*", + "flap/?*/?*/?*", + "bridge_data_v1/berkeley/?*/?*", + "rss/toykitchen2/?*/?*", + "rss/toykitchen6/?*/?*", + ], + ["rss/toykitchen7/pnp_sweep_target_fixed/?*"], + ], + "exclude": [ + "*tabletop_dark_wood*", + "*icra_validation/toykitchen_fixed_cam_offline_validation/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_combo/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_push_sweep/tabletop*", + ], + "sample_weights": [0.9, 0.1], + "action_metadata": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + } + ), + "all_finetune_autonomous": ml_collections.ConfigDict( + { + "include": [ + [ + "icra/?*/?*/?*", + "icra_validation/?*/?*/?*", + "flap/?*/?*/?*", + "bridge_data_v1/berkeley/?*/?*", + "rss/?*/?*/?*", + "bridge_data_v2/?*/?*/?*", + "scripted/?*", + ], + ["learned/toykitchen7/pnp_sweep_v2"], + ], + "exclude": [ + "*rss/toykitchen7*", + "*tabletop_dark_wood*", + "*icra_validation/toykitchen_fixed_cam_offline_validation/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_combo/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_push_sweep/tabletop*", + "*sweep_12-03*", + ], + "sample_weights": [0.9, 0.1], + "action_metadata": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + } + ), + "all_finetune_autonomous_oracle": ml_collections.ConfigDict( + { + "include": [ + [ + "icra/*/*/*", + "icra_validation/?*/?*/?*", + "flap/?*/?*/?*", + "bridge_data_v1/berkeley/?*/?*", + "rss/toykitchen2/?*/?*", + "rss/toykitchen6/?*/?*", + ], + [ + "finetuning/ours_2_22/?*", + "rss/toykitchen7/pnp_sweep_target_fixed/?*", + ], + ], + "exclude": [ + "*tabletop_dark_wood*", + "*icra_validation/toykitchen_fixed_cam_offline_validation/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_combo/tabletop*", + "*icra/toykitchen_fixed_cam_resetfree_push_sweep/tabletop*", + ], + "sample_weights": [0.9, 0.1], + "action_metadata": { + "mean": ACT_MEAN, + "std": ACT_STD, + }, + } + ), + } + return possible_structures[config_string] diff --git a/experiments/main/configs/train_config.py b/experiments/main/configs/train_config.py new file mode 100644 index 00000000..40a337e1 --- /dev/null +++ b/experiments/main/configs/train_config.py @@ -0,0 +1,179 @@ +from ml_collections import ConfigDict +from ml_collections.config_dict import placeholder, required_placeholder + + +def update_config(_prototype, **kwargs): + result = dict(_prototype) + for key, value in kwargs.items(): + if type(result.get(key)) == dict or type(result.get(key)) == ConfigDict: + if not kwargs[key].get("_overwrite", False): + value = dict(update_config(_prototype=result[key], **kwargs[key])) + value.pop("_overwrite", None) + result[key] = value + result.pop("_overwrite", None) + return ConfigDict(result) + + +def get_config(config_string): + base_wandb_config = dict( + project="orca", + group=placeholder(str), + entity=placeholder(str), + ) + + base_real_config = dict( + batch_size=64, + num_steps=int(2e6), + log_interval=100, + eval_interval=5000, + save_interval=5000, + save_dir=placeholder(str), + data_path=placeholder(str), + resume_path=placeholder(str), + seed=42, + text_processor="muse_embedding", + text_processor_kwargs=dict(), + pretrained_weights=[], + wandb=base_wandb_config, + ) + + # params that need to be specified multiple places + normalization_type = "normal" + + base_data_config = dict( + shuffle_buffer_size=25000, + prefetch_num_batches=20, + augment=True, + augment_next_obs_goal_differently=False, + augment_kwargs=dict( + random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]), + random_brightness=[0.2], + random_contrast=[0.8, 1.2], + random_saturation=[0.8, 1.2], + random_hue=[0.1], + augment_order=[ + "random_resized_crop", + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue", + ], + ), + goal_relabeling_strategy="uniform", + goal_relabeling_kwargs=dict(reached_proportion=0.0), + normalization_type=normalization_type, + ) + base_optimizer_config = dict( + learning_rate=3e-4, + warmup_steps=2000, + decay_steps=int(2e6), + ) + + base_model_config = dict( + policy_kwargs=dict( + num_layers=4, + layer_size=1024, + vocab_size=256, + num_heads=8, + feed_forward_size=512, + dropout_rate=0.1, + normalization_type=normalization_type, + ), + ) + + possible_structures = { + "transformer_bc": ConfigDict( + dict( + agent="transformer_bc", + obs_horizon=1, + model=update_config( + base_model_config, + observation_tokenizers=["obs-tokenizer"], + observation_tokenizer_kwargs={"obs-tokenizer": {}}, + task_tokenizers=["goal-obs-tokenizer"], + task_tokenizer_kwargs={"goal-obs-tokenizer": {}}, + ), + optimizer=base_optimizer_config, + dataset_kwargs=base_data_config, + **base_real_config, + ) + ), + "transformer_bc_film_lang": ConfigDict( + dict( + agent="transformer_bc", + obs_horizon=1, + model=update_config( + base_model_config, + observation_tokenizers=["obs-film-language-tokenizer"], + observation_tokenizer_kwargs={ + "obs-film-language-tokenizer": {"num_tokens": 64} + }, + task_tokenizers=[], + task_tokenizer_kwargs={}, + ), + optimizer=base_optimizer_config, + dataset_kwargs=base_data_config, + **base_real_config, + ) + ), + "transformer_bc_lang": ConfigDict( + dict( + agent="transformer_bc", + obs_horizon=1, + model=update_config( + base_model_config, + observation_tokenizers=["obs-tokenizer"], + observation_tokenizer_kwargs={"obs-tokenizer": {"num_tokens": 64}}, + task_tokenizers=["language-tokenizer"], + task_tokenizer_kwargs={"language-tokenizer": {"num_tokens": 16}}, + ), + optimizer=base_optimizer_config, + dataset_kwargs=base_data_config, + **base_real_config, + ) + ), + "transformer_bc_clip_text": ConfigDict( + dict( + agent="transformer_bc", + obs_horizon=1, + model=update_config( + base_model_config, + observation_tokenizers=["obs-tokenizer"], + observation_tokenizer_kwargs={"obs-tokenizer": {"num_tokens": 64}}, + task_tokenizers=["clip-text-tokenizer"], + task_tokenizer_kwargs={"clip-text-tokenizer": {"num_tokens": 64}}, + ), + optimizer=base_optimizer_config, + dataset_kwargs=base_data_config, + **update_config( + base_real_config, + text_processor="clip_processor", + pretrained_weights=["clip"], + ), + ) + ), + "transformer_bc_clip_vit_and_text": ConfigDict( + dict( + agent="transformer_bc", + obs_horizon=1, + model=update_config( + base_model_config, + observation_tokenizers=["clip-obs-tokenizer"], + observation_tokenizer_kwargs={ + "clip-obs-tokenizer": {"num_tokens": 50} + }, + task_tokenizers=["clip-text-tokenizer"], + task_tokenizer_kwargs={"clip-text-tokenizer": {"num_tokens": 64}}, + ), + optimizer=base_optimizer_config, + dataset_kwargs=update_config(base_data_config, image_processor="clip"), + **update_config( + base_real_config, + text_processor="clip_processor", + pretrained_weights=["clip"], + ), + ) + ), + } + + return possible_structures[config_string] diff --git a/experiments/main/train.py b/experiments/main/train.py new file mode 100644 index 00000000..197a89fa --- /dev/null +++ b/experiments/main/train.py @@ -0,0 +1,269 @@ +import datetime +import json +import os +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow as tf +import tqdm +import wandb +from absl import app, flags, logging +from flax.training import checkpoints +from flax.traverse_util import flatten_dict +from ml_collections import config_flags + +from orca.data.bridge_dataset import BridgeDataset, glob_to_path_list +from orca.data.text_processing import text_processors +from orca.model import create_model_def +from orca.model.tokenizers import weights_loaders +from orca.train_utils import ( + Timer, + TrainState, + create_train_state, + format_name_with_config, + initialize_compilation_cache, + shard_batch, +) + +try: + from jax_smi import initialise_tracking # type: ignore + + initialise_tracking() +except ImportError: + pass + +FLAGS = flags.FLAGS + +flags.DEFINE_string("name", "experiment", "Experiment name.") +flags.DEFINE_bool("debug", False, "Debug config (no wandb logging)") + +config_dir = os.path.join(os.path.dirname(__file__), "configs") +config_flags.DEFINE_config_file( + "config", + os.path.join(config_dir, "train_config.py:transformer_bc"), + "File path to the training hyperparameter configuration.", + lock_config=False, +) + +config_flags.DEFINE_config_file( + "bridgedata_config", + os.path.join(config_dir, "data_config.py:all"), + "File path to the bridgedata configuration.", + lock_config=False, +) + + +def main(_): + initialize_compilation_cache() + devices = jax.local_devices() + num_devices = len(devices) + assert FLAGS.config.batch_size % num_devices == 0 + + # prevent tensorflow from using GPUs + tf.config.set_visible_devices([], "GPU") + + # set up wandb and logging + name = format_name_with_config( + FLAGS.name, + FLAGS.config.to_dict(), + ) + wandb_id = "{name}_{time}".format( + name=name, + time=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), + ) + wandb.init( + config=FLAGS.config.to_dict(), + id=wandb_id, + name=name, + mode="disabled" if FLAGS.debug else None, + **FLAGS.config.wandb, + ) + if FLAGS.config.save_dir is not None: + save_dir = tf.io.gfile.join( + FLAGS.config.save_dir, + FLAGS.config.wandb.project, + FLAGS.config.wandb.group or "", + wandb_id, + ) + wandb.config.update(dict(save_dir=save_dir), allow_val_change=True) + logging.info("Saving to %s", save_dir) + tf.io.gfile.makedirs(save_dir) + with tf.io.gfile.GFile( + os.path.join(save_dir, "config.json"), "w" + ) as config_file: + config_file.write(FLAGS.config.to_json_best_effort()) + else: + save_dir = None + logging.info("save_dir not passed in, not saving checkpoints") + # load datasets + + logging.info(f"Loading data from {FLAGS.config.data_path}") + assert type(FLAGS.bridgedata_config.include[0]) == list + task_paths = [ + glob_to_path_list( + path, prefix=FLAGS.config.data_path, exclude=FLAGS.bridgedata_config.exclude + ) + for path in FLAGS.bridgedata_config.include + ] + + train_paths = [ + [os.path.join(path, "train/out.tfrecord") for path in sub_list] + for sub_list in task_paths + ] + val_paths = [ + [os.path.join(path, "val/out.tfrecord") for path in sub_list] + for sub_list in task_paths + ] + + obs_horizon = FLAGS.config.get("obs_horizon") + text_processor = text_processors[FLAGS.config.text_processor]( + **FLAGS.config.text_processor_kwargs + ) + + train_data = BridgeDataset( + train_paths, + FLAGS.config.seed, + batch_size=FLAGS.config.batch_size, + train=True, + action_proprio_metadata=FLAGS.bridgedata_config.action_proprio_metadata, + sample_weights=FLAGS.bridgedata_config.sample_weights, + obs_horizon=obs_horizon, + text_processor=text_processor, + **FLAGS.config.dataset_kwargs, + ) + val_data = BridgeDataset( + val_paths, + FLAGS.config.seed, + batch_size=FLAGS.config.batch_size, + action_proprio_metadata=FLAGS.bridgedata_config.action_proprio_metadata, + train=False, + obs_horizon=obs_horizon, + text_processor=text_processor, + **FLAGS.config.dataset_kwargs, + ) + train_data_iter = train_data.get_iterator() + + example_batch = next(train_data_iter) + logging.info(f"Batch size: {example_batch['observations']['image'].shape[0]}") + logging.info(f"Number of devices: {num_devices}") + logging.info( + f"Batch size per device: {example_batch['observations']['image'].shape[0] // num_devices}" + ) + + # we shard the leading dimension (batch dimension) accross all devices evenly + sharding = jax.sharding.PositionalSharding(devices) + example_batch = shard_batch(example_batch, sharding) + + model_def = create_model_def( + action_dim=example_batch["actions"].shape[-1], + time_sequence_length=example_batch["observations"]["image"].shape[1], + **FLAGS.config.model.to_dict(), + ) + + # pretrained weights to load + pretrained_loaders = [weights_loaders[w] for w in FLAGS.config.pretrained_weights] + + lr_schedule = optax.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=FLAGS.config.optimizer.learning_rate, + warmup_steps=FLAGS.config.optimizer.warmup_steps, + decay_steps=FLAGS.config.optimizer.decay_steps, + end_value=0.0, + ) + tx = optax.adam(lr_schedule) + rng = jax.random.PRNGKey(FLAGS.config.seed) + rng, construct_rng = jax.random.split(rng) + + train_state = create_train_state( + construct_rng, + model_def, + tx, + init_args=( + example_batch["observations"], + example_batch["goals"], + example_batch["actions"], + ), + pretrained_loaders=pretrained_loaders, + ) + if FLAGS.config.resume_path is not None: + train_state = checkpoints.restore_checkpoint( + FLAGS.config.resume_path, target=train_state + ) + + # replicate agent across devices + # need the jnp.array to avoid a bug where device_put doesn't recognize primitives + train_state = jax.device_put( + jax.tree_map(jnp.array, train_state), sharding.replicate() + ) + + def loss_fn(params, state, batch, rng, train=True): + info = state.apply_fn( + {"params": params}, + batch["observations"], + batch["goals"], + batch["actions"], + train=train, + rngs={"dropout": rng}, + ) + return info["loss"], info + + @jax.jit + def train_step(state, batch): + rng, dropout_rng = jax.random.split(state.rng) + (loss, info), grads = jax.value_and_grad(loss_fn, has_aux=True)( + state.params, state, batch, dropout_rng, train=True + ) + new_state = state.apply_gradients(grads=grads, rng=rng) + return new_state, info + + @jax.jit + def eval_step(state, batch): + loss, info = loss_fn(state.params, state, batch, state.rng, train=False) + return info + + def wandb_log(info, step): + wandb.log(flatten_dict(info, sep="/"), step=step) + + timer = Timer() + for i in tqdm.tqdm(range(int(FLAGS.config.num_steps))): + timer.tick("total") + + timer.tick("dataset") + batch = shard_batch(next(train_data_iter), sharding) + timer.tock("dataset") + + timer.tick("train") + train_state, update_info = train_step(train_state, batch) + timer.tock("train") + + if (i + 1) % FLAGS.config.eval_interval == 0: + logging.info("Evaluating...") + timer.tick("val") + metrics = [] + for batch in val_data.get_iterator(): + metrics.append(eval_step(train_state, batch)) + metrics = jax.tree_map(lambda *xs: np.mean(xs), *metrics) + wandb_log({"validation": metrics}, step=i) + timer.tock("val") + + if (i + 1) % FLAGS.config.save_interval == 0 and save_dir is not None: + logging.info("Saving checkpoint...") + checkpoint_path = checkpoints.save_checkpoint( + save_dir, train_state, step=i + 1, keep=1e6 + ) + logging.info("Saved checkpoint to %s", checkpoint_path) + + timer.tock("total") + + if (i + 1) % FLAGS.config.log_interval == 0: + update_info = jax.device_get(update_info) + wandb_log( + {"training": update_info, "timer": timer.get_average_times()}, step=i + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/orca/__init__.py b/orca/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orca/agents/__init__.py b/orca/agents/__init__.py deleted file mode 100644 index e49a03c7..00000000 --- a/orca/agents/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .transformer_bc import TransformerBCAgent - -agents = { - "transformer_bc": TransformerBCAgent, -} diff --git a/orca/agents/transformer_bc.py b/orca/agents/transformer_bc.py deleted file mode 100644 index 22b80ec2..00000000 --- a/orca/agents/transformer_bc.py +++ /dev/null @@ -1,138 +0,0 @@ -import copy -from functools import partial -from typing import Any, Sequence - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -import optax -from flax.core import FrozenDict -from orca.common.common import JaxRLTrainState, nonpytree_field -from orca.common.typing import Batch, PRNGKey -from orca.networks.transformer_policy import TransformerPolicy - - -class TransformerBCAgent(flax.struct.PyTreeNode): - state: JaxRLTrainState - lr_schedule: Any = nonpytree_field() - - @partial(jax.jit, static_argnames="pmap_axis") - def update(self, batch: Batch, pmap_axis: str = None): - def loss_fn(params, rng): - rng, key = jax.random.split(rng) - info = self.state.apply_fn( - {"params": params}, - batch["observations"], - batch["goals"], - batch["actions"], - train=True, - rngs={"dropout": key}, - ) - return info["loss"], info - - # compute gradients and update params - new_state, info = self.state.apply_loss_fns( - loss_fn, pmap_axis=pmap_axis, has_aux=True - ) - - # log learning rates - info["lr"] = self.lr_schedule(self.state.step) - - return self.replace(state=new_state), info - - @partial(jax.jit, static_argnames="argmax") - def sample_actions( - self, - observations: np.ndarray, - goals: np.ndarray, - *, - seed: PRNGKey, - temperature: float = 1.0, - argmax=False - ) -> jnp.ndarray: - if len(observations["image"].shape) == 4: - # unbatched input from evaluation - observations = jax.tree_map(lambda x: x[None], observations) - goals = jax.tree_map(lambda x: x[None], goals) - actions = self.state.apply_fn( - {"params": self.state.params}, - observations, - goals, - method="predict_action", - train=False, - argmax=argmax, - rng=seed, - temperature=temperature, - ) - return actions[0] - - @jax.jit - def get_debug_metrics(self, batch, **kwargs): - return self.state.apply_fn( - {"params": self.state.params}, - batch["observations"], - batch["goals"], - batch["actions"], - ) - - @classmethod - def create( - cls, - rng: PRNGKey, - observations: FrozenDict, - actions: jnp.ndarray, - goals: FrozenDict, - # Model architecture - observation_tokenizer_defs: Sequence[nn.Module], - task_tokenizer_defs: Sequence[nn.Module], - policy_kwargs: dict = {}, - # Optimizer - learning_rate: float = 3e-4, - warmup_steps: int = 1000, - decay_steps: int = 1000000, - # Load pretrained weights - pretrained_weights: Sequence[Any] = [], - ): - # time sequence length is the observation history length - if len(observations["image"].shape) == 5: - # batched input - time_sequence_length = observations["image"].shape[1] - else: - # unbatched input - time_sequence_length = observations["image"].shape[0] - - model_def = TransformerPolicy( - observation_tokenizers=observation_tokenizer_defs, - task_tokenizers=task_tokenizer_defs, - action_dim=actions.shape[-1], - time_sequence_length=time_sequence_length, - **policy_kwargs - ) - - lr_schedule = optax.warmup_cosine_decay_schedule( - init_value=0.0, - peak_value=learning_rate, - warmup_steps=warmup_steps, - decay_steps=decay_steps, - end_value=0.0, - ) - tx = optax.adam(lr_schedule) - - rng, init_rng = jax.random.split(rng) - params = model_def.init(init_rng, observations, goals, actions)["params"] - - for loader in pretrained_weights: - params = loader(params) - - rng, create_rng = jax.random.split(rng) - state = JaxRLTrainState.create( - apply_fn=model_def.apply, - params=params, - txs=tx, - target_params=params, - rng=create_rng, - ) - - return cls(state, lr_schedule) diff --git a/orca/common/common.py b/orca/common/common.py deleted file mode 100644 index ad8f781d..00000000 --- a/orca/common/common.py +++ /dev/null @@ -1,244 +0,0 @@ -import functools -from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, Union - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -import optax -from flax import struct -from orca.common.typing import Params, PRNGKey - -nonpytree_field = functools.partial(flax.struct.field, pytree_node=False) - -default_init = nn.initializers.xavier_uniform - - -def shard_batch(batch, sharding): - """Shards a batch across devices along its first dimension. - - Args: - batch: A pytree of arrays. - sharding: A jax Sharding object with shape (num_devices,). - """ - return jax.tree_map( - lambda x: jax.device_put( - x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) - ), - batch, - ) - - -class ModuleDict(nn.Module): - """ - Utility class for wrapping a dictionary of modules. This is useful when you have multiple modules that you want to - initialize all at once (creating a single `params` dictionary), but you want to be able to call them separately - later. As a bonus, the modules may have sub-modules nested inside them that share parameters (e.g. an image encoder) - and Flax will automatically handle this without duplicating the parameters. - - To initialize the modules, call `init` with no `name` kwarg, and then pass the example arguments to each module as - additional kwargs. To call the modules, pass the name of the module as the `name` kwarg, and then pass the arguments - to the module as additional args or kwargs. - - Example usage: - ``` - shared_encoder = Encoder() - actor = Actor(encoder=shared_encoder) - critic = Critic(encoder=shared_encoder) - - model_def = ModuleDict({"actor": actor, "critic": critic}) - params = model_def.init(rng_key, actor=example_obs, critic=(example_obs, example_action)) - - actor_output = model_def.apply({"params": params}, example_obs, name="actor") - critic_output = model_def.apply({"params": params}, example_obs, action=example_action, name="critic") - ``` - """ - - modules: Dict[str, nn.Module] - - @nn.compact - def __call__(self, *args, name=None, **kwargs): - if name is None: - if kwargs.keys() != self.modules.keys(): - raise ValueError( - f"When `name` is not specified, kwargs must contain the arguments for each module. " - f"Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}" - ) - out = {} - for key, value in kwargs.items(): - if isinstance(value, Mapping): - out[key] = self.modules[key](**value) - elif isinstance(value, Sequence): - out[key] = self.modules[key](*value) - else: - out[key] = self.modules[key](value) - return out - - return self.modules[name](*args, **kwargs) - - -class JaxRLTrainState(struct.PyTreeNode): - """ - Custom TrainState class to replace `flax.training.train_state.TrainState`. - - Adds support for holding target params and updating them via polyak - averaging. Adds the ability to hold an rng key for dropout. - - Also generalizes the TrainState to support an arbitrary pytree of - optimizers, `txs`. When `apply_gradients()` is called, the `grads` argument - must have `txs` as a prefix. This is backwards-compatible, meaning `txs` can - be a single optimizer and `grads` can be a single tree with the same - structure as `self.params`. - - Also adds a convenience method `apply_loss_fns` that takes a pytree of loss - functions with the same structure as `txs`, computes gradients, and applies - them using `apply_gradients`. - - Attributes: - step: The current training step. - apply_fn: The function used to apply the model. - params: The model parameters. - target_params: The target model parameters. - txs: The optimizer or pytree of optimizers. - opt_states: The optimizer state or pytree of optimizer states. - rng: The internal rng state. - """ - - step: int - apply_fn: Callable = struct.field(pytree_node=False) - params: Params - target_params: Params - txs: Any = struct.field(pytree_node=False) - opt_states: Any - rng: PRNGKey - - @staticmethod - def _tx_tree_map(*args, **kwargs): - return jax.tree_map( - *args, - is_leaf=lambda x: isinstance(x, optax.GradientTransformation), - **kwargs, - ) - - def target_update(self, tau: float) -> "JaxRLTrainState": - """ - Performs an update of the target params via polyak averaging. The new - target params are given by: - - new_target_params = tau * params + (1 - tau) * target_params - """ - new_target_params = jax.tree_map( - lambda p, tp: p * tau + tp * (1 - tau), self.params, self.target_params - ) - return self.replace(target_params=new_target_params) - - def apply_gradients(self, *, grads: Any) -> "JaxRLTrainState": - """ - Only difference from flax's TrainState is that `grads` must have - `self.txs` as a tree prefix (i.e. where `self.txs` has a leaf, `grads` - has a subtree with the same structure as `self.params`.) - """ - updates_and_new_states = self._tx_tree_map( - lambda tx, opt_state, grad: tx.update(grad, opt_state, self.params), - self.txs, - self.opt_states, - grads, - ) - updates = self._tx_tree_map(lambda _, x: x[0], self.txs, updates_and_new_states) - new_opt_states = self._tx_tree_map( - lambda _, x: x[1], self.txs, updates_and_new_states - ) - - # not the cleanest, I know, but this flattens the leaves of `updates` - # into a list where leaves are defined by `self.txs` - updates_flat = [] - self._tx_tree_map( - lambda _, update: updates_flat.append(update), self.txs, updates - ) - - # apply all the updates additively - updates_acc = jax.tree_map( - lambda *xs: jnp.sum(jnp.array(xs), axis=0), *updates_flat - ) - new_params = optax.apply_updates(self.params, updates_acc) - - return self.replace( - step=self.step + 1, params=new_params, opt_states=new_opt_states - ) - - def apply_loss_fns( - self, loss_fns: Any, pmap_axis: str = None, has_aux: bool = False - ) -> Union["JaxRLTrainState", Tuple["JaxRLTrainState", Any]]: - """ - Convenience method to compute gradients based on `self.params` and apply - them using `apply_gradients`. `loss_fns` must have the same structure as - `txs`, and each leaf must be a function that takes two arguments: - `params` and `rng`. - - This method automatically provides fresh rng to each loss function and - updates this train state's internal rng key. - - Args: - loss_fns: loss function or pytree of loss functions with same - structure as `self.txs`. Each loss function must take `params` - as the first argument and `rng` as the second argument, and return - a scalar value. - pmap_axis: if not None, gradients (and optionally auxiliary values) - will be averaged over this axis - has_aux: if True, each `loss_fn` returns a tuple of (loss, aux) where - `aux` is a pytree of auxiliary values to be returned by this - method. - - Returns: - If `has_aux` is True, returns a tuple of (new_train_state, aux). - Otherwise, returns the new train state. - """ - # create a pytree of rngs with the same structure as `loss_fns` - treedef = jax.tree_util.tree_structure(loss_fns) - new_rng, *rngs = jax.random.split(self.rng, treedef.num_leaves + 1) - rngs = jax.tree_util.tree_unflatten(treedef, rngs) - - # compute gradients - grads_and_aux = jax.tree_map( - lambda loss_fn, rng: jax.grad(loss_fn, has_aux=has_aux)(self.params, rng), - loss_fns, - rngs, - ) - - # update rng state - self = self.replace(rng=new_rng) - - # average across devices if necessary - if pmap_axis is not None: - grads_and_aux = jax.lax.pmean(grads_and_aux, axis_name=pmap_axis) - - if has_aux: - grads = jax.tree_map(lambda _, x: x[0], loss_fns, grads_and_aux) - aux = jax.tree_map(lambda _, x: x[1], loss_fns, grads_and_aux) - return self.apply_gradients(grads=grads), aux - else: - return self.apply_gradients(grads=grads_and_aux) - - @classmethod - def create( - cls, *, apply_fn, params, txs, target_params=None, rng=jax.random.PRNGKey(0) - ): - """ - Initializes a new train state. - - Args: - apply_fn: The function used to apply the model, typically `model_def.apply`. - params: The model parameters, typically from `model_def.init`. - txs: The optimizer or pytree of optimizers. - target_params: The target model parameters. - rng: The rng key used to initialize the rng chain for `apply_loss_fns`. - """ - return cls( - step=0, - apply_fn=apply_fn, - params=params, - target_params=target_params, - txs=txs, - opt_states=cls._tx_tree_map(lambda tx: tx.init(params), txs), - rng=rng, - ) diff --git a/orca/common/evaluation.py b/orca/common/evaluation.py deleted file mode 100644 index ef3189fb..00000000 --- a/orca/common/evaluation.py +++ /dev/null @@ -1,154 +0,0 @@ -from collections import defaultdict -from typing import Dict - -import gym -import jax -import numpy as np - - -def supply_rng(f, rng=jax.random.PRNGKey(0)): - def wrapped(*args, **kwargs): - nonlocal rng - rng, key = jax.random.split(rng) - return f(*args, seed=key, **kwargs) - - return wrapped - - -def flatten(d, parent_key="", sep="."): - items = [] - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if hasattr(v, "items"): - items.extend(flatten(v, new_key, sep=sep).items()) - else: - items.append((new_key, v)) - return dict(items) - - -def filter_info(info): - filter_keys = [ - "object_names", - "target_object", - "initial_positions", - "target_position", - "goal", - ] - for k in filter_keys: - if k in info: - del info[k] - return info - - -def add_to(dict_of_lists, single_dict): - for k, v in single_dict.items(): - dict_of_lists[k].append(v) - - -def evaluate(policy_fn, env: gym.Env, num_episodes: int) -> Dict[str, float]: - stats = defaultdict(list) - for _ in range(num_episodes): - observation, info = env.reset() - add_to(stats, flatten(info)) - done = False - while not done: - action = policy_fn(observation) - observation, _, terminated, truncated, info = env.step(action) - done = terminated or truncated - add_to(stats, flatten(info)) - add_to(stats, flatten(info, parent_key="final")) - - for k, v in stats.items(): - stats[k] = np.mean(v) - return stats - - -def evaluate_with_trajectories( - policy_fn, env: gym.Env, num_episodes: int -) -> Dict[str, float]: - trajectories = [] - stats = defaultdict(list) - - for _ in range(num_episodes): - trajectory = defaultdict(list) - observation, info = env.reset() - add_to(stats, flatten(info)) - done = False - while not done: - action = policy_fn(observation) - next_observation, r, terminated, truncated, info = env.step(action) - done = terminated or truncated - transition = dict( - observation=observation, - next_observation=next_observation, - action=action, - reward=r, - done=done, - info=info, - ) - add_to(trajectory, transition) - add_to(stats, flatten(info)) - observation = next_observation - add_to(stats, flatten(info, parent_key="final")) - trajectories.append(trajectory) - - for k, v in stats.items(): - stats[k] = np.mean(v) - return stats, trajectories - - -def evaluate_gc( - policy_fn, - env: gym.Env, - num_episodes: int, - return_trajectories: bool = False, -) -> Dict[str, float]: - stats = defaultdict(list) - - if return_trajectories: - trajectories = [] - - for _ in range(num_episodes): - if return_trajectories: - trajectory = defaultdict(list) - - observation, info = env.reset() - goal = info["goal"] - add_to(stats, flatten(filter_info(info))) - done = False - - while not done: - action = policy_fn(observation, goal) - next_observation, r, terminated, truncated, info = env.step(action) - goal = info["goal"] - done = terminated or truncated - transition = dict( - observation=observation, - next_observation=next_observation, - goal=goal, - action=action, - reward=r, - done=done, - info=info, - ) - - add_to(stats, flatten(filter_info(info))) - - if return_trajectories: - add_to(trajectory, transition) - - observation = next_observation - - add_to(stats, flatten(filter_info(info), parent_key="final")) - if return_trajectories: - trajectory["steps_remaining"] = list( - np.arange(len(trajectory["action"]))[::-1] - ) - trajectories.append(trajectory) - - stats = {k: np.mean(v) for k, v in stats.items() if not isinstance(v[0], str)} - - if return_trajectories: - return stats, trajectories - else: - return stats diff --git a/orca/common/wandb.py b/orca/common/wandb.py deleted file mode 100644 index 106bbdaf..00000000 --- a/orca/common/wandb.py +++ /dev/null @@ -1,82 +0,0 @@ -import datetime -import tempfile -from copy import copy -from socket import gethostname - -import absl.flags as flags -import ml_collections -import wandb - - -def _recursive_flatten_dict(d: dict): - keys, values = [], [] - for key, value in d.items(): - if isinstance(value, dict): - sub_keys, sub_values = _recursive_flatten_dict(value) - keys += [f"{key}/{k}" for k in sub_keys] - values += sub_values - else: - keys.append(key) - values.append(value) - return keys, values - - -class WandBLogger(object): - @staticmethod - def get_default_config(): - config = ml_collections.ConfigDict() - config.project = "orca" # WandB Project Name - config.entity = ml_collections.config_dict.FieldReference(None, field_type=str) - # Which entity to log as (default: your own user) - config.exp_descriptor = "" # Run name (doesn't have to be unique) - # Unique identifier for run (will be automatically generated unless - # provided) - config.unique_identifier = "" - return config - - def __init__(self, wandb_config, variant, wandb_output_dir=None, debug=False): - self.config = wandb_config - if self.config.unique_identifier == "": - self.config.unique_identifier = datetime.datetime.now().strftime( - "%Y%m%d_%H%M%S" - ) - - self.config.experiment_id = ( - self.experiment_id - ) = f"{self.config.exp_descriptor}_{self.config.unique_identifier}" # NOQA - - print(self.config) - - if wandb_output_dir is None: - wandb_output_dir = tempfile.mkdtemp() - - self._variant = copy(variant) - - if "hostname" not in self._variant: - self._variant["hostname"] = gethostname() - - if debug: - mode = "disabled" - else: - mode = "online" - - self.run = wandb.init( - config=self._variant, - project=self.config.project, - entity=self.config.entity, - dir=wandb_output_dir, - id=self.config.experiment_id, - save_code=True, - mode=mode, - ) - - flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS} - for k in flag_dict: - if isinstance(flag_dict[k], ml_collections.ConfigDict): - flag_dict[k] = flag_dict[k].to_dict() - wandb.config.update(flag_dict) - - def log(self, data: dict, step: int = None): - data_flat = _recursive_flatten_dict(data) - data = {k: v for k, v in zip(*data_flat)} - wandb.log(data, step=step) diff --git a/orca/data/bridge_dataset.py b/orca/data/bridge_dataset.py index 4746dd17..0e5a7744 100644 --- a/orca/data/bridge_dataset.py +++ b/orca/data/bridge_dataset.py @@ -14,6 +14,7 @@ import tensorflow as tf from absl import logging from flax.core import FrozenDict + from orca.data.text_processing import TextProcessor from orca.data.tf_augmentations import augment from orca.data.tf_goal_relabeling import GOAL_RELABELING_FUNCTIONS diff --git a/orca/data/rlds_dataset.py b/orca/data/rlds_dataset.py index e64c96ba..666612d6 100644 --- a/orca/data/rlds_dataset.py +++ b/orca/data/rlds_dataset.py @@ -7,6 +7,7 @@ import tensorflow as tf import tensorflow_datasets as tfds import tqdm + from orca.data.bridge_dataset import BridgeDataset from orca.utils.rlds_data_utils import RLDS_TRAJECTORY_MAP_TRANSFORMS diff --git a/orca/model/__init__.py b/orca/model/__init__.py new file mode 100644 index 00000000..e1c720df --- /dev/null +++ b/orca/model/__init__.py @@ -0,0 +1,30 @@ +import logging + +from .tokenizers import tokenizers +from .transformer_policy import TransformerPolicy + + +def create_model_def( + observation_tokenizer_kwargs, + task_tokenizer_kwargs, + action_dim, + time_sequence_length, + policy_kwargs, + **kwargs, +): + if len(kwargs) > 0: + logging.warn(f"Extra kwargs passed into create_model_def: {kwargs}") + observation_tokenizer_defs = tuple( + tokenizers[k](**kwargs) for k, kwargs in observation_tokenizer_kwargs.items() + ) + task_tokenizer_defs = tuple( + tokenizers[k](**kwargs) for k, kwargs in task_tokenizer_kwargs.items() + ) + model_def = TransformerPolicy( + observation_tokenizers=observation_tokenizer_defs, + task_tokenizers=task_tokenizer_defs, + action_dim=action_dim, + time_sequence_length=time_sequence_length, + **policy_kwargs, + ) + return model_def diff --git a/orca/networks/clip.py b/orca/model/clip.py similarity index 100% rename from orca/networks/clip.py rename to orca/model/clip.py diff --git a/orca/networks/input_tokenizers.py b/orca/model/tokenizers.py similarity index 88% rename from orca/networks/input_tokenizers.py rename to orca/model/tokenizers.py index a08fbfee..a12d1333 100644 --- a/orca/networks/input_tokenizers.py +++ b/orca/model/tokenizers.py @@ -1,19 +1,39 @@ import functools as ft +from typing import Callable, Optional, Sequence import flax.linen as nn import jax import jax.numpy as jnp from jax.scipy.stats import norm -from orca.networks.clip import ( - CLIPTextTokenizer, - CLIPVisionTokenizer, - clip_weights_loader, -) -from orca.networks.mlp import MLP -from orca.vision import encoders + +from orca.model.clip import CLIPTextTokenizer, CLIPVisionTokenizer, clip_weights_loader +from orca.model.vision import encoders EPS = 1e-6 + +# Originally from jaxrl_m/networks/mlp.py +class MLP(nn.Module): + hidden_dims: Sequence[int] + activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish + activate_final: bool = False + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + kernel_init: Callable = nn.initializers.xavier_uniform + + @nn.compact + def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: + for i, size in enumerate(self.hidden_dims): + x = nn.Dense(size, kernel_init=self.kernel_init())(x) + if i + 1 < len(self.hidden_dims) or self.activate_final: + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = self.activations(x) + return x + + # adapted from https://github.com/google-research/robotics_transformer/blob/master/tokenizers/token_learner.py class TokenLearner(nn.Module): num_tokens: int diff --git a/orca/networks/transformer.py b/orca/model/transformer.py similarity index 100% rename from orca/networks/transformer.py rename to orca/model/transformer.py diff --git a/orca/networks/transformer_policy.py b/orca/model/transformer_policy.py similarity index 98% rename from orca/networks/transformer_policy.py rename to orca/model/transformer_policy.py index 42573d88..f6a8087f 100644 --- a/orca/networks/transformer_policy.py +++ b/orca/model/transformer_policy.py @@ -3,9 +3,10 @@ import jax import jax.numpy as jnp import numpy as np -from orca.common.typing import PRNGKey, Sequence -from orca.networks.input_tokenizers import ActionTokenizer -from orca.networks.transformer import Transformer + +from orca.model.tokenizers import ActionTokenizer +from orca.model.transformer import Transformer +from orca.typing import PRNGKey, Sequence class TransformerPolicy(nn.Module): diff --git a/orca/model/vision/__init__.py b/orca/model/vision/__init__.py new file mode 100644 index 00000000..1218ef2a --- /dev/null +++ b/orca/model/vision/__init__.py @@ -0,0 +1,4 @@ +from orca.model.vision.resnet_v1 import resnetv1_configs + +encoders = dict() +encoders.update(resnetv1_configs) diff --git a/orca/vision/film_conditioning_layer.py b/orca/model/vision/film_conditioning_layer.py similarity index 98% rename from orca/vision/film_conditioning_layer.py rename to orca/model/vision/film_conditioning_layer.py index 0d216abc..40d00f00 100644 --- a/orca/vision/film_conditioning_layer.py +++ b/orca/model/vision/film_conditioning_layer.py @@ -2,7 +2,8 @@ import flax.linen as nn import jax.numpy as jnp -from orca.common.typing import * + +from orca.typing import * class FilmConditioning(nn.Module): diff --git a/orca/vision/resnet_v1.py b/orca/model/vision/resnet_v1.py similarity index 99% rename from orca/vision/resnet_v1.py rename to orca/model/vision/resnet_v1.py index a0332efd..01190c72 100644 --- a/orca/vision/resnet_v1.py +++ b/orca/model/vision/resnet_v1.py @@ -8,7 +8,8 @@ # import flax.linen as nn # import jax.numpy as jnp import numpy as np -from orca.vision.film_conditioning_layer import FilmConditioning + +from orca.model.vision.film_conditioning_layer import FilmConditioning ModuleDef = Any diff --git a/orca/networks/mlp.py b/orca/networks/mlp.py deleted file mode 100644 index 363bb245..00000000 --- a/orca/networks/mlp.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Callable, Optional, Sequence - -import flax.linen as nn -import jax.numpy as jnp -from orca.common.common import default_init - - -class MLP(nn.Module): - hidden_dims: Sequence[int] - activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.swish - activate_final: bool = False - use_layer_norm: bool = False - dropout_rate: Optional[float] = None - - @nn.compact - def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: - for i, size in enumerate(self.hidden_dims): - x = nn.Dense(size, kernel_init=default_init())(x) - - if i + 1 < len(self.hidden_dims) or self.activate_final: - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = self.activations(x) - return x - - -class MLPResNetBlock(nn.Module): - features: int - act: Callable - dropout_rate: float = None - use_layer_norm: bool = False - - @nn.compact - def __call__(self, x, train: bool = False): - residual = x - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = nn.Dense(self.features * 4)(x) - x = self.act(x) - x = nn.Dense(self.features)(x) - - if residual.shape != x.shape: - residual = nn.Dense(self.features)(residual) - - return residual + x - - -class MLPResNet(nn.Module): - num_blocks: int - out_dim: int - dropout_rate: float = None - use_layer_norm: bool = False - hidden_dim: int = 256 - activations: Callable = nn.swish - - @nn.compact - def __call__(self, x: jnp.ndarray, train: bool = False) -> jnp.ndarray: - x = nn.Dense(self.hidden_dim, kernel_init=default_init())(x) - for _ in range(self.num_blocks): - x = MLPResNetBlock( - self.hidden_dim, - act=self.activations, - use_layer_norm=self.use_layer_norm, - dropout_rate=self.dropout_rate, - )(x, train=train) - - x = self.activations(x) - x = nn.Dense(self.out_dim, kernel_init=default_init())(x) - return x diff --git a/orca/train_utils.py b/orca/train_utils.py new file mode 100644 index 00000000..c6209e91 --- /dev/null +++ b/orca/train_utils.py @@ -0,0 +1,110 @@ +import logging +import time +from collections import defaultdict + +import flax +import jax +from flax.training import train_state +from jax.experimental.compilation_cache import compilation_cache + +from orca.typing import PRNGKey + + +class TrainState(train_state.TrainState): + rng: PRNGKey + + +def shard_batch(batch, sharding): + """Shards a batch across devices along its first dimension. + + Args: + batch: A pytree of arrays. + sharding: A jax Sharding object with shape (num_devices,). + """ + return jax.tree_map( + lambda x: jax.device_put( + x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) + ), + batch, + ) + + +def create_train_state( + rng, model_def, tx, init_args=(), init_kwargs=dict(), pretrained_loaders=tuple() +): + """Utility to create a TrainState.""" + init_rng, state_rng = jax.random.split(rng) + + # Initializing the model in a jit avoids running the model on CPU + @jax.jit + def init(rng): + return model_def.init(rng, *init_args, **init_kwargs) + + ev, params = flax.core.pop(init(init_rng), "params") + assert ( + len(ev) == 0 + ), "Are you forgetting to store some variables in the state? {}".format(ev.keys()) + + for loader in pretrained_loaders: + params = loader(params) + + return TrainState.create( + apply_fn=model_def.apply, + params=params, + tx=tx, + rng=state_rng, + ) + + +def format_name_with_config(name, config): + """Formats a name string with a config dict. + + Formatting keys may be specified as {key} or {full_path_to_key_with_underscores}. + + Example: + name = "model_{model_type}_{model_size}" + config = {"model_type": "transformer", "model_size": "small"} + format_name_with_config(name, config) -> "model_transformer_small" + """ + config_flat = flax.traverse_util.flatten_dict(config, sep="_") + config_final = {k.split("_")[-1]: v for k, v in config_flat.items()} + format_dict = {**config_final, **config_flat} + return name.format(**format_dict) + + +class Timer: + def __init__(self): + self.reset() + + def reset(self): + self.counts = defaultdict(int) + self.times = defaultdict(float) + self.start_times = {} + + def tick(self, key): + if key in self.start_times: + raise ValueError(f"Timer is already ticking for key: {key}") + self.start_times[key] = time.time() + + def tock(self, key): + if key not in self.start_times: + raise ValueError(f"Timer is not ticking for key: {key}") + self.counts[key] += 1 + self.times[key] += time.time() - self.start_times[key] + del self.start_times[key] + + def get_average_times(self, reset=True): + ret = {key: self.times[key] / self.counts[key] for key in self.counts} + if reset: + self.reset() + return ret + + +def initialize_compilation_cache(cache_dir="/tmp/jax_cache"): + compilation_cache.initialize_cache(cache_dir) + + for logger in [logging.getLogger(name) for name in logging.root.manager.loggerDict]: + logger.addFilter( + lambda record: "Not writing persistent cache entry for" + not in record.getMessage() + ) diff --git a/orca/common/typing.py b/orca/typing.py similarity index 100% rename from orca/common/typing.py rename to orca/typing.py diff --git a/orca/vision/__init__.py b/orca/vision/__init__.py deleted file mode 100644 index f4bbacbd..00000000 --- a/orca/vision/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from orca.vision.resnet_v1 import resnetv1_configs - -encoders = dict() -encoders.update(resnetv1_configs)