Skip to content

Commit

Permalink
Add SizePartitioner
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Aug 30, 2024
1 parent 0269569 commit fd0cc0a
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 0 deletions.
128 changes: 128 additions & 0 deletions datasets/flwr_datasets/partitioner/size_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SizePartitioner class."""


import warnings
from typing import Dict, List, Sequence

import datasets
from flwr_datasets.partitioner.partitioner import Partitioner


class SizePartitioner(Partitioner):
"""Partitioner that creates each partition with the size specified by a user.
Parameters
----------
partition_sizes : Sequence[int]
The size of each partition. partition_id 0 will have partition_sizes[0]
samples, partition_id 1 will have partition_sizes[1] samples, etc.
Examples
--------
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import SizePartitioner
>>>
>>> partition_sizes = [20_000, 10_000 30_000]
>>> partitioner = SizePartitioner(partition_sizes)
>>> fds = FederatedDataset(dataset="cifar10", partitioners={"train": partitioner})
"""

def __init__(self, partition_sizes: Sequence[int]) -> None:
super().__init__()
self._pre_ds_validate_partition_sizes(partition_sizes)
self._partition_sizes = partition_sizes
self._partition_id_to_indices: Dict[int, List[int]] = {}
self._partition_id_to_indices_determined = False

def load_partition(self, partition_id: int) -> datasets.Dataset:
"""Load a single partition of the size of partition_sizes[partition_id].
For example if given partition_sizes=[20_000, 10_000, 30_000],
then partition_id=0 will return a partition of size 20_000,
partition_id=1 will return a partition of size 10_000, etc.
Parameters
----------
partition_id : int
The index that corresponds to the requested partition.
Returns
-------
dataset_partition : Dataset
Single dataset partition.
"""
self._determine_partition_id_to_indices_if_needed()
return self.dataset.select(self._partition_id_to_indices[partition_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
self._determine_partition_id_to_indices_if_needed()
return len(self._partition_sizes)

@property
def partition_id_to_indices(self) -> Dict[int, List[int]]:
"""Partition id to indices (the result of partitioning)."""
self._determine_partition_id_to_indices_if_needed()
return self._partition_id_to_indices

def _determine_partition_id_to_indices_if_needed(
self,
) -> None:
"""Create an assignment of indices to the partition indices."""
if self._partition_id_to_indices_determined:
return
self._post_ds_validate_partition_sizes()
start = 0
end = 0
for partition_id, partition_size in enumerate(self._partition_sizes):
end += partition_size
indices = list(range(start, end))
self._partition_id_to_indices[partition_id] = indices
start = end
self._partition_id_to_indices_determined = True

def _pre_ds_validate_partition_sizes(self, partition_sizes: Sequence[int]) -> None:
"""Check if the partition sizes are valid (no information about the dataset)."""
if not isinstance(partition_sizes, Sequence):
raise ValueError("Partition sizes must be a sequence.")
if len(partition_sizes) == 0:
raise ValueError("Partition sizes must not be empty.")
if not all(
isinstance(partition_size, int) for partition_size in partition_sizes
):
raise ValueError("All partition sizes must be integers.")
if not all(partition_size > 0 for partition_size in partition_sizes):
raise ValueError("All partition sizes must be greater than zero.")

def _post_ds_validate_partition_sizes(self) -> None:
"""Validate the partition sizes against the dataset size."""
desired_partition_sizes = sum(self._partition_sizes)
dataset_size = len(self.dataset)
if desired_partition_sizes > dataset_size:
raise ValueError(
f"The sum of partition sizes sum({self._partition_sizes})"
f"= {desired_partition_sizes} is greater than the size of"
f" the dataset {dataset_size}."
)
if desired_partition_sizes < dataset_size:
warnings.warn(
f"The sum of partition sizes is {desired_partition_sizes}, which is"
f"smaller than the size of the dataset: {dataset_size}. "
f"Ignore this warning if it is the desired behavior.",
stacklevel=1,
)
125 changes: 125 additions & 0 deletions datasets/flwr_datasets/partitioner/size_partitioner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the SizePartitioner class."""

