diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 2a941315b..e92e7ff2e 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -31,6 +31,7 @@ MapKeyZipper, MaxTokenBucketizer, ParagraphAggregator, + RenameKeys, Repeater, Rows2Columnar, SampleMultiplexer, @@ -951,6 +952,49 @@ def test_mux_longest_iterdatapipe(self): with self.assertRaises(TypeError): len(output_dp) + def test_renamer(self): + + # Functional Test: verify that renaming by patterns yields correct output + stage1 = IterableWrapper([ + {"1.txt": "1", "1.bin": "1b"}, + {"2.txt": "2", "2.bin": "2b"}, + ]) + stage2 = stage1.rename_keys(("t", "*.txt"), ("b", "*.bin")) + output = list(iter(stage2)) + self.assertEqual(output, [ + {"t": "1", "b": "1b"}, + {"t": "2", "b": "2b"}, + ]) + + # Functional Test: verify that renaming by patterns yields correct output + stage2 = stage1.rename_keys(t="*.txt", b="*.bin") + output = list(iter(stage2)) + self.assertEqual(output, [ + {"t": "1", "b": "1b"}, + {"t": "2", "b": "2b"}, + ]) + + # Functional test: verify that must_match raises a ValueError + with self.assertRaisesRegex(ValueError, r"Not all patterns"): + stage2 = stage1.rename_keys(t="*.txt", b="*.bin", c="*.csv", must_match=True) + output = list(iter(stage2)) + + # Functional test: verify that duplicate_is_error raises a ValueError + with self.assertRaisesRegex(ValueError, r"Duplicate value"): + stage2 = stage1.rename_keys(("t", "*.txt"), ("t", "*.bin"), duplicate_is_error=True) + output = list(iter(stage2)) + + # Functional test: verify more complex glob patterns + dp = IterableWrapper([ + {"/a/b.input.jpg": b"data1", "/a/b.target.jpg": b"data2"}, + {"/a/b.input.png": b"data1", "/a/b.target.png": b"data2"}, + ]).rename_keys(input="*.input.*", target="*.target.*") + self.assertEqual(list(dp), [ + {'input': b'data1', 'target': b'data2'}, + {'input': b'data1', 'target': b'data2'} + ]) + + def test_zip_longest_iterdatapipe(self): # Functional Test: raises TypeError when an input is not of type `IterDataPipe` diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 4a2265d65..979adf9da 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -121,7 +121,12 @@ TFRecordLoaderIterDataPipe as TFRecordLoader, ) from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper -from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset +from torchdata.datapipes.iter.util.webdataset import ( + WebDatasetIterDataPipe as WebDataset, +) +from torchdata.datapipes.iter.util.renamekeys import ( + KeyRenamerIterDataPipe as RenameKeys, +) from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, XzFileReaderIterDataPipe as XzFileReader, @@ -192,6 +197,7 @@ "ParquetDataFrameLoader", "RandomSplitter", "RarArchiveLoader", + "RenameKeys", "Repeater", "RoutedDecoder", "Rows2Columnar", diff --git a/torchdata/datapipes/iter/util/renamekeys.py b/torchdata/datapipes/iter/util/renamekeys.py new file mode 100644 index 000000000..5ad233fdc --- /dev/null +++ b/torchdata/datapipes/iter/util/renamekeys.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from fnmatch import fnmatch +from typing import Dict, Iterator, List, Union, Any + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + + +@functional_datapipe("rename_keys") +class KeyRenamerIterDataPipe(IterDataPipe[Dict]): + r""" + Given a stream of dictionaries, rename keys using glob patterns. + + This is used for quickly extracting relevant fields from a stream of dictionaries + and renaming them to a common format. + + Note that if keys contain slashes, only the part after the last slash is matched. + + Args: + source_datapipe: a DataPipe yielding a stream of dictionaries. + keep_unselected: keep keys/value pairs even if they don't match any pattern (False) + must_match: all key value pairs must match (True) + duplicate_is_error: it is an error if two renamings yield the same key (True) + *args: `(renamed, pattern)` pairs + **kw: `renamed=pattern` pairs + + Returns: + a DataPipe yielding a stream of dictionaries. + + Examples: + >>> dp = IterableWrapper([{"/a/b.jpg": b"data"}]).rename_keys(("image", "*.jpg")) + >>> list(dp) + [{'image': b'data'}] + >>> dp = IterableWrapper([ + {"/a/b.input.jpg": b"data1", "/a/b.target.jpg": b"data2"}, + {"/a/b.input.png": b"data1", "/a/b.target.png": b"data2"}, + ]).rename_keys(input="*.input.*", output="*.target.*") + >>> list(dp) + [{'input': b'data1', 'target': b'data2'}, {'input': b'data1', 'target': b'data2'}] + """ + + def __init__( + self, + source_datapipe: IterDataPipe[Dict[Any, Any]], + *args, + keep_unselected=False, + must_match=True, + duplicate_is_error=True, + **kw, + ) -> None: + super().__init__() + assert not (keep_unselected and must_match) + self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe + self.must_match = must_match + self.keep_unselected = keep_unselected + self.duplicate_is_error = duplicate_is_error + self.renamings = [(pattern, output) for output, pattern in args] + self.renamings += [(pattern, output) for output, pattern in kw.items()] + + def __iter__(self) -> Iterator[Dict]: + for sample in self.source_datapipe: + new_sample = {} + matched = {k: False for k, _ in self.renamings} + for path, value in sample.items(): + fname = re.sub(r".*/", "", path) + new_name = None + for pattern, name in self.renamings[::-1]: + if fnmatch(fname.lower(), pattern): + matched[pattern] = True + new_name = name + break + if new_name is None: + if self.keep_unselected: + new_sample[path] = value + continue + if new_name in new_sample: + if self.duplicate_is_error: + raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") + continue + new_sample[new_name] = value + if self.must_match and not all(matched.values()): + raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") + + yield new_sample + + def __len__(self) -> int: + return len(self.source_datapipe)