Skip to content

Commit

Permalink
Merge pull request #4 from octo-models/pytorch_dataloading_example
Browse files Browse the repository at this point in the history
Add pytorch dataloading example
  • Loading branch information
kpertsch authored Dec 14, 2023
2 parents ba89ada + 0483292 commit 627aed0
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 6 deletions.
3 changes: 0 additions & 3 deletions examples/05_dataloading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,6 @@
" window_size=2, # let's get some history\n",
" future_action_window_size=3, # let's get some future actions for action chunking\n",
" subsample_length=100, # subsampling long trajectories improves shuffling a lot\n",
" # let's filter outlier actions just in case (this is after normalization, so\n",
" # 4.0 represents 4 standard deviations from the mean)\n",
" max_action=4.0,\n",
" ),\n",
" # see `octo.data.dataset.apply_frame_transforms` for full documentation\n",
" # of these configuration options\n",
Expand Down
118 changes: 118 additions & 0 deletions examples/06_pytorch_oxe_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
This example shows how to use the `octo.data` dataloader with PyTorch by wrapping it in a simple PyTorch
dataloader. The config below also happens to be our exact pretraining config (except for the batch size and
shuffle buffer size, which are reduced for demonstration purposes).
"""
import numpy as np
from orca.data.dataset import make_interleaved_dataset
from orca.data.oxe import make_oxe_dataset_kwargs_and_weights
import tensorflow as tf
import torch
from torch.utils.data import DataLoader
import tqdm

DATA_PATH = "gs://rail-orca-central2/resize_256_256"

tf.config.set_visible_devices([], "GPU")


class TorchRLDSDataset(torch.utils.data.IterableDataset):
"""Thin wrapper around RLDS dataset for use with PyTorch dataloaders."""

def __init__(
self,
rlds_dataset,
train=True,
):
self._rlds_dataset = rlds_dataset
self._is_train = train

def __iter__(self):
for sample in self._rlds_dataset.as_numpy_iterator():
yield sample

def __len__(self):
lengths = np.array(
[
stats["num_transitions"]
for stats in self._rlds_dataset.dataset_statistics
]
)
if hasattr(self._rlds_dataset, "sample_weights"):
lengths *= np.array(self._rlds_dataset.sample_weights)
total_len = lengths.sum()
if self._is_train:
return int(0.95 * total_len)
else:
return int(0.05 * total_len)


dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights(
"oxe_magic_soup",
DATA_PATH,
load_camera_views=("primary", "wrist"),
)

dataset = make_interleaved_dataset(
dataset_kwargs_list,
sample_weights,
train=True,
shuffle_buffer_size=1000, # change to 500k for training, large shuffle buffers are important, but adjust to your RAM
batch_size=None, # batching will be handles in PyTorch Dataloader object
balance_weights=True,
traj_transform_kwargs=dict(
goal_relabeling_strategy="uniform",
window_size=2,
future_action_window_size=3,
subsample_length=100,
),
frame_transform_kwargs=dict(
image_augment_kwargs={
"primary": dict(
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_resized_crop",
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
),
"wrist": dict(
random_brightness=[0.1],
random_contrast=[0.9, 1.1],
random_saturation=[0.9, 1.1],
random_hue=[0.05],
augment_order=[
"random_brightness",
"random_contrast",
"random_saturation",
"random_hue",
],
),
},
resize_size=dict(
primary=(256, 256),
wrist=(128, 128),
),
num_parallel_calls=200,
),
traj_transform_threads=48,
traj_read_threads=48,
)


pytorch_dataset = TorchRLDSDataset(dataset)
dataloader = DataLoader(
pytorch_dataset,
batch_size=16,
num_workers=0, # important to keep this to 0 so PyTorch does not mess with the parallelism
)

for i, sample in tqdm.tqdm(enumerate(dataloader)):
if i == 5000:
break
8 changes: 5 additions & 3 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,9 @@ def make_interleaved_dataset(
*,
train: bool,
shuffle_buffer_size: int,
batch_size: int,
traj_transform_kwargs: dict = {},
frame_transform_kwargs: dict = {},
batch_size: Optional[int] = None,
balance_weights: bool = False,
traj_transform_threads: Optional[int] = None,
traj_read_threads: Optional[int] = None,
Expand All @@ -470,10 +470,10 @@ def make_interleaved_dataset(
sample_weights: sampling weights for each dataset in list. If None, defaults to uniform.
train: whether this is a training or validation dataset.
shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames).
batch_size: batch size.
traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is
overidden using `traj_transform_threads`.
frame_transform_kwargs: kwargs passed to `apply_frame_transforms`.
batch_size: batch size, if not provided output is not batched.
balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset.
This makes it so that, if all the sample weights are equal, one full iteration through the interleaved
dataset will correspond to one full iteration through each individual dataset (only in expectation,
Expand Down Expand Up @@ -544,11 +544,13 @@ def make_interleaved_dataset(
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)

# sequential batch (parallel batch seems to use much more memory)
dataset = dataset.batch(batch_size)
if batch_size is not None:
dataset = dataset.batch(batch_size)

# this seems to reduce memory usage without affecting speed
dataset = dataset.with_ram_budget(1)

# save for later
dataset.dataset_statistics = all_dataset_statistics
dataset.sample_weights = sample_weights
return dataset

0 comments on commit 627aed0

Please sign in to comment.