Skip to content

Commit

Permalink
resolved #329
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Apr 25, 2024
1 parent def607e commit 70b663b
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions nobrainer/tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def write(
to_ras=True,
compressed=True,
processes=None,
chunksize=1,
multi_resolution=False,
resolutions=None,
verbose=1,
Expand All @@ -53,15 +52,13 @@ def write(
writing to multiple TFRecord files (i.e.,
`examples_per_shard` < `len(features_labels)`). If `None`, uses all available
cores.
chunksize: int, multiprocessing chunksize.
multi_resolution: boolean, if `True`, different tfrecords for each resolution in each shard
resolutions: list of ints, if multi_resolution is `True`, set resolutions for
which tfrecords are created. For example, [4, 8, 16, 32, 64, 128, 256]
verbose: int, if 1, print progress bar. If 0, print nothing.
"""
n_examples = len(features_labels)
n_shards = math.ceil(n_examples / examples_per_shard)
shards = np.array_split(features_labels, n_shards)
shards = np.array_split(features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard))

# Test that the `filename_template` has a `shard` formatting key.
try:
Expand All @@ -80,7 +77,7 @@ def write(
# This is the object that returns a protocol buffer string of the feature and label
# on each iteration. It is pickle-able, unlike a generator.
proto_iterators = [
_ProtoIterator(s, multi_resolution=multi_resolution, resolutions=resolutions)
_ProtoIterator(s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions)
for s in shards
]
# Set up positional arguments for the core writer function.
Expand All @@ -90,14 +87,14 @@ def write(
# Set keyword arguments so the resulting function accepts one positional argument.
map_fn = functools.partial(
_write_tfrecords,
compressed=True,
compressed=compressed,
multi_resolution=multi_resolution,
resolutions=resolutions,
)

if processes is None:
processes = get_num_parallel()
Parallel(n_jobs=processes, verbose=10)(
Parallel(n_jobs=processes, verbose=verbose)(
delayed(__writer_func)(val, map_fn) for val in iterable
)
from joblib.externals.loky import get_reusable_executor
Expand Down

0 comments on commit 70b663b

Please sign in to comment.