Skip to content

Commit

Permalink
rename data_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
kpertsch committed Mar 14, 2024
1 parent 546db37 commit 1940301
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,9 @@ def make_interleaved_dataset(
dataset_sizes = []
all_dataset_statistics = []
for dataset_kwargs in dataset_kwargs_list:
_, data_stats = make_dataset_from_rlds(**dataset_kwargs, train=train)
dataset_sizes.append(data_stats["num_transitions"])
all_dataset_statistics.append(data_stats)
_, per_dataset_stats = make_dataset_from_rlds(**dataset_kwargs, train=train)
dataset_sizes.append(per_dataset_stats["num_transitions"])
all_dataset_statistics.append(per_dataset_stats)

# balance and normalize weights
if balance_weights:
Expand All @@ -530,7 +530,7 @@ def make_interleaved_dataset(

# construct datasets
datasets = []
for dataset_kwargs, data_stats, threads, reads in zip(
for dataset_kwargs, per_dataset_stats, threads, reads in zip(
dataset_kwargs_list,
all_dataset_statistics,
threads_per_dataset,
Expand All @@ -543,7 +543,7 @@ def make_interleaved_dataset(
num_parallel_reads=reads,
dataset_statistics=dataset_statistics
if dataset_statistics is not None
else data_stats,
else per_dataset_stats,
)
dataset = apply_trajectory_transforms(
dataset.repeat(),
Expand Down

0 comments on commit 1940301

Please sign in to comment.