From 4bf48f321f36e5bea3e0694f5ca1711cb95ca621 Mon Sep 17 00:00:00 2001 From: speedcell4 <3585459+speedcell4@users.noreply.github.com> Date: Tue, 27 Aug 2024 04:10:30 +0900 Subject: [PATCH] Feat: Update loader.py --- torchglyph/data/loader.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/torchglyph/data/loader.py b/torchglyph/data/loader.py index 264210b..c484669 100644 --- a/torchglyph/data/loader.py +++ b/torchglyph/data/loader.py @@ -2,7 +2,7 @@ from abc import ABCMeta from logging import getLogger from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union from datasets import Dataset, DownloadConfig, DownloadManager from torch.utils import data @@ -67,22 +67,34 @@ def new(cls, **kwargs): raise NotImplementedError +def unpack(obj: Union[Any, Tuple[Any, ...]]): + if not isinstance(obj, (list, tuple)): + return itertools.repeat(obj) + + return obj + + class DataLoader(data.DataLoader): dataset: Dataset @classmethod def new(cls, data_sources: Tuple[Dataset, ...], - collate_fn, batch_size: Union[int, Tuple[int, ...]], - drop_last: bool = False, section_size: int = 1 << 12, - sortish_key: str = 'size') -> List['DataLoader']: + collate_fn: Union[Callable, Tuple[Callable, ...]], + batch_size: Union[int, Tuple[int, ...]], + drop_last: Union[bool, Tuple[bool, ...]] = False, + section_size: Union[int, Tuple[int, ...]] = 1 << 12, + sortish_key: Union[str, Tuple[str, ...]] = 'size') -> List['DataLoader']: assert len(data_sources) > 0 - batch_sizes = batch_size - if isinstance(batch_size, int): - batch_sizes = itertools.repeat(batch_size) - loaders = [] - for index, (data_source, batch_size) in enumerate(zip(data_sources, batch_sizes)): + for index, (data_source, collate_fn, batch_size, drop_last, section_size, sortish_key) in enumerate(zip( + data_sources, + unpack(collate_fn), + unpack(batch_size), + unpack(drop_last), + unpack(section_size), + unpack(sortish_key), + )): training = index == 0 if not training: data_source = data_source.select(range(get_rank(), len(data_source), get_world_size()))