From f290fe2d49e12a871c33c0ce465ea1d021a4e863 Mon Sep 17 00:00:00 2001 From: Mohammad Naseri Date: Thu, 5 Sep 2024 22:31:08 +0100 Subject: [PATCH] refactor(examples) Update fl-dp-sa example (#4137) Co-authored-by: jafermarq --- examples/fl-dp-sa/README.md | 61 +++++++++++++++---- examples/fl-dp-sa/fl_dp_sa/__init__.py | 2 +- examples/fl-dp-sa/fl_dp_sa/client.py | 42 ------------- examples/fl-dp-sa/fl_dp_sa/client_app.py | 50 +++++++++++++++ .../fl_dp_sa/{server.py => server_app.py} | 59 ++++++++++-------- examples/fl-dp-sa/fl_dp_sa/task.py | 39 ++++++------ examples/fl-dp-sa/flower.toml | 13 ---- examples/fl-dp-sa/pyproject.toml | 49 ++++++++++----- examples/fl-dp-sa/requirements.txt | 4 -- 9 files changed, 188 insertions(+), 131 deletions(-) delete mode 100644 examples/fl-dp-sa/fl_dp_sa/client.py create mode 100644 examples/fl-dp-sa/fl_dp_sa/client_app.py rename examples/fl-dp-sa/fl_dp_sa/{server.py => server_app.py} (56%) delete mode 100644 examples/fl-dp-sa/flower.toml delete mode 100644 examples/fl-dp-sa/requirements.txt diff --git a/examples/fl-dp-sa/README.md b/examples/fl-dp-sa/README.md index 65c8a5b18fa..61a6c80f355 100644 --- a/examples/fl-dp-sa/README.md +++ b/examples/fl-dp-sa/README.md @@ -1,28 +1,63 @@ --- -tags: [basic, vision, fds] +tags: [DP, SecAgg, vision, fds] dataset: [MNIST] framework: [torch, torchvision] --- -# Example of Flower App with DP and SA +# Flower Example on MNIST with Differential Privacy and Secure Aggregation -This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation. -Note: This example is designed for a small number of rounds and is intended for demonstration purposes. +This example demonstrates a federated learning setup using the Flower, incorporating central differential privacy (DP) with client-side fixed clipping and secure aggregation (SA). It is intended for a small number of rounds for demonstration purposes. -## Install dependencies +This example is similar to the [quickstart-pytorch example](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) and extends it by integrating central differential privacy and secure aggregation. For more details on differential privacy and secure aggregation in Flower, please refer to the documentation [here](https://flower.ai/docs/framework/how-to-use-differential-privacy.html) and [here](https://flower.ai/docs/framework/contributor-ref-secure-aggregation-protocols.html). -```bash -# Using pip -pip install . +## Set up the project + +### Clone the project + +Start by cloning the example project: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/fl-dp-sa . && rm -rf flower && cd fl-dp-sa +``` + +This will create a new directory called `fl-dp-sa` containing the following files: -# Or using Poetry -poetry install +```shell +fl-dp-sa +├── fl_dp_sa +│ ├── client_app.py # Defines your ClientApp +│ ├── server_app.py # Defines your ServerApp +│ └── task.py # Defines your model, training, and data loading +├── pyproject.toml # Project metadata like dependencies and configs +└── README.md ``` -## Run +### Install dependencies and project -The example uses the MNIST dataset with a total of 100 clients, with 20 clients sampled in each round. The hyperparameters for DP and SecAgg are specified in `server.py`. +Install the dependencies defined in `pyproject.toml` as well as the `fl_dp_sa` package. ```shell -flower-simulation --server-app fl_dp_sa.server:app --client-app fl_dp_sa.client:app --num-supernodes 100 +# From a new python environment, run: +pip install -e . +``` + +## Run the project + +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. + +### Run with the Simulation Engine + +```bash +flwr run . +``` + +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example: + +```bash +flwr run . --run-config "noise-multiplier=0.1 clipping-norm=5" ``` + +### Run with the Deployment Engine + +> \[!NOTE\] +> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker. diff --git a/examples/fl-dp-sa/fl_dp_sa/__init__.py b/examples/fl-dp-sa/fl_dp_sa/__init__.py index 741260348ab..c5c9a7e9581 100644 --- a/examples/fl-dp-sa/fl_dp_sa/__init__.py +++ b/examples/fl-dp-sa/fl_dp_sa/__init__.py @@ -1 +1 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" diff --git a/examples/fl-dp-sa/fl_dp_sa/client.py b/examples/fl-dp-sa/fl_dp_sa/client.py deleted file mode 100644 index b3b02c6e9d6..00000000000 --- a/examples/fl-dp-sa/fl_dp_sa/client.py +++ /dev/null @@ -1,42 +0,0 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" - -from flwr.client import ClientApp, NumPyClient -from flwr.client.mod import fixedclipping_mod, secaggplus_mod - -from fl_dp_sa.task import DEVICE, Net, get_weights, load_data, set_weights, test, train - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) - - -# Define FlowerClient and client_fn -class FlowerClient(NumPyClient): - def __init__(self, trainloader, testloader) -> None: - self.trainloader = trainloader - self.testloader = testloader - - def fit(self, parameters, config): - set_weights(net, parameters) - results = train(net, self.trainloader, self.testloader, epochs=1, device=DEVICE) - return get_weights(net), len(self.trainloader.dataset), results - - def evaluate(self, parameters, config): - set_weights(net, parameters) - loss, accuracy = test(net, self.testloader) - return loss, len(self.testloader.dataset), {"accuracy": accuracy} - - -def client_fn(cid: str): - """Create and return an instance of Flower `Client`.""" - trainloader, testloader = load_data(partition_id=int(cid)) - return FlowerClient(trainloader, testloader).to_client() - - -# Flower ClientApp -app = ClientApp( - client_fn=client_fn, - mods=[ - secaggplus_mod, - fixedclipping_mod, - ], -) diff --git a/examples/fl-dp-sa/fl_dp_sa/client_app.py b/examples/fl-dp-sa/fl_dp_sa/client_app.py new file mode 100644 index 00000000000..5630d4f4d14 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/client_app.py @@ -0,0 +1,50 @@ +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" + +import torch +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context +from flwr.client.mod import fixedclipping_mod, secaggplus_mod + +from fl_dp_sa.task import Net, get_weights, load_data, set_weights, test, train + + +class FlowerClient(NumPyClient): + def __init__(self, trainloader, testloader) -> None: + self.net = Net() + self.trainloader = trainloader + self.testloader = testloader + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def fit(self, parameters, config): + set_weights(self.net, parameters) + results = train( + self.net, + self.trainloader, + self.testloader, + epochs=1, + device=self.device, + ) + return get_weights(self.net), len(self.trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(self.net, parameters) + loss, accuracy = test(self.net, self.testloader, self.device) + return loss, len(self.testloader.dataset), {"accuracy": accuracy} + + +def client_fn(context: Context): + partition_id = context.node_config["partition-id"] + trainloader, testloader = load_data( + partition_id=partition_id, num_partitions=context.node_config["num-partitions"] + ) + return FlowerClient(trainloader, testloader).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + fixedclipping_mod, + ], +) diff --git a/examples/fl-dp-sa/fl_dp_sa/server.py b/examples/fl-dp-sa/fl_dp_sa/server_app.py similarity index 56% rename from examples/fl-dp-sa/fl_dp_sa/server.py rename to examples/fl-dp-sa/fl_dp_sa/server_app.py index 3ec0ba757b0..1704b4942ff 100644 --- a/examples/fl-dp-sa/fl_dp_sa/server.py +++ b/examples/fl-dp-sa/fl_dp_sa/server_app.py @@ -1,20 +1,22 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" from typing import List, Tuple from flwr.common import Context, Metrics, ndarrays_to_parameters -from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.server import ( + Driver, + LegacyContext, + ServerApp, + ServerConfig, +) from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping, FedAvg from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow from fl_dp_sa.task import Net, get_weights -# Define metric aggregation function def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] train_accuracies = [ num_examples * m["train_accuracy"] for num_examples, m in metrics @@ -22,7 +24,6 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - # Aggregate and return custom metric (weighted average) return { "train_loss": sum(train_losses) / sum(examples), "train_accuracy": sum(train_accuracies) / sum(examples), @@ -31,30 +32,36 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: } -# Initialize model parameters -ndarrays = get_weights(Net()) -parameters = ndarrays_to_parameters(ndarrays) +app = ServerApp() -# Define strategy -strategy = FedAvg( - fraction_fit=0.2, - fraction_evaluate=0.0, # Disable evaluation for demo purpose - min_fit_clients=20, - min_available_clients=20, - fit_metrics_aggregation_fn=weighted_average, - initial_parameters=parameters, -) -strategy = DifferentialPrivacyClientSideFixedClipping( - strategy, noise_multiplier=0.2, clipping_norm=10, num_sampled_clients=20 -) +@app.main() +def main(driver: Driver, context: Context) -> None: + # Initialize global model + model_weights = get_weights(Net()) + parameters = ndarrays_to_parameters(model_weights) + + # Note: The fraction_fit value is configured based on the DP hyperparameter `num-sampled-clients`. + strategy = FedAvg( + fraction_fit=0.2, + fraction_evaluate=0.0, + min_fit_clients=20, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, + ) -app = ServerApp() + noise_multiplier = context.run_config["noise-multiplier"] + clipping_norm = context.run_config["clipping-norm"] + num_sampled_clients = context.run_config["num-sampled-clients"] + strategy = DifferentialPrivacyClientSideFixedClipping( + strategy, + noise_multiplier=noise_multiplier, + clipping_norm=clipping_norm, + num_sampled_clients=num_sampled_clients, + ) -@app.main() -def main(driver: Driver, context: Context) -> None: # Construct the LegacyContext context = LegacyContext( context=context, @@ -65,8 +72,8 @@ def main(driver: Driver, context: Context) -> None: # Create the train/evaluate workflow workflow = DefaultWorkflow( fit_workflow=SecAggPlusWorkflow( - num_shares=7, - reconstruction_threshold=4, + num_shares=context.run_config["num-shares"], + reconstruction_threshold=context.run_config["reconstruction-threshold"], ) ) diff --git a/examples/fl-dp-sa/fl_dp_sa/task.py b/examples/fl-dp-sa/fl_dp_sa/task.py index 5b4fd7dee59..c145cebe137 100644 --- a/examples/fl-dp-sa/fl_dp_sa/task.py +++ b/examples/fl-dp-sa/fl_dp_sa/task.py @@ -1,24 +1,22 @@ -"""fl_dp_sa: A Flower / PyTorch app.""" +"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation.""" from collections import OrderedDict -from logging import INFO import torch import torch.nn as nn import torch.nn.functional as F -from flwr.common.logger import log from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner from torch.utils.data import DataLoader from torchvision.transforms import Compose, Normalize, ToTensor -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +fds = None # Cache FederatedDataset -class Net(nn.Module): - """Model.""" +class Net(nn.Module): def __init__(self) -> None: - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 6, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) @@ -36,9 +34,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc3(x) -def load_data(partition_id): +def load_data(partition_id: int, num_partitions: int): """Load partition MNIST data.""" - fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="ylecun/mnist", + partitioners={"train": partitioner}, + ) partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) @@ -70,8 +75,8 @@ def train(net, trainloader, valloader, epochs, device): loss.backward() optimizer.step() - train_loss, train_acc = test(net, trainloader) - val_loss, val_acc = test(net, valloader) + train_loss, train_acc = test(net, trainloader, device) + val_loss, val_acc = test(net, valloader, device) results = { "train_loss": train_loss, @@ -82,17 +87,17 @@ def train(net, trainloader, valloader, epochs, device): return results -def test(net, testloader): +def test(net, testloader, device): """Validate the model on the test set.""" - net.to(DEVICE) + net.to(device) criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad(): for batch in testloader: - images = batch["image"].to(DEVICE) - labels = batch["label"].to(DEVICE) - outputs = net(images.to(DEVICE)) - labels = labels.to(DEVICE) + images = batch["image"].to(device) + labels = batch["label"].to(device) + outputs = net(images.to(device)) + labels = labels.to(device) loss += criterion(outputs, labels).item() correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() accuracy = correct / len(testloader.dataset) diff --git a/examples/fl-dp-sa/flower.toml b/examples/fl-dp-sa/flower.toml deleted file mode 100644 index ea2e9820679..00000000000 --- a/examples/fl-dp-sa/flower.toml +++ /dev/null @@ -1,13 +0,0 @@ -[project] -name = "fl_dp_sa" -version = "1.0.0" -description = "" -license = "Apache-2.0" -authors = [ - "The Flower Authors ", -] -readme = "README.md" - -[flower.components] -serverapp = "fl_dp_sa.server:app" -clientapp = "fl_dp_sa.client:app" diff --git a/examples/fl-dp-sa/pyproject.toml b/examples/fl-dp-sa/pyproject.toml index 1ca343b072d..fbb463cc1c0 100644 --- a/examples/fl-dp-sa/pyproject.toml +++ b/examples/fl-dp-sa/pyproject.toml @@ -1,21 +1,40 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "fl-dp-sa" -version = "0.1.0" -description = "" +version = "1.0.0" +description = "Central Differential Privacy and Secure Aggregation in Flower" license = "Apache-2.0" -authors = [ - "The Flower Authors ", +dependencies = [ + "flwr[simulation]>=1.11.0", + "flwr-datasets[vision]>=0.3.0", + "torch==2.2.1", + "torchvision==0.17.1", ] -readme = "README.md" -[tool.poetry.dependencies] -python = "^3.9" -# Mandatory dependencies -flwr = { version = "^1.8.0", extras = ["simulation"] } -flwr-datasets = { version = "0.0.2", extras = ["vision"] } -torch = "2.2.1" -torchvision = "0.17.1" +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "fl_dp_sa.server_app:app" +clientapp = "fl_dp_sa.client_app:app" + +[tool.flwr.app.config] +# Parameters for the DP +noise-multiplier = 0.2 +clipping-norm = 10 +num-sampled-clients = 20 +# Parameters for the SecAgg+ protocol +num-shares = 7 +reconstruction-threshold = 4 + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 100 \ No newline at end of file diff --git a/examples/fl-dp-sa/requirements.txt b/examples/fl-dp-sa/requirements.txt deleted file mode 100644 index f20b9d71e33..00000000000 --- a/examples/fl-dp-sa/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -flwr[simulation]>=1.8.0 -flwr-datasets[vision]==0.0.2 -torch==2.2.1 -torchvision==0.17.1