diff --git a/datasets/flwr_datasets/partitioner/partitioner.py b/datasets/flwr_datasets/partitioner/partitioner.py index 24ca22bbebc..0404a11e772 100644 --- a/datasets/flwr_datasets/partitioner/partitioner.py +++ b/datasets/flwr_datasets/partitioner/partitioner.py @@ -50,6 +50,11 @@ def dataset(self, value: Dataset) -> None: "created partitions (in case the partitioning scheme needs to create " "the full partitioning also in order to return a single partition)." ) + if not isinstance(value, Dataset): + raise TypeError( + f"The dataset object you want to assign to the partitioner should be " + f"of type `datasets.Dataset` but given {type(value)}." + ) self._dataset = value @abstractmethod diff --git a/datasets/flwr_datasets/partitioner/partitioner_test.py b/datasets/flwr_datasets/partitioner/partitioner_test.py new file mode 100644 index 00000000000..be0c988e6a9 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/partitioner_test.py @@ -0,0 +1,59 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstract partitioner tests.""" + + +import unittest + +import datasets +from datasets import Dataset +from flwr_datasets.partitioner.partitioner import Partitioner + + +class DummyPartitioner(Partitioner): + """Dummy partitioner for testing.""" + + def load_partition(self, partition_id: int) -> Dataset: + """Return always a dummy dataset.""" + return datasets.Dataset.from_dict({"feature": [0, 1, 2]}) + + @property + def num_partitions(self) -> int: + """Return always 0.""" + return 0 + + +class TestPartitioner(unittest.TestCase): + """Test Partitioner.""" + + def test_dataset_setter_incorrect_type(self) -> None: + """Test if the incorrect type of the dataset to dataset.setter method raises.""" + train_split = datasets.Dataset.from_dict({"feature": [0, 1, 2]}) + test_split = datasets.Dataset.from_dict({"feature": [0, 1, 2]}) + dataset = datasets.DatasetDict({"train": train_split, "test": test_split}) + partitioner = DummyPartitioner() + + with self.assertRaises(Exception) as context: + partitioner.dataset = dataset + self.assertIn( + "The dataset object you want to assign to the partitioner should be of " + "type `datasets.Dataset` but given " + ".", + str(context.exception), + ) + + +if __name__ == "__main__": + unittest.main()