Skip to content

Commit

Permalink
Feat: Update loader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Aug 26, 2024
1 parent 35af853 commit 4bf48f3
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions torchglyph/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABCMeta
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union

from datasets import Dataset, DownloadConfig, DownloadManager
from torch.utils import data
Expand Down Expand Up @@ -67,22 +67,34 @@ def new(cls, **kwargs):
raise NotImplementedError


def unpack(obj: Union[Any, Tuple[Any, ...]]):
if not isinstance(obj, (list, tuple)):
return itertools.repeat(obj)

return obj


class DataLoader(data.DataLoader):
dataset: Dataset

@classmethod
def new(cls, data_sources: Tuple[Dataset, ...],
collate_fn, batch_size: Union[int, Tuple[int, ...]],
drop_last: bool = False, section_size: int = 1 << 12,
sortish_key: str = 'size') -> List['DataLoader']:
collate_fn: Union[Callable, Tuple[Callable, ...]],
batch_size: Union[int, Tuple[int, ...]],
drop_last: Union[bool, Tuple[bool, ...]] = False,
section_size: Union[int, Tuple[int, ...]] = 1 << 12,
sortish_key: Union[str, Tuple[str, ...]] = 'size') -> List['DataLoader']:
assert len(data_sources) > 0

batch_sizes = batch_size
if isinstance(batch_size, int):
batch_sizes = itertools.repeat(batch_size)

loaders = []
for index, (data_source, batch_size) in enumerate(zip(data_sources, batch_sizes)):
for index, (data_source, collate_fn, batch_size, drop_last, section_size, sortish_key) in enumerate(zip(
data_sources,
unpack(collate_fn),
unpack(batch_size),
unpack(drop_last),
unpack(section_size),
unpack(sortish_key),
)):
training = index == 0
if not training:
data_source = data_source.select(range(get_rank(), len(data_source), get_world_size()))
Expand Down

0 comments on commit 4bf48f3

Please sign in to comment.