-
Notifications
You must be signed in to change notification settings - Fork 858
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d16807a
commit 9a2c96a
Showing
5 changed files
with
338 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,129 @@ | ||
"""$project_name: A Flower / PyTorch app.""" | ||
|
||
from collections import OrderedDict | ||
from typing import Dict, Tuple, List | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
import flwr as fl | ||
from flwr.common import Metrics | ||
from flwr.common.typing import Scalar | ||
|
||
from flwr_datasets import FederatedDataset | ||
|
||
from utils import Net, train, test, apply_transforms | ||
|
||
NUM_CLIENTS = 100 | ||
NUM_ROUNDS = 10 | ||
|
||
|
||
# Flower client, adapted from Pytorch quickstart example | ||
class FlowerClient(fl.client.NumPyClient): | ||
def __init__(self, trainset, valset): | ||
self.trainset = trainset | ||
self.valset = valset | ||
|
||
# Instantiate model | ||
self.model = Net() | ||
|
||
# Determine device | ||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
self.model.to(self.device) # send model to device | ||
|
||
def get_parameters(self, config): | ||
return [val.cpu().numpy() for _, val in self.model.state_dict().items()] | ||
|
||
def fit(self, parameters, config): | ||
set_params(self.model, parameters) | ||
|
||
# Read from config | ||
batch, epochs = config["batch_size"], config["epochs"] | ||
|
||
# Construct dataloader | ||
trainloader = DataLoader(self.trainset, batch_size=batch, shuffle=True) | ||
|
||
# Define optimizer | ||
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | ||
# Train | ||
train(self.model, trainloader, optimizer, epochs=epochs, device=self.device) | ||
|
||
# Return local model and statistics | ||
return self.get_parameters({}), len(trainloader.dataset), {} | ||
|
||
def evaluate(self, parameters, config): | ||
set_params(self.model, parameters) | ||
|
||
# Construct dataloader | ||
valloader = DataLoader(self.valset, batch_size=64) | ||
|
||
# Evaluate | ||
loss, accuracy = test(self.model, valloader, device=self.device) | ||
|
||
# Return statistics | ||
return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} | ||
|
||
|
||
def get_client_fn(dataset: FederatedDataset): | ||
"""Return a function to construct a client. | ||
|
||
The VirtualClientEngine will execute this function whenever a client is sampled by | ||
the strategy to participate. | ||
""" | ||
|
||
def client_fn(cid: str) -> fl.client.Client: | ||
"""Construct a FlowerClient with its own dataset partition.""" | ||
|
||
# Let's get the partition corresponding to the i-th client | ||
client_dataset = dataset.load_partition(int(cid), "train") | ||
|
||
# Now let's split it into train (90%) and validation (10%) | ||
client_dataset_splits = client_dataset.train_test_split(test_size=0.1) | ||
|
||
trainset = client_dataset_splits["train"] | ||
valset = client_dataset_splits["test"] | ||
|
||
# Now we apply the transform to each batch. | ||
trainset = trainset.with_transform(apply_transforms) | ||
valset = valset.with_transform(apply_transforms) | ||
|
||
# Create and return client | ||
return FlowerClient(trainset, valset).to_client() | ||
|
||
return client_fn | ||
|
||
|
||
def fit_config(server_round: int) -> Dict[str, Scalar]: | ||
"""Return a configuration with static batch size and (local) epochs.""" | ||
config = { | ||
"epochs": 1, # Number of local epochs done by clients | ||
"batch_size": 32, # Batch size to use by clients during fit() | ||
} | ||
return config | ||
|
||
|
||
def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]): | ||
"""Set model weights from a list of NumPy ndarrays.""" | ||
params_dict = zip(model.state_dict().keys(), params) | ||
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | ||
model.load_state_dict(state_dict, strict=True) | ||
|
||
|
||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
"""Aggregation function for (federated) evaluation metrics, i.e. those returned by | ||
the client's evaluate() method.""" | ||
# Multiply accuracy of each client by number of examples used | ||
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return {"accuracy": sum(accuracies) / sum(examples)} | ||
|
||
|
||
# Download MNIST dataset and partition it | ||
mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) | ||
|
||
# ClientApp for Flower-Next | ||
app = fl.client.ClientApp( | ||
client_fn=get_client_fn(mnist_fds), | ||
) |
129 changes: 129 additions & 0 deletions
129
src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,130 @@ | ||
"""$project_name: A Flower / PyTorch app.""" | ||
|
||
"""$project_name: A Flower / PyTorch app.""" | ||
|
||
from collections import OrderedDict | ||
from typing import Dict, Tuple, List | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
import flwr as fl | ||
from flwr.common import Metrics | ||
from flwr.common.typing import Scalar | ||
|
||
from datasets import Dataset | ||
from datasets.utils.logging import disable_progress_bar | ||
from flwr_datasets import FederatedDataset | ||
|
||
from utils import Net, test, apply_transforms | ||
|
||
NUM_CLIENTS = 100 | ||
NUM_ROUNDS = 10 | ||
|
||
|
||
def get_client_fn(dataset: FederatedDataset): | ||
"""Return a function to construct a client. | ||
|
||
The VirtualClientEngine will execute this function whenever a client is sampled by | ||
the strategy to participate. | ||
""" | ||
|
||
def client_fn(cid: str) -> fl.client.Client: | ||
"""Construct a FlowerClient with its own dataset partition.""" | ||
|
||
# Let's get the partition corresponding to the i-th client | ||
client_dataset = dataset.load_partition(int(cid), "train") | ||
|
||
# Now let's split it into train (90%) and validation (10%) | ||
client_dataset_splits = client_dataset.train_test_split(test_size=0.1) | ||
|
||
trainset = client_dataset_splits["train"] | ||
valset = client_dataset_splits["test"] | ||
|
||
# Now we apply the transform to each batch. | ||
trainset = trainset.with_transform(apply_transforms) | ||
valset = valset.with_transform(apply_transforms) | ||
|
||
# Create and return client | ||
return FlowerClient(trainset, valset).to_client() | ||
|
||
return client_fn | ||
|
||
|
||
def fit_config(server_round: int) -> Dict[str, Scalar]: | ||
"""Return a configuration with static batch size and (local) epochs.""" | ||
config = { | ||
"epochs": 1, # Number of local epochs done by clients | ||
"batch_size": 32, # Batch size to use by clients during fit() | ||
} | ||
return config | ||
|
||
|
||
def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]): | ||
"""Set model weights from a list of NumPy ndarrays.""" | ||
params_dict = zip(model.state_dict().keys(), params) | ||
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | ||
model.load_state_dict(state_dict, strict=True) | ||
|
||
|
||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
"""Aggregation function for (federated) evaluation metrics, i.e. those returned by | ||
the client's evaluate() method.""" | ||
# Multiply accuracy of each client by number of examples used | ||
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return {"accuracy": sum(accuracies) / sum(examples)} | ||
|
||
|
||
def get_evaluate_fn( | ||
centralized_testset: Dataset, | ||
): | ||
"""Return an evaluation function for centralized evaluation.""" | ||
|
||
def evaluate( | ||
server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] | ||
): | ||
"""Use the entire CIFAR-10 test set for evaluation.""" | ||
|
||
# Determine device | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
model = Net() | ||
set_params(model, parameters) | ||
model.to(device) | ||
|
||
# Apply transform to dataset | ||
testset = centralized_testset.with_transform(apply_transforms) | ||
|
||
# Disable tqdm for dataset preprocessing | ||
disable_progress_bar() | ||
|
||
testloader = DataLoader(testset, batch_size=50) | ||
loss, accuracy = test(model, testloader, device=device) | ||
|
||
return loss, {"accuracy": accuracy} | ||
|
||
return evaluate | ||
|
||
|
||
# Download MNIST dataset and partition it | ||
mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS}) | ||
centralized_testset = mnist_fds.load_full("test") | ||
|
||
# Configure the strategy | ||
strategy = fl.server.strategy.FedAvg( | ||
fraction_fit=0.1, # Sample 10% of available clients for training | ||
fraction_evaluate=0.05, # Sample 5% of available clients for evaluation | ||
min_available_clients=10, | ||
on_fit_config_fn=fit_config, | ||
evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics | ||
evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function | ||
) | ||
|
||
# ServerApp for Flower-Next | ||
server = fl.server.ServerApp( | ||
config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), | ||
strategy=strategy, | ||
) |
64 changes: 64 additions & 0 deletions
64
src/py/flwr/cli/new/templates/app/code/utils.pytorch.py.tpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from torchvision.transforms import ToTensor, Normalize, Compose | ||
|
||
|
||
# transformation to convert images to tensors and apply normalization | ||
def apply_transforms(batch): | ||
transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) | ||
batch["image"] = [transforms(img) for img in batch["image"]] | ||
return batch | ||
|
||
|
||
# Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz') | ||
class Net(nn.Module): | ||
def __init__(self, num_classes: int = 10) -> None: | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 4 * 4, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, num_classes) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 4 * 4) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
|
||
# borrowed from Pytorch quickstart example | ||
def train(net, trainloader, optim, epochs, device: str): | ||
"""Train the network on the training set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
net.train() | ||
for _ in range(epochs): | ||
for batch in trainloader: | ||
images, labels = batch["image"].to(device), batch["label"].to(device) | ||
optim.zero_grad() | ||
loss = criterion(net(images), labels) | ||
loss.backward() | ||
optim.step() | ||
|
||
|
||
# borrowed from Pytorch quickstart example | ||
def test(net, testloader, device: str): | ||
"""Validate the network on the entire test set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
correct, loss = 0, 0.0 | ||
net.eval() | ||
with torch.no_grad(): | ||
for data in testloader: | ||
images, labels = data["image"].to(device), data["label"].to(device) | ||
outputs = net(images) | ||
loss += criterion(outputs, labels).item() | ||
_, predicted = torch.max(outputs.data, 1) | ||
correct += (predicted == labels).sum().item() | ||
accuracy = correct / len(testloader.dataset) | ||
return loss, accuracy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters