Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to provide stats to make_interleaved_dataset and skip keys during norm #62

Merged
merged 6 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/06_pytorch_oxe_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
shuffle buffer size, which are reduced for demonstration purposes).
"""
import numpy as np
from octo.data.dataset import make_interleaved_dataset
from octo.data.oxe import make_oxe_dataset_kwargs_and_weights
import tensorflow as tf
import torch
from torch.utils.data import DataLoader
import tqdm

from octo.data.dataset import make_interleaved_dataset
from octo.data.oxe import make_oxe_dataset_kwargs_and_weights

DATA_PATH = "gs://rail-orca-central2/resize_256_256"

tf.config.set_visible_devices([], "GPU")
Expand Down
36 changes: 28 additions & 8 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pprint_data_mixture,
tree_map,
)
from octo.utils.spec import ModuleSpec


def apply_trajectory_transforms(
Expand Down Expand Up @@ -212,6 +213,8 @@ def make_dataset_from_rlds(
dataset_statistics: Optional[Union[dict, str]] = None,
absolute_action_mask: Optional[Sequence[bool]] = None,
action_normalization_mask: Optional[Sequence[bool]] = None,
norm_skip_keys: Optional[Sequence[str]] = None,
filter_functions: Sequence[ModuleSpec] = (),
num_parallel_reads: int = tf.data.AUTOTUNE,
num_parallel_calls: int = tf.data.AUTOTUNE,
) -> Tuple[dl.DLataset, dict]:
Expand Down Expand Up @@ -272,6 +275,9 @@ def make_dataset_from_rlds(
action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions
should be normalized. For example, you might not want to normalize the gripper action dimension if
it's always exactly 0 or 1. By default, all action dimensions are normalized.
norm_skip_keys (Sequence[str], optional): Provided keys will be skipped during normalization.
filter_functions (Sequence[ModuleSpec]): ModuleSpecs for filtering functions applied to the
raw dataset.
num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE.
num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE.
Returns:
Expand Down Expand Up @@ -370,14 +376,18 @@ def restructure(traj):
elif dataset_statistics is None:
full_dataset = dl.DLataset.from_rlds(
builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads
).traj_map(restructure, num_parallel_calls)
)
for filter_fcn_spec in filter_functions:
full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
full_dataset = full_dataset.traj_map(restructure, num_parallel_calls)
# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
hash_dependencies=(
str(builder.info),
str(state_obs_keys),
inspect.getsource(standardize_fn) if standardize_fn is not None else "",
*map(ModuleSpec.to_string, filter_functions),
),
save_dir=builder.data_dir,
)
Expand All @@ -404,13 +414,15 @@ def restructure(traj):
dataset = dl.DLataset.from_rlds(
builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads
)

for filter_fcn_spec in filter_functions:
dataset = dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
dataset = dataset.traj_map(restructure, num_parallel_calls)
dataset = dataset.traj_map(
partial(
normalize_action_and_proprio,
metadata=dataset_statistics,
normalization_type=action_proprio_normalization_type,
skip_keys=norm_skip_keys,
),
num_parallel_calls,
)
Expand Down Expand Up @@ -456,6 +468,7 @@ def make_interleaved_dataset(
shuffle_buffer_size: int,
traj_transform_kwargs: dict = {},
frame_transform_kwargs: dict = {},
dataset_statistics: Optional[Union[dict, str]] = None,
batch_size: Optional[int] = None,
balance_weights: bool = False,
traj_transform_threads: Optional[int] = None,
Expand All @@ -473,6 +486,9 @@ def make_interleaved_dataset(
traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is
overidden using `traj_transform_threads`.
frame_transform_kwargs: kwargs passed to `apply_frame_transforms`.
dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics
for normalization, see `make_dataset_from_rlds` for details. If set, applies *the same* normalization
statistics to all interleaved datasets. By default, each dataset is normalized by its own statistics.
batch_size: batch size, if not provided output is not batched.
balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset.
This makes it so that, if all the sample weights are equal, one full iteration through the interleaved
Expand All @@ -495,9 +511,9 @@ def make_interleaved_dataset(
dataset_sizes = []
all_dataset_statistics = []
for dataset_kwargs in dataset_kwargs_list:
_, dataset_statistics = make_dataset_from_rlds(**dataset_kwargs, train=train)
dataset_sizes.append(dataset_statistics["num_transitions"])
all_dataset_statistics.append(dataset_statistics)
_, per_dataset_stats = make_dataset_from_rlds(**dataset_kwargs, train=train)
dataset_sizes.append(per_dataset_stats["num_transitions"])
all_dataset_statistics.append(per_dataset_stats)

# balance and normalize weights
if balance_weights:
Expand All @@ -514,7 +530,7 @@ def make_interleaved_dataset(

# construct datasets
datasets = []
for dataset_kwargs, dataset_statistics, threads, reads in zip(
for dataset_kwargs, per_dataset_stats, threads, reads in zip(
dataset_kwargs_list,
all_dataset_statistics,
threads_per_dataset,
Expand All @@ -525,7 +541,9 @@ def make_interleaved_dataset(
train=train,
num_parallel_calls=threads,
num_parallel_reads=reads,
dataset_statistics=dataset_statistics,
dataset_statistics=dataset_statistics
if dataset_statistics is not None
else per_dataset_stats,
)
dataset = apply_trajectory_transforms(
dataset.repeat(),
Expand All @@ -551,6 +569,8 @@ def make_interleaved_dataset(
dataset = dataset.with_ram_budget(1)

# save for later
dataset.dataset_statistics = all_dataset_statistics
dataset.dataset_statistics = (
dataset_statistics if dataset_statistics is not None else all_dataset_statistics
)
dataset.sample_weights = sample_weights
return dataset
61 changes: 59 additions & 2 deletions octo/data/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import dlimp as dl
import numpy as np
Expand Down Expand Up @@ -181,15 +181,72 @@ def get_dataset_statistics(
return metadata


def combine_dataset_statistics(
all_dataset_statistics: Sequence[dict],
) -> dict:
"""Merges dataset statistics from multiple datasets."""
merge_stat_keys = ["action", "proprio"]

num_trajectories = [stat["num_trajectories"] for stat in all_dataset_statistics]
num_transitions = [stat["num_transitions"] for stat in all_dataset_statistics]
stat_weights = [
transitions / sum(num_transitions) for transitions in num_transitions
]

combined_dataset_statistics = {}
for key in merge_stat_keys:
combined_mean = np.array(
[
stat[key]["mean"] * w
for stat, w in zip(all_dataset_statistics, stat_weights)
]
).sum(0)
# compute combined_std for denominator `n` instead of `n-1` since numpy uses that by default for std
# https://stats.stackexchange.com/questions/55999/is-it-possible-to-find-the-combined-standard-deviation
combined_std = np.sqrt(
np.array(
[
n * np.array(stat[key]["std"]) ** 2
+ n * (np.array(stat[key]["mean"]) - combined_mean) ** 2
for stat, n in zip(all_dataset_statistics, num_transitions)
]
).sum(0)
/ sum(num_transitions)
)
combined_dataset_statistics[key] = {
"min": np.array([stat[key]["min"] for stat in all_dataset_statistics])
.min(0)
.tolist(),
"max": np.array([stat[key]["max"] for stat in all_dataset_statistics])
.max(0)
.tolist(),
"mean": combined_mean.tolist(),
"std": combined_std.tolist(),
}

combined_dataset_statistics["num_trajectories"] = num_trajectories
combined_dataset_statistics["num_transitions"] = num_transitions
return combined_dataset_statistics


def normalize_action_and_proprio(
traj: dict, metadata: dict, normalization_type: NormalizationType
traj: dict, metadata: dict, normalization_type: NormalizationType, skip_keys=None
):
"""Normalizes the action and proprio fields of a trajectory using the given metadata."""
# maps keys of `metadata` to corresponding keys in `traj`
keys_to_normalize = {
"action": "action",
"proprio": "observation/proprio",
}
if skip_keys is not None:
for skip_key in skip_keys:
if skip_key not in keys_to_normalize:
raise ValueError(
f"{skip_key} cannot be skipped during normalization since it's not a valid key, "
f"choose from {keys_to_normalize.keys()}"
)
keys_to_normalize.pop(skip_key)

if normalization_type == NormalizationType.NORMAL:
# normalize to mean 0, std 1
for key, traj_key in keys_to_normalize.items():
Expand Down
9 changes: 9 additions & 0 deletions octo/utils/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def instantiate(spec: "ModuleSpec"): # type: ignore
cls = _import_from_string(spec["module"], spec["name"])
return partial(cls, *spec["args"], **spec["kwargs"])

@staticmethod
def to_string(spec: "ModuleSpec"): # type: ignore
return (
f"{spec['module']}:{spec['name']}"
f"({', '.join(spec['args'])}"
f"{', ' if spec['args'] and spec['kwargs'] else ''}"
f"{', '.join(f'{k}={v}' for k, v in spec['kwargs'].items())})"
)


def _infer_full_name(o: object):
if hasattr(o, "__module__") and hasattr(o, "__name__"):
Expand Down
Loading