-
Notifications
You must be signed in to change notification settings - Fork 147
/
few_shot_dataset.py
34 lines (28 loc) · 1.11 KB
/
few_shot_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from abc import abstractmethod
from typing import List, Tuple
from torch import Tensor
from torch.utils.data import Dataset
class FewShotDataset(Dataset):
"""
Abstract class for all datasets used in a context of Few-Shot Learning.
The tools we use in few-shot learning, especially TaskSampler, expect an
implementation of FewShotDataset.
Compared to PyTorch's Dataset, FewShotDataset forces a method get_labels.
This exposes the list of all items labels and therefore allows to sample
items depending on their label.
"""
@abstractmethod
def __getitem__(self, item: int) -> Tuple[Tensor, int]:
raise NotImplementedError(
"All PyTorch datasets, including few-shot datasets, need a __getitem__ method."
)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError(
"All PyTorch datasets, including few-shot datasets, need a __len__ method."
)
@abstractmethod
def get_labels(self) -> List[int]:
raise NotImplementedError(
"Implementations of FewShotDataset need a get_labels method."
)