From 74d64dd0e203279916f1a7553c9a05d25b16ed61 Mon Sep 17 00:00:00 2001 From: Mohammad Naseri Date: Fri, 11 Oct 2024 13:50:12 +0100 Subject: [PATCH] refactor(examples) Update opacus example (#4262) Co-authored-by: jafermarq --- examples/opacus/README.md | 60 ++++----- examples/opacus/client.py | 171 ------------------------ examples/opacus/opacus_fl/__init__.py | 1 + examples/opacus/opacus_fl/client_app.py | 92 +++++++++++++ examples/opacus/opacus_fl/server_app.py | 37 +++++ examples/opacus/opacus_fl/task.py | 102 ++++++++++++++ examples/opacus/pyproject.toml | 31 ++++- examples/opacus/server.py | 22 --- 8 files changed, 285 insertions(+), 231 deletions(-) delete mode 100644 examples/opacus/client.py create mode 100644 examples/opacus/opacus_fl/__init__.py create mode 100644 examples/opacus/opacus_fl/client_app.py create mode 100644 examples/opacus/opacus_fl/server_app.py create mode 100644 examples/opacus/opacus_fl/task.py delete mode 100644 examples/opacus/server.py diff --git a/examples/opacus/README.md b/examples/opacus/README.md index aea5d0f689f..5a816d008f9 100644 --- a/examples/opacus/README.md +++ b/examples/opacus/README.md @@ -1,5 +1,5 @@ --- -tags: [dp, security, fds] +tags: [DP, DP-SGD, basic, vision, fds, privacy] dataset: [CIFAR-10] framework: [opacus, torch] --- @@ -10,57 +10,55 @@ In this example, we demonstrate how to train a model with differential privacy ( For more information about DP in Flower please refer to the [tutorial](https://flower.ai/docs/framework/how-to-use-differential-privacy.html). For additional information about Opacus, visit the official [website](https://opacus.ai/). -## Environments Setup +## Set up the project -Start by cloning the example. We prepared a single-line command that you can copy into your shell which will checkout the example for you: +### Clone the project + +Start by cloning the example project: ```shell -git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/opacus . && rm -rf flower && cd opacus +git clone --depth=1 https://github.com/adap/flower.git \ + && mv flower/examples/opacus . \ + && rm -rf flower \ + && cd opacus ``` This will create a new directory called `opacus` containing the following files: ```shell --- pyproject.toml --- client.py --- server.py --- README.md +opacus +├── opacus_fl +│ ├── client_app.py # Defines your ClientApp +│ ├── server_app.py # Defines your ServerApp +│ └── task.py # Defines your model, training, and data loading +├── pyproject.toml # Project metadata like dependencies and configs +└── README.md ``` -### Installing dependencies +### Install dependencies and project -Project dependencies are defined in `pyproject.toml`. Install them with: +Install the dependencies defined in `pyproject.toml` as well as the `opacus_fl` package. ```shell -pip install . +# From a new python environment, run: +pip install -e . ``` -## Run Flower with Opacus and Pytorch - -### 1. Start the long-running Flower server (SuperLink) - -```bash -flower-superlink --insecure -``` +## Run the project -### 2. Start the long-running Flower clients (SuperNodes) +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. -Start 2 Flower `SuperNodes` in 2 separate terminal windows, using: +### Run with the Simulation Engine ```bash -flower-client-app client:appA --insecure +flwr run . ``` -```bash -flower-client-app client:appB --insecure -``` - -Opacus hyperparameters can be passed for each client in `ClientApp` instantiation (in `client.py`). In this example, `noise_multiplier=1.5` and `noise_multiplier=1` are used for the first and second client respectively. - -### 3. Run the Flower App - -With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example: ```bash -flower-server-app server:app --insecure +flwr run . --run-config "max-grad-norm=1.0 num-server-rounds=5" ``` + +> \[!NOTE\] +> Please note that, at the current state, users cannot set `NodeConfig` for simulated `ClientApp`s. For this reason, the hyperparameter `noise_multiplier` is set in the `client_fn` method based on a condition check on `partition_id`. This will be modified in a future version of Flower to allow users to set `NodeConfig` for simulated `ClientApp`s. diff --git a/examples/opacus/client.py b/examples/opacus/client.py deleted file mode 100644 index 2771a5d78bc..00000000000 --- a/examples/opacus/client.py +++ /dev/null @@ -1,171 +0,0 @@ -import argparse -import warnings -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -from flwr.client import ClientApp, NumPyClient -from flwr_datasets import FederatedDataset -from opacus import PrivacyEngine -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize, ToTensor -from tqdm import tqdm - -warnings.filterwarnings("ignore", category=UserWarning) - -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class Net(nn.Module): - """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" - - def __init__(self) -> None: - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - return self.fc3(x) - - -def train(net, train_loader, privacy_engine, optimizer, target_delta, epochs=1): - criterion = torch.nn.CrossEntropyLoss() - for _ in range(epochs): - for batch in tqdm(train_loader, "Training"): - images = batch["img"] - labels = batch["label"] - optimizer.zero_grad() - criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() - optimizer.step() - - epsilon = privacy_engine.get_epsilon(delta=target_delta) - return epsilon - - -def test(net, test_loader): - criterion = torch.nn.CrossEntropyLoss() - correct, loss = 0, 0.0 - with torch.no_grad(): - for batch in tqdm(test_loader, "Testing"): - images = batch["img"].to(DEVICE) - labels = batch["label"].to(DEVICE) - outputs = net(images) - loss += criterion(outputs, labels).item() - correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() - accuracy = correct / len(test_loader.dataset) - return loss, accuracy - - -def load_data(partition_id): - fds = FederatedDataset(dataset="cifar10", partitioners={"train": 2}) - partition = fds.load_partition(partition_id) - # Divide data on each node: 80% train, 20% test - partition_train_test = partition.train_test_split(test_size=0.2, seed=42) - pytorch_transforms = Compose( - [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) - - def apply_transforms(batch): - batch["img"] = [pytorch_transforms(img) for img in batch["img"]] - return batch - - partition_train_test = partition_train_test.with_transform(apply_transforms) - train_loader = DataLoader( - partition_train_test["train"], batch_size=32, shuffle=True - ) - test_loader = DataLoader(partition_train_test["test"], batch_size=32) - return train_loader, test_loader - - -class FlowerClient(NumPyClient): - def __init__( - self, - model, - train_loader, - test_loader, - target_delta, - noise_multiplier, - max_grad_norm, - ) -> None: - super().__init__() - self.test_loader = test_loader - self.optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) - self.privacy_engine = PrivacyEngine(secure_mode=False) - self.target_delta = target_delta - ( - self.model, - self.optimizer, - self.train_loader, - ) = self.privacy_engine.make_private( - module=model, - optimizer=self.optimizer, - data_loader=train_loader, - noise_multiplier=noise_multiplier, - max_grad_norm=max_grad_norm, - ) - - def get_parameters(self, config): - return [val.cpu().numpy() for _, val in self.model.state_dict().items()] - - def set_parameters(self, parameters): - params_dict = zip(self.model.state_dict().keys(), parameters) - state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - self.model.load_state_dict(state_dict, strict=True) - - def fit(self, parameters, config): - self.set_parameters(parameters) - epsilon = train( - self.model, - self.train_loader, - self.privacy_engine, - self.optimizer, - self.target_delta, - ) - - if epsilon is not None: - print(f"Epsilon value for delta={self.target_delta} is {epsilon:.2f}") - else: - print("Epsilon value not available.") - return (self.get_parameters(config={}), len(self.train_loader), {}) - - def evaluate(self, parameters, config): - self.set_parameters(parameters) - loss, accuracy = test(self.model, self.test_loader) - return loss, len(self.test_loader.dataset), {"accuracy": accuracy} - - -def client_fn_parameterized( - partition_id, target_delta=1e-5, noise_multiplier=1.3, max_grad_norm=1.0 -): - def client_fn(cid: str): - net = Net().to(DEVICE) - train_loader, test_loader = load_data(partition_id=partition_id) - return FlowerClient( - net, - train_loader, - test_loader, - target_delta, - noise_multiplier, - max_grad_norm, - ).to_client() - - return client_fn - - -appA = ClientApp( - client_fn=client_fn_parameterized(partition_id=0, noise_multiplier=1.5), -) - -appB = ClientApp( - client_fn=client_fn_parameterized(partition_id=1, noise_multiplier=1), -) diff --git a/examples/opacus/opacus_fl/__init__.py b/examples/opacus/opacus_fl/__init__.py new file mode 100644 index 00000000000..91006b32e38 --- /dev/null +++ b/examples/opacus/opacus_fl/__init__.py @@ -0,0 +1 @@ +"""opacus: Training with Sample-Level Differential Privacy using Opacus Privacy Engine.""" diff --git a/examples/opacus/opacus_fl/client_app.py b/examples/opacus/opacus_fl/client_app.py new file mode 100644 index 00000000000..631e9909278 --- /dev/null +++ b/examples/opacus/opacus_fl/client_app.py @@ -0,0 +1,92 @@ +"""opacus: Training with Sample-Level Differential Privacy using Opacus Privacy Engine.""" + +import warnings + +import torch +from opacus import PrivacyEngine +from opacus_fl.task import Net, get_weights, load_data, set_weights, test, train +import logging + +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context + +warnings.filterwarnings("ignore", category=UserWarning) + + +class FlowerClient(NumPyClient): + def __init__( + self, + train_loader, + test_loader, + target_delta, + noise_multiplier, + max_grad_norm, + ) -> None: + super().__init__() + self.model = Net() + self.train_loader = train_loader + self.test_loader = test_loader + self.target_delta = target_delta + self.noise_multiplier = noise_multiplier + self.max_grad_norm = max_grad_norm + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def fit(self, parameters, config): + model = self.model + set_weights(model, parameters) + + optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + privacy_engine = PrivacyEngine(secure_mode=False) + ( + model, + optimizer, + self.train_loader, + ) = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=self.train_loader, + noise_multiplier=self.noise_multiplier, + max_grad_norm=self.max_grad_norm, + ) + + epsilon = train( + model, + self.train_loader, + privacy_engine, + optimizer, + self.target_delta, + device=self.device, + ) + + if epsilon is not None: + print(f"Epsilon value for delta={self.target_delta} is {epsilon:.2f}") + else: + print("Epsilon value not available.") + + return (get_weights(model), len(self.train_loader.dataset), {}) + + def evaluate(self, parameters, config): + set_weights(self.model, parameters) + loss, accuracy = test(self.model, self.test_loader, self.device) + return loss, len(self.test_loader.dataset), {"accuracy": accuracy} + + +def client_fn(context: Context): + partition_id = context.node_config["partition-id"] + noise_multiplier = 1.0 if partition_id % 2 == 0 else 1.5 + + train_loader, test_loader = load_data( + partition_id=partition_id, num_partitions=context.node_config["num-partitions"] + ) + return FlowerClient( + train_loader, + test_loader, + context.run_config["target-delta"], + noise_multiplier, + context.run_config["max-grad-norm"], + ).to_client() + + +app = ClientApp(client_fn=client_fn) diff --git a/examples/opacus/opacus_fl/server_app.py b/examples/opacus/opacus_fl/server_app.py new file mode 100644 index 00000000000..2c105d36df4 --- /dev/null +++ b/examples/opacus/opacus_fl/server_app.py @@ -0,0 +1,37 @@ +"""opacus: Training with Sample-Level Differential Privacy using Opacus Privacy Engine.""" + +import logging +from typing import List, Tuple + +from opacus_fl.task import Net, get_weights + +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedAvg + +# Opacus logger seems to change the flwr logger to DEBUG level. Set back to INFO +logging.getLogger("flwr").setLevel(logging.INFO) + + +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + return {"accuracy": sum(accuracies) / sum(examples)} + + +def server_fn(context: Context) -> ServerAppComponents: + num_rounds = context.run_config["num-server-rounds"] + + ndarrays = get_weights(Net()) + parameters = ndarrays_to_parameters(ndarrays) + + strategy = FedAvg( + evaluate_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, + ) + config = ServerConfig(num_rounds=num_rounds) + + return ServerAppComponents(config=config, strategy=strategy) + + +app = ServerApp(server_fn=server_fn) diff --git a/examples/opacus/opacus_fl/task.py b/examples/opacus/opacus_fl/task.py new file mode 100644 index 00000000000..0c7ef71dc50 --- /dev/null +++ b/examples/opacus/opacus_fl/task.py @@ -0,0 +1,102 @@ +"""opacus: Training with Sample-Level Differential Privacy using Opacus Privacy Engine.""" + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize, ToTensor +from tqdm import tqdm + +fds = None # Cache FederatedDataset + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + +def load_data(partition_id: int, num_partitions: int): + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="uoft-cs/cifar10", + partitioners={"train": partitioner}, + ) + + partition = fds.load_partition(partition_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2, seed=42) + pytorch_transforms = Compose( + [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + def apply_transforms(batch): + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + train_loader = DataLoader( + partition_train_test["train"], batch_size=32, shuffle=True + ) + test_loader = DataLoader(partition_train_test["test"], batch_size=32) + return train_loader, test_loader + + +def train(net, train_loader, privacy_engine, optimizer, target_delta, device, epochs=1): + criterion = torch.nn.CrossEntropyLoss() + net.to(device) + net.train() + for _ in range(epochs): + for batch in tqdm(train_loader, "Training"): + images = batch["img"] + labels = batch["label"] + optimizer.zero_grad() + criterion(net(images.to(device)), labels.to(device)).backward() + optimizer.step() + + epsilon = privacy_engine.get_epsilon(delta=target_delta) + return epsilon + + +def test(net, test_loader, device): + net.to(device) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for batch in tqdm(test_loader, "Testing"): + images = batch["img"].to(device) + labels = batch["label"].to(device) + outputs = net(images) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(test_loader.dataset) + return loss, accuracy diff --git a/examples/opacus/pyproject.toml b/examples/opacus/pyproject.toml index 0489dff94c8..208ce8f4550 100644 --- a/examples/opacus/pyproject.toml +++ b/examples/opacus/pyproject.toml @@ -3,18 +3,35 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "opacus-fl" -version = "0.1.0" -description = "Sample Differential Privacy with Opacus in Flower" -authors = [{ name = "The Flower Authors", email = "hello@flower.ai" }] +name = "opacus_fl" +version = "1.0.0" +description = "Sample-level Differential Privacy with Opacus in Flower" + dependencies = [ - "flwr>=1.8.0,<2.0", - "flwr-datasets[vision]>=0.0.2,<1.0.0", + "flwr[simulation]>=1.11.1", + "flwr-datasets[vision]>=0.3.0", "torch==2.1.1", "torchvision==0.16.1", - "tqdm==4.65.0", "opacus==v1.4.1", ] [tool.hatch.build.targets.wheel] packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "opacus_fl.server_app:app" +clientapp = "opacus_fl.client_app:app" + +[tool.flwr.app.config] +num-server-rounds = 3 +target-delta = 1e-5 +max-grad-norm = 1.0 + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 2 diff --git a/examples/opacus/server.py b/examples/opacus/server.py deleted file mode 100644 index 68c1c027d3d..00000000000 --- a/examples/opacus/server.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import List, Tuple - -import flwr as fl -from flwr.common import Metrics -from flwr.server import ServerApp, ServerConfig -from flwr.server.strategy import FedAvg - - -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] - examples = [num_examples for num_examples, _ in metrics] - return {"accuracy": sum(accuracies) / sum(examples)} - - -strategy = FedAvg(evaluate_metrics_aggregation_fn=weighted_average) - -config = ServerConfig(num_rounds=3) - -app = ServerApp( - config=config, - strategy=strategy, -)