diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 37f1e084d4c..ed3d03fd144 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -213,6 +213,25 @@ def load_split(self, split: str) -> Dataset: self._check_if_split_present(split) return self._dataset[split] + @property + def partitioners(self) -> Dict[str, Partitioner]: + """Dictionary mapping each split to its associated partitioner. + + The returned partitioners have the splits of the dataset assigned to them. + """ + # This function triggers the dataset download (lazy download) and checks + # the partitioner specification correctness (which can also happen lazily only + # after the dataset download). + if not self._dataset_prepared: + self._prepare_dataset() + if self._dataset is None: + raise ValueError("Dataset is not loaded yet.") + partitioners_keys = list(self._partitioners.keys()) + for split in partitioners_keys: + self._check_if_split_present(split) + self._assign_dataset_to_partitioner(split) + return self._partitioners + def _check_if_split_present(self, split: str) -> None: """Check if the split (for partitioning or full return) is in the dataset.""" if self._dataset is None: