Skip to content

Commit

Permalink
Merge pull request #1 from rail-berkeley/dibya_refactor
Browse files Browse the repository at this point in the history
Made large refactoring changes
  • Loading branch information
HomerW authored Aug 18, 2023
2 parents 039f57a + 3c43951 commit 65fc03f
Show file tree
Hide file tree
Showing 42 changed files with 1,057 additions and 736 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ repos:
hooks:
- id: isort
exclude: ^experiments/
args: ["--profile", "black", "--src", "jaxrl_m", "--src", "experiments"]
args: ["--profile", "black", "--src", "orca", "--src", "experiments"]
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@

A robot foundation model.

## Training

```
python experiments/main/train.py --config.data_path=<...> --config.save_dir=<...>
```

## Contributing
Experimental things and training/eval scripts should go in `experiments/`. To make any changes to files outside of `experiments/`, please open a pull request.
Experimental things and training/eval scripts should go in `experiments/<your_name>`. To make any changes to files outside of your experiments directory, please open a pull request.

To enable code checks and auto-formatting, please install pre-commit hooks:
```
pre-commit install
```

## Environment
```
Expand Down
15 changes: 15 additions & 0 deletions experiments/dibya/scripts/train_real.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

CONFIG_NAME=${1:-transformer_bc}
DATA_CONFIG_NAME=${2:-all}
echo "Using config $CONFIG_NAME and data config $DATA_CONFIG_NAME"
NAME="test"

CMD="python experiments/main/train.py \
--config experiments/main/configs/train_config.py:$CONFIG_NAME \
--bridgedata_config experiments/main/configs/data_config.py:$DATA_CONFIG_NAME \
--name $NAME \
--config.data_path=gs://rail-tpus-homer-v4/data_new"
shift 2
echo $CMD "$@"
$CMD "$@"
File renamed without changes.
4 changes: 2 additions & 2 deletions experiments/homer/configs/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_config(config_string):
"proprio": {
"mean": ACT_MEAN,
"std": ACT_STD,
}
},
},
}
),
Expand All @@ -201,7 +201,7 @@ def get_config(config_string):
"proprio": {
"mean": ACT_MEAN,
"std": ACT_STD,
}
},
},
}
),
Expand Down
9 changes: 4 additions & 5 deletions experiments/homer/configs/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def update_config(_prototype, **kwargs):
result.pop("_overwrite", None)
return ConfigDict(result)


def get_config(config_string):
base_sim_config = dict(
batch_size=256,
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_config(config_string):
task_tokenizers=["sim-goal-obs-tokenizer"],
task_tokenizer_kwargs={"sim-goal-obs-tokenizer": {}},
dataset_kwargs=base_data_config,
**base_sim_config
**base_sim_config,
)
),
"transformer_bc": ConfigDict(
Expand All @@ -115,7 +116,7 @@ def get_config(config_string):
task_tokenizers=["goal-obs-tokenizer"],
task_tokenizer_kwargs={"goal-obs-tokenizer": {}},
dataset_kwargs=base_data_config,
**base_real_config
**base_real_config,
)
),
"transformer_bc_film_lang": ConfigDict(
Expand All @@ -142,9 +143,7 @@ def get_config(config_string):
observation_tokenizers=["obs-tokenizer"],
observation_tokenizer_kwargs={"obs-tokenizer": {"num_tokens": 64}},
task_tokenizers=["language-tokenizer"],
task_tokenizer_kwargs={
"language-tokenizer": {"num_tokens": 16}
},
task_tokenizer_kwargs={"language-tokenizer": {"num_tokens": 16}},
**base_real_config,
)
),
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
env: gym.Env,
goal_sampler: Union[np.ndarray, Callable],
):

super().__init__(env)
self.env = env
self.observation_space = gym.spaces.Dict(
Expand Down Expand Up @@ -96,7 +95,6 @@ def render(self, *args, **kwargs):
return self.env.render_obs()

def reset(self, **kwargs):

if not callable(self.goal_sampler):
idx = np.random.randint(len(self.goal_sampler["observations"]["image"]))
goal_image = self.goal_sampler["observations"]["image"][idx]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def stop_recording(self):
self.num_record_episodes = None

def step(self, action: np.ndarray): # NOQA

if self.num_record_episodes is None or self.num_record_episodes == 0:
observation, reward, terminated, truncated, info = self.env.step(action)

Expand Down
35 changes: 23 additions & 12 deletions experiments/homer/eval.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import sys
from widowx_envs.widowx.widowx_env import BridgeDataRailRLPrivateWidowX
import os
import numpy as np
from PIL import Image
from flax.training import checkpoints
import sys
import traceback
import wandb
from orca.vision import encoders
from orca.agents import agents

import matplotlib
import numpy as np
import wandb
from absl import app, flags, logging
from flax.training import checkpoints
from PIL import Image
from widowx_envs.widowx.widowx_env import BridgeDataRailRLPrivateWidowX

from orca.agents import agents
from orca.vision import encoders

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import time
from collections import deque
from datetime import datetime

import jax
import time
import matplotlib.pyplot as plt
import tensorflow as tf
from widowx_envs.utils.multicam_server_rospkg.src.topic_utils import IMTopic
from collections import deque
from orca.utils.python_utils import list_of_dicts_to_dict_of_lists

np.set_printoptions(suppress=True)

Expand Down Expand Up @@ -56,6 +57,16 @@
WORKSPACE_BOUNDS = np.array([[0.1, -0.15, -0.1, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]])


def list_of_dicts_to_dict_of_lists(list_of_dicts):
dict_of_lists = {}
for dictionary in list_of_dicts:
for key, value in dictionary.items():
if key not in dict_of_lists:
dict_of_lists[key] = []
dict_of_lists[key].append(value)
return dict_of_lists


def unnormalize_action(action, mean, std):
return action * std + mean

Expand Down
Loading

0 comments on commit 65fc03f

Please sign in to comment.