From 514f670948776ee3468388ca22540e8e4d64ee4f Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 26 Feb 2024 22:03:20 +0100 Subject: [PATCH 1/6] Make `InMemoryState` thread-safe when handling `TaskIns` (#3012) --- .../server/superlink/state/in_memory_state.py | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index ecb39f18300..690fadc032d 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -16,6 +16,7 @@ import os +import threading from datetime import datetime, timedelta from logging import ERROR from typing import Dict, List, Optional, Set @@ -35,6 +36,7 @@ def __init__(self) -> None: self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} + self.lock = threading.Lock() def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: """Store one TaskIns.""" @@ -57,7 +59,8 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: task_ins.task_id = str(task_id) task_ins.task.created_at = created_at.isoformat() task_ins.task.ttl = ttl.isoformat() - self.task_ins_store[task_id] = task_ins + with self.lock: + self.task_ins_store[task_id] = task_ins # Return the new task_id return task_id @@ -71,22 +74,23 @@ def get_task_ins( # Find TaskIns for node_id that were not delivered yet task_ins_list: List[TaskIns] = [] - for _, task_ins in self.task_ins_store.items(): - # pylint: disable=too-many-boolean-expressions - if ( - node_id is not None # Not anonymous - and task_ins.task.consumer.anonymous is False - and task_ins.task.consumer.node_id == node_id - and task_ins.task.delivered_at == "" - ) or ( - node_id is None # Anonymous - and task_ins.task.consumer.anonymous is True - and task_ins.task.consumer.node_id == 0 - and task_ins.task.delivered_at == "" - ): - task_ins_list.append(task_ins) - if limit and len(task_ins_list) == limit: - break + with self.lock: + for _, task_ins in self.task_ins_store.items(): + # pylint: disable=too-many-boolean-expressions + if ( + node_id is not None # Not anonymous + and task_ins.task.consumer.anonymous is False + and task_ins.task.consumer.node_id == node_id + and task_ins.task.delivered_at == "" + ) or ( + node_id is None # Anonymous + and task_ins.task.consumer.anonymous is True + and task_ins.task.consumer.node_id == 0 + and task_ins.task.delivered_at == "" + ): + task_ins_list.append(task_ins) + if limit and len(task_ins_list) == limit: + break # Mark all of them as delivered delivered_at = now().isoformat() @@ -164,7 +168,8 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: task_res_to_be_deleted.add(task_res_id) for task_id in task_ins_to_be_deleted: - del self.task_ins_store[task_id] + with self.lock: + del self.task_ins_store[task_id] for task_id in task_res_to_be_deleted: del self.task_res_store[task_id] From 39ef78bf6e742ea7f7a9098ef7e265cb8d60372d Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 26 Feb 2024 22:28:40 +0100 Subject: [PATCH 2/6] Fix incorrect link in FedPara README (#3014) --- baselines/fedpara/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baselines/fedpara/README.md b/baselines/fedpara/README.md index 0ff9203a9a5..89cca76a2aa 100644 --- a/baselines/fedpara/README.md +++ b/baselines/fedpara/README.md @@ -93,7 +93,7 @@ As for the parameters ratio ($\gamma$) we use the following model sizes. As in t ## Environment Setup To construct the Python environment follow these steps: -It is assumed that `pyenv` is installed, `poetry` is installed and python 3.10.6 is installed using `pyenv`. Refer to this [documentation](https://flower.ai/docs/baselines/how-to-usef-baselines.html#setting-up-your-machine) to ensure that your machine is ready. +It is assumed that `pyenv` is installed, `poetry` is installed and python 3.10.6 is installed using `pyenv`. Refer to this [documentation](https://flower.ai/docs/baselines/how-to-use-baselines.html#setting-up-your-machine) to ensure that your machine is ready. ```bash # Set Python 3.10 From ccb0b35ce0477a9048162a1a43069eabc951c66b Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 26 Feb 2024 22:46:20 +0100 Subject: [PATCH 3/6] Add `partition_id` to `Metadata` (#3013) --- .../client/message_handler/message_handler.py | 2 +- .../message_handler/message_handler_test.py | 2 ++ src/py/flwr/common/message.py | 19 ++++++++++++++++++- .../ray_transport/ray_client_proxy.py | 1 + .../ray_transport/ray_client_proxy_test.py | 3 ++- 5 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index e7e6c7e05c7..87cace88ec2 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -98,7 +98,7 @@ def handle_legacy_message_from_msgtype( client_fn: ClientFn, message: Message, context: Context ) -> Message: """Handle legacy message in the inner most mod.""" - client = client_fn(str(message.metadata.dst_node_id)) + client = client_fn(str(message.metadata.partition_id)) # Check if NumPyClient is returend if isinstance(client, NumPyClient): diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 9fc126f2792..c24b51972f3 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -269,6 +269,8 @@ def test_invalid_message_run_id(self) -> None: invalid_metadata_list: List[Metadata] = [] attrs = list(vars(self.valid_out_metadata).keys()) for attr in attrs: + if attr == "_partition_id": + continue if attr == "_ttl": # Skip configurable ttl continue # Make an invalid metadata diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 14dae0f6ee5..1e1132e42e2 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -14,7 +14,6 @@ # ============================================================================== """Message.""" - from __future__ import annotations from dataclasses import dataclass @@ -46,6 +45,10 @@ class Metadata: # pylint: disable=too-many-instance-attributes message_type : str A string that encodes the action to be executed on the receiving end. + partition_id : Optional[int] + An identifier that can be used when loading a particular + data partition for a ClientApp. Making use of this identifier + is more relevant when conducting simulations. """ _run_id: int @@ -56,6 +59,7 @@ class Metadata: # pylint: disable=too-many-instance-attributes _group_id: str _ttl: str _message_type: str + _partition_id: int | None def __init__( # pylint: disable=too-many-arguments self, @@ -67,6 +71,7 @@ def __init__( # pylint: disable=too-many-arguments group_id: str, ttl: str, message_type: str, + partition_id: int | None = None, ) -> None: self._run_id = run_id self._message_id = message_id @@ -76,6 +81,7 @@ def __init__( # pylint: disable=too-many-arguments self._group_id = group_id self._ttl = ttl self._message_type = message_type + self._partition_id = partition_id @property def run_id(self) -> int: @@ -137,6 +143,16 @@ def message_type(self, value: str) -> None: """Set message_type.""" self._message_type = value + @property + def partition_id(self) -> int | None: + """An identifier telling which data partition a ClientApp should use.""" + return self._partition_id + + @partition_id.setter + def partition_id(self, value: int) -> None: + """Set patition_id.""" + self._partition_id = value + @dataclass class Message: @@ -202,6 +218,7 @@ def create_reply(self, content: RecordSet, ttl: str) -> Message: group_id=self.metadata.group_id, ttl=ttl, message_type=self.metadata.message_type, + partition_id=self.metadata.partition_id, ), content=content, ) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 405e0920c5a..a45321ed236 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -111,6 +111,7 @@ def _wrap_recordset_in_message( reply_to_message="", ttl=str(timeout) if timeout else "", message_type=message_type, + partition_id=int(self.cid), ), ) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 3eeabe0292c..24fe3546e7d 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -198,10 +198,11 @@ def _load_app() -> ClientApp: message_id="", group_id="", src_node_id=0, - dst_node_id=int(cid), + dst_node_id=12345, reply_to_message="", ttl="", message_type=MESSAGE_TYPE_GET_PROPERTIES, + partition_id=int(cid), ), ) pool.submit_client_job( From 7bfc58ad98eef114f6c876e6a364082e5507f38f Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 26 Feb 2024 22:57:29 +0100 Subject: [PATCH 4/6] Fix incorrect URLs for baseline doc page (#3015) --- doc/locales/fr/LC_MESSAGES/framework-docs.po | 6 +++--- doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po | 2 +- doc/source/ref-changelog.md | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/locales/fr/LC_MESSAGES/framework-docs.po b/doc/locales/fr/LC_MESSAGES/framework-docs.po index 920a47abab3..ba5ea5ec070 100644 --- a/doc/locales/fr/LC_MESSAGES/framework-docs.po +++ b/doc/locales/fr/LC_MESSAGES/framework-docs.po @@ -1325,7 +1325,7 @@ msgid "" msgstr "" "Si tu n'es pas familier avec les Flower Baselines, tu devrais " "probablement consulter notre `guide de contribution pour les baselines " -"`_." +"`_." #: ../../source/contributor-ref-good-first-contributions.rst:27 msgid "" @@ -15862,7 +15862,7 @@ msgstr "" "l'utilisation de [Flower Baselines](https://flower.ai/docs/using-" "baselines.html). Avec cette première version préliminaire, nous invitons " "également la communauté à [contribuer à leurs propres lignes de " -"base](https://flower.ai/docs/contributing-baselines.html)." +"base](https://flower.ai/docs/baselines/how-to-contribute-baselines.html)." #: ../../source/ref-changelog.md:662 msgid "" @@ -25474,7 +25474,7 @@ msgstr "" #~ " papers. If you want to add a" #~ " new baseline or experiment, please " #~ "check the `Contributing Baselines " -#~ "`_ " +#~ "`_ " #~ "section." #~ msgstr "" diff --git a/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po b/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po index ab1c8dc39e6..b6c32f99459 100644 --- a/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po +++ b/doc/locales/zh_Hans/LC_MESSAGES/framework-docs.po @@ -1230,7 +1230,7 @@ msgid "" "/contributing-baselines.html>`_." msgstr "" "如果您对 Flower Baselines 还不熟悉,也许可以看看我们的 `Baselines贡献指南 " -"`_。" +"`_。" #: ../../source/contributor-ref-good-first-contributions.rst:27 msgid "" diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 41dc91873c6..54092e15a56 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -657,7 +657,7 @@ We would like to give our **special thanks** to all the contributors who made Fl - **Flower Baselines (preview): FedOpt, FedBN, FedAvgM** ([#919](https://github.com/adap/flower/pull/919), [#1127](https://github.com/adap/flower/pull/1127), [#914](https://github.com/adap/flower/pull/914)) - The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.ai/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.ai/docs/contributing-baselines.html). + The first preview release of Flower Baselines has arrived! We're kickstarting Flower Baselines with implementations of FedOpt (FedYogi, FedAdam, FedAdagrad), FedBN, and FedAvgM. Check the documentation on how to use [Flower Baselines](https://flower.ai/docs/using-baselines.html). With this first preview release we're also inviting the community to [contribute their own baselines](https://flower.ai/docs/baselines/how-to-contribute-baselines.html). - **C++ client SDK (preview) and code example** ([#1111](https://github.com/adap/flower/pull/1111)) From 4abfd066b444e6dbc01b2944ba5934dfa3e39b03 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:27:19 +0100 Subject: [PATCH 5/6] Add DirichletPartitioner (#2795) Co-authored-by: Javier --- .../flwr_datasets/partitioner/__init__.py | 2 + .../partitioner/dirichlet_partitioner.py | 323 ++++++++++++++++++ .../partitioner/dirichlet_partitioner_test.py | 170 +++++++++ 3 files changed, 495 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner.py create mode 100644 datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 5e7c86718f6..6a85f8a1174 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -15,6 +15,7 @@ """Flower Datasets Partitioner package.""" +from .dirichlet_partitioner import DirichletPartitioner from .exponential_partitioner import ExponentialPartitioner from .iid_partitioner import IidPartitioner from .linear_partitioner import LinearPartitioner @@ -27,6 +28,7 @@ "IidPartitioner", "Partitioner", "NaturalIdPartitioner", + "DirichletPartitioner", "SizePartitioner", "LinearPartitioner", "SquarePartitioner", diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py new file mode 100644 index 00000000000..5f1df71991b --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -0,0 +1,323 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dirichlet partitioner class that works with Hugging Face Datasets.""" + + +import warnings +from typing import Dict, List, Optional, Union + +import numpy as np + +import datasets +from flwr_datasets.common.typing import NDArrayFloat +from flwr_datasets.partitioner.partitioner import Partitioner + + +# pylint: disable=R0902, R0912 +class DirichletPartitioner(Partitioner): + """Partitioner based on Dirichlet distribution. + + Implementation based on Bayesian Nonparametric Federated Learning of Neural Networks + https://arxiv.org/abs/1905.12022. + + The algorithm sequentially divides the data with each label. The fractions of the + data with each label is drawn from Dirichlet distribution and adjusted in case of + balancing. The data is assigned. In case the `min_partition_size` is not satisfied + the algorithm is run again (the fractions will change since it is a random process + even though the alpha stays the same). + + The notion of balancing is explicitly introduced here (not mentioned in paper but + implemented in the code). It is a mechanism that excludes the node from + assigning new samples to it if the current number of samples on that node exceeds + the average number that the node would get in case of even data distribution. + It is controlled by`self_balancing` parameter. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + min_partition_size : int + The minimum number of samples that each partitions will have (the sampling + process is repeated if any partition is too small). + self_balancing : bool + Whether assign further samples to a partition after the number of samples + exceeded the average number of samples per partition. (True in the original + paper's code although not mentioned in paper itself). + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import DirichletPartitioner + >>> + >>> partitioner = DirichletPartitioner(num_partitions=10, partition_by="label", + >>> alpha=0.5, min_partition_size=10, + >>> self_balancing=True) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + {'image': , + 'label': 4} + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(sorted(partition_sizes)) + [2134, 2615, 3646, 6011, 6170, 6386, 6715, 7653, 8435, 10235] + """ + + def __init__( # pylint: disable=R0913 + self, + num_partitions: int, + partition_by: str, + alpha: Union[int, float, List[float], NDArrayFloat], + min_partition_size: int = 10, + self_balancing: bool = True, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + self._num_partitions = num_partitions + self._check_num_partitions_greater_than_zero() + self._alpha: NDArrayFloat = self._initialize_alpha(alpha) + self._partition_by = partition_by + self._min_partition_size: int = min_partition_size + self._self_balancing = self_balancing + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + + # Utility attributes + # The attributes below are determined during the first call to load_partition + self._avg_num_of_samples_per_node: Optional[float] = None + self._unique_classes: Optional[Union[List[int], List[str]]] = None + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _initialize_alpha( + self, alpha: Union[int, float, List[float], NDArrayFloat] + ) -> NDArrayFloat: + """Convert alpha to the used format in the code a NDArrayFloat. + + The alpha can be provided in constructor can be in different format for user + convenience. The format into which it's transformed here is used throughout the + code for computation. + + Parameters + ---------- + alpha : Union[int, float, List[float], NDArrayFloat] + Concentration parameter to the Dirichlet distribution + + Returns + ------- + alpha : NDArrayFloat + Concentration parameter in a format ready to used in computation. + """ + if isinstance(alpha, int): + alpha = np.array([float(alpha)], dtype=float).repeat(self._num_partitions) + elif isinstance(alpha, float): + alpha = np.array([alpha], dtype=float).repeat(self._num_partitions) + elif isinstance(alpha, List): + if len(alpha) != self._num_partitions: + raise ValueError( + "If passing alpha as a List, it needs to be of length of equal to " + "num_partitions." + ) + alpha = np.asarray(alpha) + elif isinstance(alpha, np.ndarray): + # pylint: disable=R1720 + if alpha.ndim == 1 and alpha.shape[0] != self._num_partitions: + raise ValueError( + "If passing alpha as an NDArray, its length needs to be of length " + "equal to num_partitions." + ) + elif alpha.ndim == 2: + alpha = alpha.flatten() + if alpha.shape[0] != self._num_partitions: + raise ValueError( + "If passing alpha as an NDArray, its size needs to be of length" + " equal to num_partitions." + ) + else: + raise ValueError("The given alpha format is not supported.") + if not (alpha > 0).all(): + raise ValueError( + f"Alpha values should be strictly greater than zero. " + f"Instead it'd be converted to {alpha}" + ) + return alpha + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Create an assignment of indices to the partition indices.""" + if self._node_id_to_indices_determined: + return + + # Generate information needed for Dirichlet partitioning + self._unique_classes = self.dataset.unique(self._partition_by) + assert self._unique_classes is not None + # This is needed only if self._self_balancing is True (the default option) + self._avg_num_of_samples_per_node = self.dataset.num_rows / self._num_partitions + + # Change targets list data type to numpy + targets = np.array(self.dataset[self._partition_by]) + + # Repeat the sampling procedure based on the Dirichlet distribution until the + # min_partition_size is reached. + sampling_try = 0 + while True: + # Prepare data structure to store indices assigned to node ids + node_id_to_indices: Dict[int, List[int]] = {} + for nid in range(self._num_partitions): + node_id_to_indices[nid] = [] + + # Iterated over all unique labels (they are not necessarily of type int) + for k in self._unique_classes: + # Access all the indices associated with class k + indices_representing_class_k = np.nonzero(targets == k)[0] + # Determine division (the fractions) of the data representing class k + # among the partitions + class_k_division_proportions = self._rng.dirichlet(self._alpha) + nid_to_proportion_of_k_samples = {} + for nid in range(self._num_partitions): + nid_to_proportion_of_k_samples[nid] = class_k_division_proportions[ + nid + ] + # Balancing (not mentioned in the paper but implemented) + # Do not assign additional samples to the node if it already has more + # than the average numbers of samples per partition. Note that it might + # especially affect classes that are later in the order. This is the + # reason for more sparse division that the alpha might suggest. + if self._self_balancing: + assert self._avg_num_of_samples_per_node is not None + for nid in nid_to_proportion_of_k_samples.copy(): + if ( + len(node_id_to_indices[nid]) + > self._avg_num_of_samples_per_node + ): + nid_to_proportion_of_k_samples[nid] = 0 + + # Normalize the proportions such that they sum up to 1 + sum_proportions = sum(nid_to_proportion_of_k_samples.values()) + for nid, prop in nid_to_proportion_of_k_samples.copy().items(): + nid_to_proportion_of_k_samples[nid] = prop / sum_proportions + + # Determine the split indices + cumsum_division_fractions = np.cumsum( + list(nid_to_proportion_of_k_samples.values()) + ) + cumsum_division_numbers = cumsum_division_fractions * len( + indices_representing_class_k + ) + # [:-1] is because the np.split requires the division indices but the + # last element represents the sum = total number of samples + indices_on_which_split = cumsum_division_numbers.astype(int)[:-1] + + split_indices = np.split( + indices_representing_class_k, indices_on_which_split + ) + + # Append new indices (coming from class k) to the existing indices + for nid, indices in node_id_to_indices.items(): + indices.extend(split_indices[nid].tolist()) + + # Determine if the indices assignment meets the min_partition_size + # If it does not mean the requirement repeat the Dirichlet sampling process + # Otherwise break the while loop + min_sample_size_on_client = min( + len(indices) for indices in node_id_to_indices.values() + ) + if min_sample_size_on_client >= self._min_partition_size: + break + sample_sizes = [len(indices) for indices in node_id_to_indices.values()] + alpha_not_met = [ + self._alpha[i] + for i, ss in enumerate(sample_sizes) + if ss == min(sample_sizes) + ] + mssg_list_alphas = ( + ( + "Generating partitions by sampling from a list of very wide range " + "of alpha values can be hard to achieve. Try reducing the range " + f"between maximum ({max(self._alpha)}) and minimum alpha " + f"({min(self._alpha)}) values or increasing all the values." + ) + if len(self._alpha.flatten().tolist()) > 0 + else "" + ) + warnings.warn( + f"The specified min_partition_size ({self._min_partition_size}) was " + f"not satisfied for alpha ({alpha_not_met}) after " + f"{sampling_try} attempts at sampling from the Dirichlet " + f"distribution. The probability sampling from the Dirichlet " + f"distribution will be repeated. Note: This is not a desired " + f"behavior. It is recommended to adjust the alpha or " + f"min_partition_size instead. {mssg_list_alphas}", + stacklevel=1, + ) + if sampling_try == 10: + raise ValueError( + "The max number of attempts (10) was reached. " + "Please update the values of alpha and try again." + ) + sampling_try += 1 + + # Shuffle the indices not to have the datasets with targets in sequences like + # [00000, 11111, ...]) if the shuffle is True + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _check_num_partitions_greater_than_zero(self) -> None: + """Test num_partition left sides correctness.""" + if not self._num_partitions > 0: + raise ValueError("The number of partitions needs to be greater than zero.") diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py new file mode 100644 index 00000000000..c123f84effb --- /dev/null +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py @@ -0,0 +1,170 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DirichletPartitioner.""" + + +# pylint: disable=W0212 +import unittest +from typing import Tuple, Union + +import numpy as np +from numpy.typing import NDArray +from parameterized import parameterized + +from datasets import Dataset +from flwr_datasets.partitioner.dirichlet_partitioner import DirichletPartitioner + + +def _dummy_setup( + num_partitions: int, + alpha: Union[float, NDArray[np.float_]], + num_rows: int, + partition_by: str, + self_balancing: bool = True, +) -> Tuple[Dataset, DirichletPartitioner]: + """Create a dummy dataset and partitioner for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = DirichletPartitioner( + num_partitions=num_partitions, + alpha=alpha, + partition_by=partition_by, + self_balancing=self_balancing, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestDirichletPartitionerSuccess(unittest.TestCase): + """Test DirichletPartitioner used with no exceptions.""" + + @parameterized.expand( # type: ignore + [ + # num_partitions, alpha, num_rows, partition_by + (3, 0.5, 100, "labels"), + (5, 1.0, 150, "labels"), + ] + ) + def test_valid_initialization( + self, num_partitions: int, alpha: float, num_rows: int, partition_by: str + ) -> None: + """Test if alpha is correct scaled based on the given num_partitions.""" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + self.assertEqual( + ( + partitioner._num_partitions, + len(partitioner._alpha), + partitioner._partition_by, + ), + (num_partitions, num_partitions, partition_by), + ) + + def test_min_partition_size_requirement(self) -> None: + """Test if partitions are created with min partition size required.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + partition_list = [partitioner.load_partition(node_id) for node_id in [0, 1, 2]] + self.assertTrue( + all(len(p) > partitioner._min_partition_size for p in partition_list) + ) + + def test_alpha_in_ndarray_initialization(self) -> None: + """Test alpha does not change when in NDArrayFloat format.""" + _, partitioner = _dummy_setup(3, np.array([1.0, 1.0, 1.0]), 100, "labels") + self.assertTrue(np.all(partitioner._alpha == np.array([1.0, 1.0, 1.0]))) + + def test__determine_node_id_to_indices(self) -> None: + """Test the determine_nod_id_to_indices matches the flag after the call.""" + num_partitions, alpha, num_rows, partition_by = 3, 0.5, 100, "labels" + _, partitioner = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + partitioner._determine_node_id_to_indices_if_needed() + self.assertTrue( + partitioner._node_id_to_indices_determined + and len(partitioner._node_id_to_indices) == num_partitions + ) + + +class TestDirichletPartitionerFailure(unittest.TestCase): + """Test DirichletPartitioner failures (exceptions) by incorrect usage.""" + + @parameterized.expand([(-2,), (-1,), (3,), (4,), (100,)]) # type: ignore + def test_load_invalid_partition_index(self, partition_id): + """Test if raises when the load_partition is above the num_partitions.""" + _, partitioner = _dummy_setup(3, 0.5, 100, "labels") + with self.assertRaises(KeyError): + partitioner.load_partition(partition_id) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + (-0.5, 1), + (-0.5, 2), + (-0.5, 3), + (-0.5, 10), + ([0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, -0.5], 3), + ([-0.5, 0.5, 0.5], 3), + ([-0.5, -0.5, -0.5], 3), + ([0.5, 0.5, -0.5, -0.5, 0.5], 5), + (np.array([0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, -0.5]), 3), + (np.array([-0.5, 0.5, 0.5]), 3), + (np.array([-0.5, -0.5, -0.5]), 3), + (np.array([0.5, 0.5, -0.5, -0.5, 0.5]), 5), + ] + ) + def test_negative_values_in_alpha(self, alpha, num_partitions): + """Test if giving the negative value of alpha raises error.""" + num_rows, partition_by = 100, "labels" + with self.assertRaises(ValueError): + _, _ = _dummy_setup(num_partitions, alpha, num_rows, partition_by) + + @parameterized.expand( # type: ignore + [ + # alpha, num_partitions + # alpha greater than the num_partitions + ([0.5, 0.5], 1), + ([0.5, 0.5, 0.5], 2), + (np.array([0.5, 0.5]), 1), + (np.array([0.5, 0.5, 0.5]), 2), + (np.array([0.5, 0.5, 0.5, 0.5]), 3), + ] + ) + def test_incorrect_alpha_shape(self, alpha, num_partitions): + """Test alpha list len not matching the num_partitions.""" + with self.assertRaises(ValueError): + DirichletPartitioner( + num_partitions=num_partitions, alpha=alpha, partition_by="labels" + ) + + @parameterized.expand( # type: ignore + [(0,), (-1,), (11,), (100,)] + ) # num_partitions, + def test_invalid_num_partitions(self, num_partitions): + """Test if 0 is invalid num_partitions.""" + with self.assertRaises(ValueError): + _, partitioner = _dummy_setup( + num_partitions=num_partitions, + alpha=1.0, + num_rows=10, + partition_by="labels", + ) + partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main() From 65f77a98bfb8fe418ddf96641f4cdbf1ba3a07f0 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:44:22 +0100 Subject: [PATCH 6/6] Add ShardPartitioner (#2792) Co-authored-by: Javier --- .../flwr_datasets/partitioner/__init__.py | 2 + .../partitioner/shard_partitioner.py | 354 ++++++++++++++++ .../partitioner/shard_partitioner_test.py | 392 ++++++++++++++++++ 3 files changed, 748 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/shard_partitioner.py create mode 100644 datasets/flwr_datasets/partitioner/shard_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 6a85f8a1174..73d048ddf3f 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -21,6 +21,7 @@ from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner +from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner from .square_partitioner import SquarePartitioner @@ -32,5 +33,6 @@ "SizePartitioner", "LinearPartitioner", "SquarePartitioner", + "ShardPartitioner", "ExponentialPartitioner", ] diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py new file mode 100644 index 00000000000..7c86570fe48 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -0,0 +1,354 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Shard partitioner class.""" + + +# pylint: disable=R0912 +import math +from typing import Dict, List, Optional + +import numpy as np + +import datasets +from flwr_datasets.partitioner.partitioner import Partitioner + + +class ShardPartitioner(Partitioner): # pylint: disable=R0902 + """Partitioner based on shard of (typically) unique classes. + + The algorithm works as follows: the dataset is sorted by label e.g. [samples with + label 1, samples with labels 2 ...], then the shards are created, with each + shard of size = `shard_size` if provided or automatically calculated: + shards_size = len(dataset) / `num_partitions` * `num_shards_per_node`. + + A shard is just a block (chunk) of a `dataset` that contains `shard_size` + consecutive samples. There might be shards that contain samples associated with more + than a single unique label. The first case is (remember the preprocessing step sorts + the dataset by label) when a shard is constructed from samples at the boundaries of + the sorted dataset and therefore belonging to different classes e.g. the "leftover" + of samples of class 1 and the majority of class 2. The another scenario when a shard + has samples with more than one unique label is when the shard size is bigger than + the number of samples of a certain class. + + Each partition is created from `num_shards_per_node` that are chosen randomly. + + There are a few ways of partitioning data that result in certain properties + (depending on the parameters specification): + 1) same number of shards per nodes + the same shard size (specify: + a) `num_shards_per_nodes`, `shard_size`; or b) `num_shards_per_node`) + In case of b the `shard_size` is calculated as floor(len(dataset) / + (`num_shards_per_nodes` * `num_partitions`)) + 2) possibly different number of shards per node (use nearly all data) + the same + shard size (specify: `shard_size` + `keep_incomplete_shard=False`) + 3) possibly different number of shards per node (use all data) + possibly different + shard size (specify: `shard_size` + `keep_incomplete_shard=True`) + + + Algorithm based on the description in Communication-Efficient Learning of Deep + Networks from Decentralized Data https://arxiv.org/abs/1602.05629. This + implementation expands on the initial idea by enabling more hyperparameters + specification therefore providing more control on how partitions are created. + It enables the division obtained in original paper. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + num_shards_per_node : Optional[int] + Number of shards to assign to a single partitioner. It's an alternative to + `num_partitions`. + shard_size : Optional[int] + Size of a single shards (a partition has one or more shards). If the size is not + given it will be automatically computed. + keep_incomplete_shard : bool + Whether to drop the last shard which might be incomplete (smaller than the + others). If it is dropped each shard is equal size. (It does not mean that each + client gets equal number of shards, which only happens if + `num_partitions` % `num_shards` = 0). This parameter has no effect if + `num_shards_per_nodes` and `shard_size` are specified. + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + + Examples + -------- + 1) If you need same number of shards per nodes + the same shard size (and you know + both of these values) + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", + >>> num_shards_per_node=2, shard_size=1_000) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + {'image': , + 'label': 3} + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(partition_sizes) + [2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000] + + 2) If you want to use nearly all the data and do not need to have the number of + shard per each node to be the same + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=9, partition_by="label", + >>> shard_size=1_000) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(9)] + >>> print(partition_sizes) + [7000, 7000, 7000, 7000, 7000, 7000, 6000, 6000, 6000] + + 3) If you want to use all the data + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", + >>> shard_size=990, keep_incomplete_shard=True) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(sorted(partition_sizes)) + [5550, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 6930] + """ + + def __init__( # pylint: disable=R0913 + self, + num_partitions: int, + partition_by: str, + num_shards_per_node: Optional[int] = None, + shard_size: Optional[int] = None, + keep_incomplete_shard: bool = False, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + _check_if_natual_number(num_partitions, "num_partitions") + self._num_partitions = num_partitions + self._partition_by = partition_by + _check_if_natual_number(num_shards_per_node, "num_shards_per_node", True) + self._num_shards_per_node = num_shards_per_node + self._num_shards_used: Optional[int] = None + _check_if_natual_number(shard_size, "shard_size", True) + self._shard_size = shard_size + self._keep_incomplete_shard = keep_incomplete_shard + self._shuffle = shuffle + self._seed = seed + + # Utility attributes + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._check_possibility_of_partitions_creation() + self._sort_dataset_if_needed() + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Assign sample indices to each node id. + + This method works on sorted datasets. A "shard" is a part of the dataset of + consecutive samples (if self._keep_incomplete_shard is False, each shard is same + size). + """ + # No need to do anything if that node_id_to_indices are already determined + if self._node_id_to_indices_determined: + return + + # One of the specification allows to skip the `num_shards_per_node` param + if self._num_shards_per_node is not None: + self._num_shards_used = int( + self._num_partitions * self._num_shards_per_node + ) + num_shards_per_node_array = ( + np.ones(self._num_partitions) * self._num_shards_per_node + ) + if self._shard_size is None: + self._compute_shard_size_if_missing() + assert self._shard_size is not None + if self._keep_incomplete_shard: + num_usable_shards_in_dataset = int( + math.ceil(len(self.dataset) / self._shard_size) + ) + else: + num_usable_shards_in_dataset = int( + math.floor(len(self.dataset) / self._shard_size) + ) + else: + num_usable_shards_in_dataset = int( + math.floor(len(self.dataset) / self._shard_size) + ) + elif self._num_shards_per_node is None: + if self._shard_size is None: + raise ValueError( + "The shard_size needs to be specified if the " + "num_shards_per_node is None" + ) + if self._keep_incomplete_shard is False: + self._num_shards_used = int( + math.floor(len(self.dataset) / self._shard_size) + ) + num_usable_shards_in_dataset = self._num_shards_used + elif self._keep_incomplete_shard is True: + self._num_shards_used = int( + math.ceil(len(self.dataset) / self._shard_size) + ) + num_usable_shards_in_dataset = self._num_shards_used + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "Based on the given arguments the creation of the partitions " + "is impossible. The implied number of partitions that can be " + "used is lower than the number of requested partitions " + "resulting in empty partitions. Please decrease the size of " + "shards: `shard_size`." + ) + else: + raise ValueError( + "The keep_incomplete_shards need to be specified " + "when _num_shards_per_node is None." + ) + num_shards_per_node = int(self._num_shards_used / self._num_partitions) + # Assign the shards per nodes (so far, the same as in ideal case) + num_shards_per_node_array = ( + np.ones(self._num_partitions) * num_shards_per_node + ) + num_shards_assigned = self._num_partitions * num_shards_per_node + num_shards_to_assign = self._num_shards_used - num_shards_assigned + # Assign the "missing" shards + for i in range(num_shards_to_assign): + num_shards_per_node_array[i] += 1 + + else: + raise ValueError( + "The specification of nm_shards_per_node and " + "keep_incomplete_shards is not correct." + ) + + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "The specified configuration results in empty partitions because the " + "number of usable shards is smaller that the number partitions. " + "Try decreasing the shard size or the number of partitions. " + ) + + indices_on_which_to_split_shards = np.cumsum( + num_shards_per_node_array, dtype=int + ) + + shard_indices_array = self._rng.permutation(num_usable_shards_in_dataset)[ + : self._num_shards_used + ] + # Randomly assign shards to node_id + nid_to_shard_indices = np.split( + shard_indices_array, indices_on_which_to_split_shards + )[:-1] + node_id_to_indices: Dict[int, List[int]] = { + cid: [] for cid in range(self._num_partitions) + } + # Compute node_id to sample indices based on the shard indices + for node_id in range(self._num_partitions): + for shard_idx in nid_to_shard_indices[node_id]: + start_id = int(shard_idx * self._shard_size) + end_id = min(int((shard_idx + 1) * self._shard_size), len(self.dataset)) + node_id_to_indices[node_id].extend(list(range(start_id, end_id))) + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _sort_dataset_if_needed(self) -> None: + """Sort dataset prior to determining the partitions. + + Operation only needed to be performed one time. It's required for the creation + of shards with the same labels. + """ + if self._node_id_to_indices_determined: + return + self._dataset = self.dataset.sort(self._partition_by) + + def _compute_shard_size_if_missing(self) -> None: + """Compute the parameters needed to perform sharding. + + This method should be called after the dataset is assigned. + """ + if self._shard_size is None: + # If shard size is not specified it needs to be computed + num_rows = self.dataset.num_rows + self._shard_size = int(num_rows / self._num_shards_used) + + def _check_possibility_of_partitions_creation(self) -> None: + if self._shard_size is not None and self._num_shards_per_node is not None: + implied_min_dataset_size = ( + self._shard_size * self._num_shards_per_node * self._num_partitions + ) + if implied_min_dataset_size > len(self.dataset): + raise ValueError( + f"Based on the given arguments the creation of the " + "partitions is impossible. The implied minimum dataset" + f"size is {implied_min_dataset_size} but the dataset" + f"size is {len(self.dataset)}" + ) + + +def _check_if_natual_number( + number: Optional[int], parameter_name: str, none_acceptable: bool = False +) -> None: + if none_acceptable and number is None: + return + if not isinstance(number, int): + raise TypeError( + f"The expected type of {parameter_name} is int but given: {number} of type " + f"{type(number)}. Please specify the correct type." + ) + if not number >= 1: + raise ValueError( + f"The expected value of {parameter_name} is >= 1 (greater or equal to 1) " + f"but given: {number} which does not meet this condition. Please " + f"provide a correct number." + ) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner_test.py b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py new file mode 100644 index 00000000000..47968699bba --- /dev/null +++ b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py @@ -0,0 +1,392 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test ShardPartitioner.""" + + +# pylint: disable=W0212, R0913 +import unittest +from typing import Optional, Tuple + +from datasets import Dataset +from flwr_datasets.partitioner.shard_partitioner import ShardPartitioner + + +def _dummy_setup( + num_rows: int, + partition_by: str, + num_partitions: int, + num_shards_per_node: Optional[int], + shard_size: Optional[int], + keep_incomplete_shard: bool = False, +) -> Tuple[Dataset, ShardPartitioner]: + """Create a dummy dataset for testing.""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = ShardPartitioner( + num_partitions=num_partitions, + num_shards_per_node=num_shards_per_node, + partition_by=partition_by, + shard_size=shard_size, + keep_incomplete_shard=keep_incomplete_shard, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestShardPartitionerSpec1(unittest.TestCase): + """Test first possible initialization of ShardPartitioner. + + Specify num_shards_per_node and shard_size arguments. + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [30, 30, 30]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec2(unittest.TestCase): + """Test second possible initialization of ShardPartitioner. + + Specify shard_size and keep_incomplete_shard=False. This setting creates partitions + that might have various sizes (each shard is same size). + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [30, 40, 40]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec3(unittest.TestCase): + """Test third possible initialization of ShardPartitioner. + + Specify shard_size and keep_incomplete_shard=True. This setting creates partitions + that might have various sizes (each shard is same size). + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [33, 40, 40]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec4(unittest.TestCase): + """Test fourth possible initialization of ShardPartitioner. + + Specify num_shards_per_node but not shard_size arguments. + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [36, 36, 36]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerIncorrectSpec(unittest.TestCase): + """Test the incorrect specification cases. + + The lack of correctness can be caused by the num_partitions, shard_size and + num_shards_per_partition can create. + """ + + def test_incorrect_specification(self) -> None: + """Test if the given specification makes the partitioning possible.""" + partition_by = "label" + num_rows = 10 + num_partitions = 3 + num_shards_per_node = 2 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(0) + + def test_too_big_shard_size(self) -> None: + """Test if it is impossible to create an empty partition.""" + partition_by = "label" + num_rows = 20 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(2).num_rows + + +if __name__ == "__main__": + unittest.main()