Skip to content

Commit

Permalink
refactor(examples) Update fl-dp-sa example (#4137)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <javier@flower.ai>
  • Loading branch information
mohammadnaseri and jafermarq authored Sep 5, 2024
1 parent 1187c70 commit f290fe2
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 131 deletions.
61 changes: 48 additions & 13 deletions examples/fl-dp-sa/README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,63 @@
---
tags: [basic, vision, fds]
tags: [DP, SecAgg, vision, fds]
dataset: [MNIST]
framework: [torch, torchvision]
---

# Example of Flower App with DP and SA
# Flower Example on MNIST with Differential Privacy and Secure Aggregation

This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation.
Note: This example is designed for a small number of rounds and is intended for demonstration purposes.
This example demonstrates a federated learning setup using the Flower, incorporating central differential privacy (DP) with client-side fixed clipping and secure aggregation (SA). It is intended for a small number of rounds for demonstration purposes.

## Install dependencies
This example is similar to the [quickstart-pytorch example](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) and extends it by integrating central differential privacy and secure aggregation. For more details on differential privacy and secure aggregation in Flower, please refer to the documentation [here](https://flower.ai/docs/framework/how-to-use-differential-privacy.html) and [here](https://flower.ai/docs/framework/contributor-ref-secure-aggregation-protocols.html).

```bash
# Using pip
pip install .
## Set up the project

### Clone the project

Start by cloning the example project:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/fl-dp-sa . && rm -rf flower && cd fl-dp-sa
```

This will create a new directory called `fl-dp-sa` containing the following files:

# Or using Poetry
poetry install
```shell
fl-dp-sa
├── fl_dp_sa
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training, and data loading
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
```

## Run
### Install dependencies and project

The example uses the MNIST dataset with a total of 100 clients, with 20 clients sampled in each round. The hyperparameters for DP and SecAgg are specified in `server.py`.
Install the dependencies defined in `pyproject.toml` as well as the `fl_dp_sa` package.

```shell
flower-simulation --server-app fl_dp_sa.server:app --client-app fl_dp_sa.client:app --num-supernodes 100
# From a new python environment, run:
pip install -e .
```

## Run the project

You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine.

### Run with the Simulation Engine

```bash
flwr run .
```

You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example:

```bash
flwr run . --run-config "noise-multiplier=0.1 clipping-norm=5"
```

### Run with the Deployment Engine

> \[!NOTE\]
> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker.
2 changes: 1 addition & 1 deletion examples/fl-dp-sa/fl_dp_sa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""fl_dp_sa: A Flower / PyTorch app."""
"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation."""
42 changes: 0 additions & 42 deletions examples/fl-dp-sa/fl_dp_sa/client.py

This file was deleted.

50 changes: 50 additions & 0 deletions examples/fl-dp-sa/fl_dp_sa/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation."""

import torch
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from flwr.client.mod import fixedclipping_mod, secaggplus_mod

from fl_dp_sa.task import Net, get_weights, load_data, set_weights, test, train


class FlowerClient(NumPyClient):
def __init__(self, trainloader, testloader) -> None:
self.net = Net()
self.trainloader = trainloader
self.testloader = testloader
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def fit(self, parameters, config):
set_weights(self.net, parameters)
results = train(
self.net,
self.trainloader,
self.testloader,
epochs=1,
device=self.device,
)
return get_weights(self.net), len(self.trainloader.dataset), results

def evaluate(self, parameters, config):
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.testloader, self.device)
return loss, len(self.testloader.dataset), {"accuracy": accuracy}


def client_fn(context: Context):
partition_id = context.node_config["partition-id"]
trainloader, testloader = load_data(
partition_id=partition_id, num_partitions=context.node_config["num-partitions"]
)
return FlowerClient(trainloader, testloader).to_client()


# Flower ClientApp
app = ClientApp(
client_fn=client_fn,
mods=[
secaggplus_mod,
fixedclipping_mod,
],
)
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
"""fl_dp_sa: A Flower / PyTorch app."""
"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation."""

from typing import List, Tuple

from flwr.common import Context, Metrics, ndarrays_to_parameters
from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig
from flwr.server import (
Driver,
LegacyContext,
ServerApp,
ServerConfig,
)
from flwr.server.strategy import DifferentialPrivacyClientSideFixedClipping, FedAvg
from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow

from fl_dp_sa.task import Net, get_weights


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
examples = [num_examples for num_examples, _ in metrics]

# Multiply accuracy of each client by number of examples used
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
train_accuracies = [
num_examples * m["train_accuracy"] for num_examples, m in metrics
]
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

# Aggregate and return custom metric (weighted average)
return {
"train_loss": sum(train_losses) / sum(examples),
"train_accuracy": sum(train_accuracies) / sum(examples),
Expand All @@ -31,30 +32,36 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
}


# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)
app = ServerApp()


# Define strategy
strategy = FedAvg(
fraction_fit=0.2,
fraction_evaluate=0.0, # Disable evaluation for demo purpose
min_fit_clients=20,
min_available_clients=20,
fit_metrics_aggregation_fn=weighted_average,
initial_parameters=parameters,
)
strategy = DifferentialPrivacyClientSideFixedClipping(
strategy, noise_multiplier=0.2, clipping_norm=10, num_sampled_clients=20
)
@app.main()
def main(driver: Driver, context: Context) -> None:

# Initialize global model
model_weights = get_weights(Net())
parameters = ndarrays_to_parameters(model_weights)

# Note: The fraction_fit value is configured based on the DP hyperparameter `num-sampled-clients`.
strategy = FedAvg(
fraction_fit=0.2,
fraction_evaluate=0.0,
min_fit_clients=20,
fit_metrics_aggregation_fn=weighted_average,
initial_parameters=parameters,
)

app = ServerApp()
noise_multiplier = context.run_config["noise-multiplier"]
clipping_norm = context.run_config["clipping-norm"]
num_sampled_clients = context.run_config["num-sampled-clients"]

strategy = DifferentialPrivacyClientSideFixedClipping(
strategy,
noise_multiplier=noise_multiplier,
clipping_norm=clipping_norm,
num_sampled_clients=num_sampled_clients,
)

@app.main()
def main(driver: Driver, context: Context) -> None:
# Construct the LegacyContext
context = LegacyContext(
context=context,
Expand All @@ -65,8 +72,8 @@ def main(driver: Driver, context: Context) -> None:
# Create the train/evaluate workflow
workflow = DefaultWorkflow(
fit_workflow=SecAggPlusWorkflow(
num_shares=7,
reconstruction_threshold=4,
num_shares=context.run_config["num-shares"],
reconstruction_threshold=context.run_config["reconstruction-threshold"],
)
)

Expand Down
39 changes: 22 additions & 17 deletions examples/fl-dp-sa/fl_dp_sa/task.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
"""fl_dp_sa: A Flower / PyTorch app."""
"""fl_dp_sa: Flower Example using Differential Privacy and Secure Aggregation."""

from collections import OrderedDict
from logging import INFO

import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr.common.logger import log
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

fds = None # Cache FederatedDataset

class Net(nn.Module):
"""Model."""

class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
Expand All @@ -36,9 +34,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc3(x)


def load_data(partition_id):
def load_data(partition_id: int, num_partitions: int):
"""Load partition MNIST data."""
fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})

global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="ylecun/mnist",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
Expand Down Expand Up @@ -70,8 +75,8 @@ def train(net, trainloader, valloader, epochs, device):
loss.backward()
optimizer.step()

train_loss, train_acc = test(net, trainloader)
val_loss, val_acc = test(net, valloader)
train_loss, train_acc = test(net, trainloader, device)
val_loss, val_acc = test(net, valloader, device)

results = {
"train_loss": train_loss,
Expand All @@ -82,17 +87,17 @@ def train(net, trainloader, valloader, epochs, device):
return results


def test(net, testloader):
def test(net, testloader, device):
"""Validate the model on the test set."""
net.to(DEVICE)
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["image"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
images = batch["image"].to(device)
labels = batch["label"].to(device)
outputs = net(images.to(device))
labels = labels.to(device)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
Expand Down
13 changes: 0 additions & 13 deletions examples/fl-dp-sa/flower.toml

This file was deleted.

Loading

0 comments on commit f290fe2

Please sign in to comment.