Skip to content

Commit

Permalink
add filter function support
Browse files Browse the repository at this point in the history
  • Loading branch information
kpertsch committed Mar 12, 2024
1 parent db78206 commit 546db37
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
13 changes: 11 additions & 2 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pprint_data_mixture,
tree_map,
)
from octo.utils.spec import ModuleSpec


def apply_trajectory_transforms(
Expand Down Expand Up @@ -213,6 +214,7 @@ def make_dataset_from_rlds(
absolute_action_mask: Optional[Sequence[bool]] = None,
action_normalization_mask: Optional[Sequence[bool]] = None,
norm_skip_keys: Optional[Sequence[str]] = None,
filter_functions: Sequence[ModuleSpec] = (),
num_parallel_reads: int = tf.data.AUTOTUNE,
num_parallel_calls: int = tf.data.AUTOTUNE,
) -> Tuple[dl.DLataset, dict]:
Expand Down Expand Up @@ -274,6 +276,8 @@ def make_dataset_from_rlds(
should be normalized. For example, you might not want to normalize the gripper action dimension if
it's always exactly 0 or 1. By default, all action dimensions are normalized.
norm_skip_keys (Sequence[str], optional): Provided keys will be skipped during normalization.
filter_functions (Sequence[ModuleSpec]): ModuleSpecs for filtering functions applied to the
raw dataset.
num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE.
num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE.
Returns:
Expand Down Expand Up @@ -372,14 +376,18 @@ def restructure(traj):
elif dataset_statistics is None:
full_dataset = dl.DLataset.from_rlds(
builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads
).traj_map(restructure, num_parallel_calls)
)
for filter_fcn_spec in filter_functions:
full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
full_dataset = full_dataset.traj_map(restructure, num_parallel_calls)
# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
hash_dependencies=(
str(builder.info),
str(state_obs_keys),
inspect.getsource(standardize_fn) if standardize_fn is not None else "",
*map(ModuleSpec.to_string, filter_functions),
),
save_dir=builder.data_dir,
)
Expand All @@ -406,7 +414,8 @@ def restructure(traj):
dataset = dl.DLataset.from_rlds(
builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads
)

for filter_fcn_spec in filter_functions:
dataset = dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
dataset = dataset.traj_map(restructure, num_parallel_calls)
dataset = dataset.traj_map(
partial(
Expand Down
9 changes: 9 additions & 0 deletions octo/utils/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def instantiate(spec: "ModuleSpec"): # type: ignore
cls = _import_from_string(spec["module"], spec["name"])
return partial(cls, *spec["args"], **spec["kwargs"])

@staticmethod
def to_string(spec: "ModuleSpec"): # type: ignore
return (
f"{spec['module']}:{spec['name']}"
f"({', '.join(spec['args'])}"
f"{', ' if spec['args'] and spec['kwargs'] else ''}"
f"{', '.join(f'{k}={v}' for k, v in spec['kwargs'].items())})"
)


def _infer_full_name(o: object):
if hasattr(o, "__module__") and hasattr(o, "__name__"):
Expand Down

0 comments on commit 546db37

Please sign in to comment.