-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DataPipe] key renamer #402
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -121,7 +121,12 @@ | |||||
TFRecordLoaderIterDataPipe as TFRecordLoader, | ||||||
) | ||||||
from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper | ||||||
from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset | ||||||
from torchdata.datapipes.iter.util.webdataset import ( | ||||||
WebDatasetIterDataPipe as WebDataset, | ||||||
) | ||||||
from torchdata.datapipes.iter.util.renamekeys import ( | ||||||
KeyRenamerIterDataPipe as RenameKeys, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) | ||||||
from torchdata.datapipes.iter.util.xzfileloader import ( | ||||||
XzFileLoaderIterDataPipe as XzFileLoader, | ||||||
XzFileReaderIterDataPipe as XzFileReader, | ||||||
|
@@ -192,6 +197,7 @@ | |||||
"ParquetDataFrameLoader", | ||||||
"RandomSplitter", | ||||||
"RarArchiveLoader", | ||||||
"RenameKeys", | ||||||
"Repeater", | ||||||
"RoutedDecoder", | ||||||
"Rows2Columnar", | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,93 @@ | ||||||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
# All rights reserved. | ||||||
# | ||||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
import re | ||||||
from fnmatch import fnmatch | ||||||
from typing import Dict, Iterator, List, Union, Any | ||||||
|
||||||
from torchdata.datapipes import functional_datapipe | ||||||
from torchdata.datapipes.iter import IterDataPipe | ||||||
|
||||||
|
||||||
@functional_datapipe("rename_keys") | ||||||
class KeyRenamerIterDataPipe(IterDataPipe[Dict]): | ||||||
r""" | ||||||
Given a stream of dictionaries, rename keys using glob patterns. | ||||||
|
||||||
This is used for quickly extracting relevant fields from a stream of dictionaries | ||||||
and renaming them to a common format. | ||||||
|
||||||
Note that if keys contain slashes, only the part after the last slash is matched. | ||||||
|
||||||
Args: | ||||||
source_datapipe: a DataPipe yielding a stream of dictionaries. | ||||||
keep_unselected: keep keys/value pairs even if they don't match any pattern (False) | ||||||
must_match: all key value pairs must match (True) | ||||||
duplicate_is_error: it is an error if two renamings yield the same key (True) | ||||||
Comment on lines
+27
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Should we move these after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*args: `(renamed, pattern)` pairs | ||||||
**kw: `renamed=pattern` pairs | ||||||
|
||||||
Returns: | ||||||
a DataPipe yielding a stream of dictionaries. | ||||||
|
||||||
Examples: | ||||||
>>> dp = IterableWrapper([{"/a/b.jpg": b"data"}]).rename_keys(("image", "*.jpg")) | ||||||
>>> list(dp) | ||||||
[{'image': b'data'}] | ||||||
>>> dp = IterableWrapper([ | ||||||
{"/a/b.input.jpg": b"data1", "/a/b.target.jpg": b"data2"}, | ||||||
{"/a/b.input.png": b"data1", "/a/b.target.png": b"data2"}, | ||||||
]).rename_keys(input="*.input.*", output="*.target.*") | ||||||
>>> list(dp) | ||||||
[{'input': b'data1', 'target': b'data2'}, {'input': b'data1', 'target': b'data2'}] | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
source_datapipe: IterDataPipe[Dict[Any, Any]], | ||||||
*args, | ||||||
keep_unselected=False, | ||||||
must_match=True, | ||||||
duplicate_is_error=True, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
nit: might be a better name but feel free to ignore |
||||||
**kw, | ||||||
) -> None: | ||||||
super().__init__() | ||||||
assert not (keep_unselected and must_match) | ||||||
self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe | ||||||
tmbdev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
self.must_match = must_match | ||||||
self.keep_unselected = keep_unselected | ||||||
self.duplicate_is_error = duplicate_is_error | ||||||
self.renamings = [(pattern, output) for output, pattern in args] | ||||||
self.renamings += [(pattern, output) for output, pattern in kw.items()] | ||||||
|
||||||
def __iter__(self) -> Iterator[Dict]: | ||||||
for sample in self.source_datapipe: | ||||||
new_sample = {} | ||||||
matched = {k: False for k, _ in self.renamings} | ||||||
for path, value in sample.items(): | ||||||
fname = re.sub(r".*/", "", path) | ||||||
tmbdev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
new_name = None | ||||||
for pattern, name in self.renamings[::-1]: | ||||||
if fnmatch(fname.lower(), pattern): | ||||||
matched[pattern] = True | ||||||
new_name = name | ||||||
break | ||||||
if new_name is None: | ||||||
if self.keep_unselected: | ||||||
new_sample[path] = value | ||||||
continue | ||||||
if new_name in new_sample: | ||||||
if self.duplicate_is_error: | ||||||
raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") | ||||||
continue | ||||||
new_sample[new_name] = value | ||||||
if self.must_match and not all(matched.values()): | ||||||
raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") | ||||||
|
||||||
yield new_sample | ||||||
|
||||||
def __len__(self) -> int: | ||||||
return len(self.source_datapipe) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please test other boolean flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done