From df6bb893d17c18f79a25fbec8e35123c4e2fdd15 Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 13 May 2022 12:33:03 -0700 Subject: [PATCH 1/4] merged --- test/test_iterdatapipe.py | 13 +++++++++++++ torchdata/datapipes/iter/__init__.py | 8 +++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 8fc4cb16a..bc13105b4 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -32,6 +32,7 @@ MapKeyZipper, MaxTokenBucketizer, ParagraphAggregator, + RenameKeys, Rows2Columnar, SampleMultiplexer, UnZipper, @@ -902,6 +903,18 @@ def test_mux_longest_iterdatapipe(self): with self.assertRaises(TypeError): len(output_dp) + def test_renamer(self): + + # Functional Test: verify that renaming by patterns yields correct output + stage1 = IterableWrapper([ + {"1.txt": "1", "1.bin": "1b"}, + {"2.txt": "2", "2.bin": "2b"}, + ]) + stage2 = RenameKeys(stage1, t="*.txt", b="*.bin") + output = list(iter(stage2)) + assert len(output) == 2 + assert set(output[0].keys()) == set(["t", "b"]) + def test_zip_longest_iterdatapipe(self): # Functional Test: raises TypeError when an input is not of type `IterDataPipe` diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index a4109cced..60369865c 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -109,7 +109,12 @@ TFRecordLoaderIterDataPipe as TFRecordLoader, ) from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper -from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset +from torchdata.datapipes.iter.util.webdataset import ( + WebDatasetIterDataPipe as WebDataset, +) +from torchdata.datapipes.iter.util.renamekeys import ( + RenameKeysIterDataPipe as RenameKeys, +) from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, XzFileReaderIterDataPipe as XzFileReader, @@ -172,6 +177,7 @@ "ParagraphAggregator", "ParquetDataFrameLoader", "RarArchiveLoader", + "RenameKeys", "RoutedDecoder", "Rows2Columnar", "S3FileLister", From 1598726153512a8400a90f80b7b39dca8f7969b2 Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 13 May 2022 12:40:55 -0700 Subject: [PATCH 2/4] added renamekeys --- torchdata/datapipes/iter/util/renamekeys.py | 79 +++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 torchdata/datapipes/iter/util/renamekeys.py diff --git a/torchdata/datapipes/iter/util/renamekeys.py b/torchdata/datapipes/iter/util/renamekeys.py new file mode 100644 index 000000000..cb526054c --- /dev/null +++ b/torchdata/datapipes/iter/util/renamekeys.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from fnmatch import fnmatch +from typing import Dict, Iterator, List, Union + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + + +@functional_datapipe("rename_keys") +class RenameKeysIterDataPipe(IterDataPipe[Dict]): + r""" + Given a stream of dictionaries, rename keys using glob patterns. + + Args: + source_datapipe: a DataPipe yielding a stream of dictionaries. + keep_unselected: keep keys/value pairs even if they don't match any pattern (False) + must_match: all key value pairs must match (True) + duplicate_is_error: it is an error if two renamings yield the same key (True) + *args: `(renamed, pattern)` pairs + **kw: `renamed=pattern` pairs + + Returns: + a DataPipe yielding a stream of dictionaries. + + Examples: + >>> dp = IterableWrapper([{"/a/b.jpg": b"data"}]).rename_keys(image="*.jpg") + """ + + def __init__( + self, + source_datapipe: IterDataPipe[List[Union[Dict, List]]], + *args, + keep_unselected=False, + must_match=True, + duplicate_is_error=True, + **kw, + ) -> None: + super().__init__() + self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe + self.must_match = must_match + self.keep_unselected = keep_unselected + self.duplicate_is_error = duplicate_is_error + self.renamings = [(pattern, output) for output, pattern in args] + self.renamings += [(pattern, output) for output, pattern in kw.items()] + + def __iter__(self) -> Iterator[Dict]: + for sample in self.source_datapipe: + new_sample = {} + matched = {k: False for k, _ in self.renamings} + for path, value in sample.items(): + fname = re.sub(r".*/", "", path) + new_name = None + for pattern, name in self.renamings[::-1]: + if fnmatch(fname.lower(), pattern): + matched[pattern] = True + new_name = name + break + if new_name is None: + if self.keep_unselected: + new_sample[path] = value + continue + if new_name in new_sample: + if self.duplicate_is_error: + raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") + continue + new_sample[new_name] = value + if self.must_match and not all(matched.values()): + raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") + + yield new_sample + + def __len__(self) -> int: + return len(self.source_datapipe) From c250c511347b91b79917ce3f4fc56104eaef2e59 Mon Sep 17 00:00:00 2001 From: Tom Date: Wed, 31 Aug 2022 17:06:07 -0700 Subject: [PATCH 3/4] renamed to KeyRenamerIterDatapipe --- test/test_iterdatapipe.py | 2 +- torchdata/datapipes/iter/__init__.py | 2 +- torchdata/datapipes/iter/util/renamekeys.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index bc13105b4..8cd138e17 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -910,7 +910,7 @@ def test_renamer(self): {"1.txt": "1", "1.bin": "1b"}, {"2.txt": "2", "2.bin": "2b"}, ]) - stage2 = RenameKeys(stage1, t="*.txt", b="*.bin") + stage2 = stage1.rename_keys(t="*.txt", b="*.bin") output = list(iter(stage2)) assert len(output) == 2 assert set(output[0].keys()) == set(["t", "b"]) diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 60369865c..3139a4caa 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -113,7 +113,7 @@ WebDatasetIterDataPipe as WebDataset, ) from torchdata.datapipes.iter.util.renamekeys import ( - RenameKeysIterDataPipe as RenameKeys, + KeyRenamerIterDataPipe as RenameKeys, ) from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, diff --git a/torchdata/datapipes/iter/util/renamekeys.py b/torchdata/datapipes/iter/util/renamekeys.py index cb526054c..1d36268c4 100644 --- a/torchdata/datapipes/iter/util/renamekeys.py +++ b/torchdata/datapipes/iter/util/renamekeys.py @@ -13,7 +13,7 @@ @functional_datapipe("rename_keys") -class RenameKeysIterDataPipe(IterDataPipe[Dict]): +class KeyRenamerIterDataPipe(IterDataPipe[Dict]): r""" Given a stream of dictionaries, rename keys using glob patterns. From 303563123079e50c67edfac8a01d1626e900cf97 Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 1 Sep 2022 01:02:54 -0700 Subject: [PATCH 4/4] resolved issues in renamekeys.py and improved tests --- test/test_iterdatapipe.py | 35 +++++++++++++++++++-- torchdata/datapipes/iter/util/renamekeys.py | 20 ++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 8cd138e17..d5b7bec53 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -910,10 +910,41 @@ def test_renamer(self): {"1.txt": "1", "1.bin": "1b"}, {"2.txt": "2", "2.bin": "2b"}, ]) + stage2 = stage1.rename_keys(("t", "*.txt"), ("b", "*.bin")) + output = list(iter(stage2)) + self.assertEqual(output, [ + {"t": "1", "b": "1b"}, + {"t": "2", "b": "2b"}, + ]) + + # Functional Test: verify that renaming by patterns yields correct output stage2 = stage1.rename_keys(t="*.txt", b="*.bin") output = list(iter(stage2)) - assert len(output) == 2 - assert set(output[0].keys()) == set(["t", "b"]) + self.assertEqual(output, [ + {"t": "1", "b": "1b"}, + {"t": "2", "b": "2b"}, + ]) + + # Functional test: verify that must_match raises a ValueError + with self.assertRaisesRegex(ValueError, r"Not all patterns"): + stage2 = stage1.rename_keys(t="*.txt", b="*.bin", c="*.csv", must_match=True) + output = list(iter(stage2)) + + # Functional test: verify that duplicate_is_error raises a ValueError + with self.assertRaisesRegex(ValueError, r"Duplicate value"): + stage2 = stage1.rename_keys(("t", "*.txt"), ("t", "*.bin"), duplicate_is_error=True) + output = list(iter(stage2)) + + # Functional test: verify more complex glob patterns + dp = IterableWrapper([ + {"/a/b.input.jpg": b"data1", "/a/b.target.jpg": b"data2"}, + {"/a/b.input.png": b"data1", "/a/b.target.png": b"data2"}, + ]).rename_keys(input="*.input.*", target="*.target.*") + self.assertEqual(list(dp), [ + {'input': b'data1', 'target': b'data2'}, + {'input': b'data1', 'target': b'data2'} + ]) + def test_zip_longest_iterdatapipe(self): diff --git a/torchdata/datapipes/iter/util/renamekeys.py b/torchdata/datapipes/iter/util/renamekeys.py index 1d36268c4..5ad233fdc 100644 --- a/torchdata/datapipes/iter/util/renamekeys.py +++ b/torchdata/datapipes/iter/util/renamekeys.py @@ -6,7 +6,7 @@ import re from fnmatch import fnmatch -from typing import Dict, Iterator, List, Union +from typing import Dict, Iterator, List, Union, Any from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -17,6 +17,11 @@ class KeyRenamerIterDataPipe(IterDataPipe[Dict]): r""" Given a stream of dictionaries, rename keys using glob patterns. + This is used for quickly extracting relevant fields from a stream of dictionaries + and renaming them to a common format. + + Note that if keys contain slashes, only the part after the last slash is matched. + Args: source_datapipe: a DataPipe yielding a stream of dictionaries. keep_unselected: keep keys/value pairs even if they don't match any pattern (False) @@ -29,12 +34,20 @@ class KeyRenamerIterDataPipe(IterDataPipe[Dict]): a DataPipe yielding a stream of dictionaries. Examples: - >>> dp = IterableWrapper([{"/a/b.jpg": b"data"}]).rename_keys(image="*.jpg") + >>> dp = IterableWrapper([{"/a/b.jpg": b"data"}]).rename_keys(("image", "*.jpg")) + >>> list(dp) + [{'image': b'data'}] + >>> dp = IterableWrapper([ + {"/a/b.input.jpg": b"data1", "/a/b.target.jpg": b"data2"}, + {"/a/b.input.png": b"data1", "/a/b.target.png": b"data2"}, + ]).rename_keys(input="*.input.*", output="*.target.*") + >>> list(dp) + [{'input': b'data1', 'target': b'data2'}, {'input': b'data1', 'target': b'data2'}] """ def __init__( self, - source_datapipe: IterDataPipe[List[Union[Dict, List]]], + source_datapipe: IterDataPipe[Dict[Any, Any]], *args, keep_unselected=False, must_match=True, @@ -42,6 +55,7 @@ def __init__( **kw, ) -> None: super().__init__() + assert not (keep_unselected and must_match) self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe self.must_match = must_match self.keep_unselected = keep_unselected