diff --git a/examples/custom-mods/.gitignore b/examples/custom-mods/.gitignore new file mode 100644 index 00000000000..260d28a67c6 --- /dev/null +++ b/examples/custom-mods/.gitignore @@ -0,0 +1,2 @@ +wandb/ +.runs_history/ diff --git a/examples/custom-mods/README.md b/examples/custom-mods/README.md new file mode 100644 index 00000000000..b0ad668c2de --- /dev/null +++ b/examples/custom-mods/README.md @@ -0,0 +1,339 @@ +# Using custom mods ๐Ÿงช + +> ๐Ÿงช = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to write custom Flower Mods and use them in a simple example. + +## Writing custom Flower Mods + +### Flower Mods basics + +As described [here](https://flower.ai/docs/framework/how-to-use-built-in-mods.html#what-are-mods), Flower Mods in their simplest form can be described as: + +```python +def basic_mod(msg: Message, context: Context, app: ClientApp) -> Message: + # Do something with incoming Message (or Context) + # before passing to the inner ``ClientApp`` + reply = app(msg, context) + # Do something with outgoing Message (or Context) + # before returning + return reply +``` + +and used when defining the `ClientApp`: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[basic_mod], +) +``` + +Note that in this specific case, this mod won't modify anything, and perform FL as usual. + +### WandB Flower Mod + +If we want to write a mod to monitor our client-side training using [Weights & Biases](https://github.com/wandb/wandb), we can follow the steps below. + +First, we need to initialize our W&B project with the correct parameters: + +```python +wandb.init( + project=..., + group=..., + name=..., + id=..., + resume="allow", + reinit=True, +) +``` + +In our case, the group should be the `run_id`, specific to a `ServerApp` run, and the `name` should be the `node_id`. This will make it easy to navigate our W&B project, as for each run we will be able to see the computed results as a whole or for each individual client. + +The `id` needs to be unique, so it will be a combination of `run_id` and `node_id`. + +In the end we have: + +```python +def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project="Mod Name", + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) +``` + +Now, before the message is processed by the server, we will store the starting time and the round number, in order to compute the time it took the client to perform its fit step. + +```python +server_round = int(msg.metadata.group_id) +start_time = time.time() +``` + +And then, we can send the message to the client: + +```python +reply = app(msg, context) +``` + +And now, with the message we got back, we can gather our metrics: + +```python +if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + results_to_log["fit_time"] = time_diff +``` + +Note that we store our metrics in the `results_to_log` variable and that we only initialize this variable when our client is sending back fit results (with content in it). + +Finally, we can send our results to W&B using: + +```python +wandb.log(results_to_log, step=int(server_round), commit=True) +``` + +The complete mod becomes: + +```python +def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + server_round = int(msg.metadata.group_id) + + if reply.metadata.message_type == MessageType.TRAIN and server_round == 1: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project="Mod Name", + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply +``` + +And it can be used like: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[wandb_mod], +) +``` + +If we want to pass an argument to our mod, we can use a wrapper function: + +```python +def get_wandb_mod(name: str) -> Mod: + def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + server_round = int(msg.metadata.group_id) + + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project=name, + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + time_diff = time.time() - start_time + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply + + return wandb_mod +``` + +And use it like: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_wandb_mod("Custom mods example"), + ], +) +``` + +### TensorBoard Flower Mod + +The [TensorBoard](https://www.tensorflow.org/tensorboard) Mod will only differ in the initialization and how the data is sent to TensorBoard: + +```python +def get_tensorboard_mod(logdir) -> Mod: + os.makedirs(logdir, exist_ok=True) + + def tensorboard_mod( + msg: Message, context: Context, app: ClientAppCallable + ) -> Message: + logdir_run = os.path.join(logdir, str(msg.metadata.run_id)) + + node_id = str(msg.metadata.dst_node_id) + + server_round = int(msg.metadata.group_id) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + writer = tf.summary.create_file_writer(os.path.join(logdir_run, node_id)) + + metrics = dict( + reply.content.configs_records.get("fitres.metrics", ConfigsRecord()) + ) + + with writer.as_default(step=server_round): + tf.summary.scalar(f"fit_time", time_diff, step=server_round) + for metric in metrics: + tf.summary.scalar( + f"{metric}", + metrics[metric], + step=server_round, + ) + writer.flush() + + return reply + + return tensorboard_mod +``` + +For the initialization, TensorBoard uses a custom directory path, which can, in this case, be passed as an argument to the wrapper function. + +It can be used in the following way: + +```python +app = fl.client.ClientApp( + client_fn=client_fn, + mods=[get_tensorboard_mod(".runs_history/")], +) +``` + +## Running the example + +### Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +โ”œโ”€โ”€ client.py # <-- contains `ClientApp` +โ”œโ”€โ”€ server.py # <-- contains `ServerApp` +โ”œโ”€โ”€ task.py # <-- task-specific code (model, data) +โ””โ”€โ”€ requirements.txt # <-- dependencies +``` + +### Install dependencies + +```bash +pip install -r requirements.txt +``` + +For [W&B](wandb.ai) you will also need a valid account. + +### Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### Start the long-running Flower client (SuperNode) + +In a new terminal window, start the first long-running Flower client using: + +```bash +flower-client-app client:wandb_app --insecure +``` + +for W&B monitoring, or: + +```bash +flower-client-app client:tb_app --insecure +``` + +for TensorBoard. + +In yet another new terminal window, start the second long-running Flower client (with the mod of your choice): + +```bash +flower-client-app client:{wandb,tb}_app --insecure +``` + +### Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` + +### Check the results + +For W&B, you will need to login to the [website](wandb.ai). + +For TensorBoard, you will need to run the following command in your terminal: + +```sh +tensorboard --logdir +``` + +Where `` needs to be replaced by the directory passed as an argument to the wrapper function (`.runs_history/` by default). diff --git a/examples/custom-mods/client.py b/examples/custom-mods/client.py new file mode 100644 index 00000000000..ca7a0d887a3 --- /dev/null +++ b/examples/custom-mods/client.py @@ -0,0 +1,162 @@ +import logging +import os +import time + +import flwr as fl +import tensorflow as tf +import wandb +from flwr.common import ConfigsRecord +from flwr.client.typing import ClientAppCallable, Mod +from flwr.common.context import Context +from flwr.common.message import Message +from flwr.common.constant import MessageType + +from task import ( + Net, + DEVICE, + load_data, + get_parameters, + set_parameters, + train, + test, +) + + +class WBLoggingFilter(logging.Filter): + def filter(self, record): + return ( + "login" in record.getMessage() + or "View project at" in record.getMessage() + or "View run at" in record.getMessage() + ) + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) +trainloader, testloader = load_data() + + +# Define Flower client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return get_parameters(net) + + def fit(self, parameters, config): + set_parameters(net, parameters) + results = train(net, trainloader, testloader, epochs=1, device=DEVICE) + return get_parameters(net), len(trainloader.dataset), results + + def evaluate(self, parameters, config): + set_parameters(net, parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + return FlowerClient().to_client() + + +def get_wandb_mod(name: str) -> Mod: + def wandb_mod(msg: Message, context: Context, app: ClientAppCallable) -> Message: + """Flower Mod that logs the metrics dictionary returned by the client's + fit function to Weights & Biases. + """ + server_round = int(msg.metadata.group_id) + + if server_round == 1 and msg.metadata.message_type == MessageType.TRAIN: + run_id = msg.metadata.run_id + group_name = f"Run ID: {run_id}" + + node_id = str(msg.metadata.dst_node_id) + run_name = f"Node ID: {node_id}" + + wandb.init( + project=name, + group=group_name, + name=run_name, + id=f"{run_id}_{node_id}", + resume="allow", + reinit=True, + ) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + # if the `ClientApp` just processed a "fit" message, let's log some metrics to W&B + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + + metrics = reply.content.configs_records + + results_to_log = dict(metrics.get("fitres.metrics", ConfigsRecord())) + + results_to_log["fit_time"] = time_diff + + wandb.log(results_to_log, step=int(server_round), commit=True) + + return reply + + return wandb_mod + + +def get_tensorboard_mod(logdir) -> Mod: + os.makedirs(logdir, exist_ok=True) + + def tensorboard_mod( + msg: Message, context: Context, app: ClientAppCallable + ) -> Message: + """Flower Mod that logs the metrics dictionary returned by the client's + fit function to TensorBoard. + """ + logdir_run = os.path.join(logdir, str(msg.metadata.run_id)) + + node_id = str(msg.metadata.dst_node_id) + + server_round = int(msg.metadata.group_id) + + start_time = time.time() + + reply = app(msg, context) + + time_diff = time.time() - start_time + + # if the `ClientApp` just processed a "fit" message, let's log some metrics to TensorBoard + if reply.metadata.message_type == MessageType.TRAIN and reply.has_content(): + writer = tf.summary.create_file_writer(os.path.join(logdir_run, node_id)) + + metrics = dict( + reply.content.configs_records.get("fitres.metrics", ConfigsRecord()) + ) + + with writer.as_default(step=server_round): + tf.summary.scalar(f"fit_time", time_diff, step=server_round) + for metric in metrics: + tf.summary.scalar( + f"{metric}", + metrics[metric], + step=server_round, + ) + writer.flush() + + return reply + + return tensorboard_mod + + +# Run via `flower-client-app client:wandb_app` +wandb_app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_wandb_mod("Custom mods example"), + ], +) + +# Run via `flower-client-app client:tb_app` +tb_app = fl.client.ClientApp( + client_fn=client_fn, + mods=[ + get_tensorboard_mod(".runs_history/"), + ], +) diff --git a/examples/custom-mods/pyproject.toml b/examples/custom-mods/pyproject.toml new file mode 100644 index 00000000000..e690e05bab8 --- /dev/null +++ b/examples/custom-mods/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "app-pytorch" +version = "0.1.0" +description = "Multi-Tenant Federated Learning with Flower and PyTorch" +authors = ["The Flower Authors "] + +[tool.poetry.dependencies] +python = ">=3.8,<3.11" +flwr = { path = "../../", develop = true, extras = ["simulation"] } +tensorboard = "2.16.2" +torch = "1.13.1" +torchvision = "0.14.1" +tqdm = "4.65.0" +wandb = "0.16.3" diff --git a/examples/custom-mods/requirements.txt b/examples/custom-mods/requirements.txt new file mode 100644 index 00000000000..75b2c1135f1 --- /dev/null +++ b/examples/custom-mods/requirements.txt @@ -0,0 +1,6 @@ +flwr-nightly[rest,simulation]>=1.0, <2.0 +tensorboard==2.16.2 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.65.0 +wandb==0.16.3 diff --git a/examples/custom-mods/server.py b/examples/custom-mods/server.py new file mode 100644 index 00000000000..c2d8a4fe5ee --- /dev/null +++ b/examples/custom-mods/server.py @@ -0,0 +1,45 @@ +from typing import List, Tuple + +import flwr as fl +from flwr.common import Metrics + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [ + num_examples * float(m["train_loss"]) for num_examples, m in metrics + ] + train_accuracies = [ + num_examples * float(m["train_accuracy"]) for num_examples, m in metrics + ] + val_losses = [num_examples * float(m["val_loss"]) for num_examples, m in metrics] + val_accuracies = [ + num_examples * float(m["val_accuracy"]) for num_examples, m in metrics + ] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Define strategy +strategy = fl.server.strategy.FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=2, + fit_metrics_aggregation_fn=weighted_average, +) + + +# Run via `flower-server-app server:app` +app = fl.server.ServerApp( + config=fl.server.ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/examples/custom-mods/task.py b/examples/custom-mods/task.py new file mode 100644 index 00000000000..276aace885d --- /dev/null +++ b/examples/custom-mods/task.py @@ -0,0 +1,95 @@ +import warnings +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor +from tqdm import tqdm + + +warnings.filterwarnings("ignore", category=UserWarning) +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + print("Starting training...") + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + net.train() + for _ in range(epochs): + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for images, labels in tqdm(testloader): + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) + + +def get_parameters(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_parameters(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True)