From 70b663be3bbb278815e598c9698603639e0ec25d Mon Sep 17 00:00:00 2001 From: Harsha Date: Thu, 25 Apr 2024 10:52:00 -0400 Subject: [PATCH] resolved https://github.com/neuronets/nobrainer/issues/329 --- nobrainer/tfrecord.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/nobrainer/tfrecord.py b/nobrainer/tfrecord.py index 405ba87d..0f3bbd95 100644 --- a/nobrainer/tfrecord.py +++ b/nobrainer/tfrecord.py @@ -26,7 +26,6 @@ def write( to_ras=True, compressed=True, processes=None, - chunksize=1, multi_resolution=False, resolutions=None, verbose=1, @@ -53,15 +52,13 @@ def write( writing to multiple TFRecord files (i.e., `examples_per_shard` < `len(features_labels)`). If `None`, uses all available cores. - chunksize: int, multiprocessing chunksize. multi_resolution: boolean, if `True`, different tfrecords for each resolution in each shard resolutions: list of ints, if multi_resolution is `True`, set resolutions for which tfrecords are created. For example, [4, 8, 16, 32, 64, 128, 256] verbose: int, if 1, print progress bar. If 0, print nothing. """ n_examples = len(features_labels) - n_shards = math.ceil(n_examples / examples_per_shard) - shards = np.array_split(features_labels, n_shards) + shards = np.array_split(features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard)) # Test that the `filename_template` has a `shard` formatting key. try: @@ -80,7 +77,7 @@ def write( # This is the object that returns a protocol buffer string of the feature and label # on each iteration. It is pickle-able, unlike a generator. proto_iterators = [ - _ProtoIterator(s, multi_resolution=multi_resolution, resolutions=resolutions) + _ProtoIterator(s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions) for s in shards ] # Set up positional arguments for the core writer function. @@ -90,14 +87,14 @@ def write( # Set keyword arguments so the resulting function accepts one positional argument. map_fn = functools.partial( _write_tfrecords, - compressed=True, + compressed=compressed, multi_resolution=multi_resolution, resolutions=resolutions, ) if processes is None: processes = get_num_parallel() - Parallel(n_jobs=processes, verbose=10)( + Parallel(n_jobs=processes, verbose=verbose)( delayed(__writer_func)(val, map_fn) for val in iterable ) from joblib.externals.loky import get_reusable_executor