# pylint: disable=W0212
import unittest
from typing import Sequence

from parameterized import parameterized

from datasets import Dataset
from flwr_datasets.partitioner.size_partitioner import SizePartitioner


def _dummy_setup_size(partition_sizes: Sequence[int], num_rows: int) -> SizePartitioner:
"""Create a dummy dataset and SizePartitioner for testing."""
data = {
"features": list(range(num_rows)),
}
dataset = Dataset.from_dict(data)
partitioner = SizePartitioner(partition_sizes=partition_sizes)
partitioner.dataset = dataset
return partitioner


tested_valid_intits = [
((10, 20, 30), 60),
# Non growing order
((20, 40, 10), 70),
# Different lengths
((10, 10), 20),
# Single partition
((10,), 10),
]


class TestSizePartitionerSuccess(unittest.TestCase):
"""Test SizePartitioner used with no exceptions."""

@parameterized.expand(tested_valid_intits) # type: ignore
def test_valid_initialization(
self, partition_sizes: Sequence[int], dataset_size: int
) -> None:
"""Test that the SizePartitioner initializes correctly with valid sizes."""
partitioner = _dummy_setup_size(partition_sizes, dataset_size)
self.assertEqual(partitioner.num_partitions, len(partition_sizes))

@parameterized.expand(tested_valid_intits) # type: ignore
def test_partition_size_assignment(
self, partition_sizes: Sequence[int], dataset_size: int
) -> None:
"""Test that partitions are assigned the correct size."""
partitioner = _dummy_setup_size(partition_sizes, dataset_size)
partitioner._determine_partition_id_to_indices_if_needed()
self.assertEqual(
{
pid: len(indices)
for pid, indices in partitioner.partition_id_to_indices.items()
},
dict(enumerate(partition_sizes)),
)

def test_correct_partition_loading(self) -> None:
"""Test that partitions are loaded correctly."""
partition_sizes = [10, 20, 30]
partitioner = _dummy_setup_size(partition_sizes, 60)
partition = partitioner.load_partition(1)
self.assertEqual(len(partition), 20)

def test_warning_for_smaller_partition_sizes(self) -> None:
"""Test a warning is raised if sum of partition sizes < len(ds)."""
partition_sizes = [10, 5, 20]
partitioner = _dummy_setup_size(partition_sizes, 50)
with self.assertWarns(Warning):
partitioner._determine_partition_id_to_indices_if_needed()

def test_no_exception_for_exact_size(self) -> None:
"""Test no exception is raised when len(ds) == sum(patition_sizes)."""
partition_sizes = [10, 20, 30]
partitioner = _dummy_setup_size(partition_sizes, 60)
partitioner._determine_partition_id_to_indices_if_needed()


class TestSizePartitionerFailure(unittest.TestCase):
"""Test SizePartitioner failures (exceptions) by incorrect usage."""

def test_invalid_partition_size(self) -> None:
"""Test if raises ValueError when partition sizes are non-positive."""
with self.assertRaises(ValueError):
SizePartitioner(partition_sizes=[-1, 10, 20])

def test_invalid_partition_type(self) -> None:
"""Test if raises ValueError when partition sizes are non-positive."""
with self.assertRaises(ValueError):
SizePartitioner(partition_sizes=[0.2, 0.3]) # type: ignore[list-item]

def test_partition_size_exceeds_dataset(self) -> None:
"""Test if raises ValueError when partition sizes exceed dataset size."""
partition_sizes = [10, 20, 30]
partitioner = _dummy_setup_size(partition_sizes, 40)
with self.assertRaises(ValueError):
partitioner._determine_partition_id_to_indices_if_needed()

def test_load_invalid_partition_index(self) -> None:
"""Test if raises KeyError when an invalid partition index is loaded."""
partition_sizes = [10, 20, 30]
partitioner = _dummy_setup_size(partition_sizes, 60)
with self.assertRaises(KeyError):
partitioner.load_partition(3)


if __name__ == "__main__":
unittest.main()

0 comments on commit fd0cc0a

Please sign in to comment.