Skip to content

Commit

Permalink
Backports for v0.14.2 (#3063)
Browse files Browse the repository at this point in the history
* Fix `iterable.Cached`. (#3060)

* Torch: Remove double caching of dataset. (#3061)

---------

Co-authored-by: Jasper <schjaspe@amazon.de>
  • Loading branch information
lostella and Jasper authored Nov 27, 2023
1 parent 536465d commit 3c434d8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
25 changes: 15 additions & 10 deletions src/gluonts/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,9 @@ def split_into(xs: Sequence, n: int) -> Sequence:
@dataclass
class Cached:
"""
An iterable wrapper, which caches values in a list the first time it is
iterated.
An iterable wrapper, which caches values in a list while iterated.
The primary use-case for this is to avoid re-computing the element of the
The primary use-case for this is to avoid re-computing the elements of the
sequence, in case the inner iterable does it on demand.
This should be used to wrap deterministic iterables, i.e. iterables where
Expand All @@ -317,15 +316,21 @@ class Cached:
"""

iterable: SizedIterable
cache: list = field(default_factory=list, init=False)
provider: Iterable = field(init=False)
consumed: list = field(default_factory=list, init=False)

def __post_init__(self):
# ensure we only iterate once over the iterable
self.provider = iter(self.iterable)

def __iter__(self):
if not self.cache:
for element in self.iterable:
yield element
self.cache.append(element)
else:
yield from self.cache
# Yield already provided values first
yield from self.consumed

# Now yield remaining elements.
for element in self.provider:
self.consumed.append(element)
yield element

def __len__(self) -> int:
return len(self.iterable)
Expand Down
20 changes: 9 additions & 11 deletions src/gluonts/torch/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import NamedTuple, Optional, Iterable, Dict, Any, Union
from typing import NamedTuple, Optional, Iterable, Dict, Any
import logging

import numpy as np
Expand All @@ -24,7 +24,7 @@
from gluonts.itertools import Cached
from gluonts.model import Estimator, Predictor
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import Transformation, TransformedDataset
from gluonts.transform import Transformation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -156,18 +156,16 @@ def train_model(
transformation = self.create_transformation()

with env._let(max_idle_transforms=max(len(training_data), 100)):
transformed_training_data: Union[
Cached, TransformedDataset
] = transformation.apply(training_data, is_train=True)
transformed_training_data: Dataset = transformation.apply(
training_data, is_train=True
)
if cache_data:
transformed_training_data = Cached(transformed_training_data)

training_network = self.create_lightning_module()

training_data_loader = self.create_training_data_loader(
Cached(transformed_training_data)
if cache_data
else transformed_training_data,
transformed_training_data,
training_network,
shuffle_buffer_length=shuffle_buffer_length,
)
Expand All @@ -176,9 +174,9 @@ def train_model(

if validation_data is not None:
with env._let(max_idle_transforms=max(len(validation_data), 100)):
transformed_validation_data: Union[
Cached, TransformedDataset
] = transformation.apply(validation_data, is_train=True)
transformed_validation_data: Dataset = transformation.apply(
validation_data, is_train=True
)
if cache_data:
transformed_validation_data = Cached(
transformed_validation_data
Expand Down
10 changes: 10 additions & 0 deletions test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def test_pickle(iterable: Iterable, assert_content: bool):
assert data == data_copy


def test_cached_reentry():
data = Cached(range(10))

assert len(data) == 10
assert list(take(5, data)) == list(range(5))
assert len(data) == 10
assert list(take(10, data)) == list(range(10))
assert len(data) == 10


@pytest.mark.parametrize(
"given, expected",
[
Expand Down

0 comments on commit 3c434d8

Please sign in to comment.