Skip to content

Commit

Permalink
Update cli new command templates
Browse files Browse the repository at this point in the history
  • Loading branch information
tanertopal committed Mar 4, 2024
1 parent d16807a commit 9a2c96a
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 10 deletions.
19 changes: 10 additions & 9 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def new(
if value == framework_value
]
framework_str = selected_value[0]

framework_str = framework_str.lower()

# Set project directory path
cwd = os.getcwd()
Expand All @@ -106,18 +108,17 @@ def new(
"README.md": {
"template": "app/README.md.tpl",
},
"requirements.txt": {
"template": f"app/requirements.{framework_str.lower()}.txt.tpl"
},
"requirements.txt": {"template": f"app/requirements.{framework_str}.txt.tpl"},
"flower.toml": {"template": "app/flower.toml.tpl"},
f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{pnl}/server.py": {
"template": f"app/code/server.{framework_str.lower()}.py.tpl"
},
f"{pnl}/client.py": {
"template": f"app/code/client.{framework_str.lower()}.py.tpl"
},
f"{pnl}/server.py": {"template": f"app/code/server.{framework_str}.py.tpl"},
f"{pnl}/client.py": {"template": f"app/code/client.{framework_str}.py.tpl"},
}

# In case framework is MlFramework.PYTORCH generate additionally the utils.py file
if framework_str == MlFramework.PYTORCH.value.lower():
files[f"{pnl}/utils.py"] = {"template": f"app/code/utils.{framework_str}.py.tpl"}

context = {"project_name": project_name}

for file_path, value in files.items():
Expand Down
128 changes: 128 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1 +1,129 @@
"""$project_name: A Flower / PyTorch app."""

from collections import OrderedDict
from typing import Dict, Tuple, List

import torch
from torch.utils.data import DataLoader

import flwr as fl
from flwr.common import Metrics
from flwr.common.typing import Scalar

from flwr_datasets import FederatedDataset

from utils import Net, train, test, apply_transforms

NUM_CLIENTS = 100
NUM_ROUNDS = 10


# Flower client, adapted from Pytorch quickstart example
class FlowerClient(fl.client.NumPyClient):
def __init__(self, trainset, valset):
self.trainset = trainset
self.valset = valset

# Instantiate model
self.model = Net()

# Determine device
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device) # send model to device

def get_parameters(self, config):
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

def fit(self, parameters, config):
set_params(self.model, parameters)

# Read from config
batch, epochs = config["batch_size"], config["epochs"]

# Construct dataloader
trainloader = DataLoader(self.trainset, batch_size=batch, shuffle=True)

# Define optimizer
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)
# Train
train(self.model, trainloader, optimizer, epochs=epochs, device=self.device)

# Return local model and statistics
return self.get_parameters({}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
set_params(self.model, parameters)

# Construct dataloader
valloader = DataLoader(self.valset, batch_size=64)

# Evaluate
loss, accuracy = test(self.model, valloader, device=self.device)

# Return statistics
return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)}


def get_client_fn(dataset: FederatedDataset):
"""Return a function to construct a client.

The VirtualClientEngine will execute this function whenever a client is sampled by
the strategy to participate.
"""

def client_fn(cid: str) -> fl.client.Client:
"""Construct a FlowerClient with its own dataset partition."""

# Let's get the partition corresponding to the i-th client
client_dataset = dataset.load_partition(int(cid), "train")

# Now let's split it into train (90%) and validation (10%)
client_dataset_splits = client_dataset.train_test_split(test_size=0.1)

trainset = client_dataset_splits["train"]
valset = client_dataset_splits["test"]

# Now we apply the transform to each batch.
trainset = trainset.with_transform(apply_transforms)
valset = valset.with_transform(apply_transforms)

# Create and return client
return FlowerClient(trainset, valset).to_client()

return client_fn


def fit_config(server_round: int) -> Dict[str, Scalar]:
"""Return a configuration with static batch size and (local) epochs."""
config = {
"epochs": 1, # Number of local epochs done by clients
"batch_size": 32, # Batch size to use by clients during fit()
}
return config


def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]):
"""Set model weights from a list of NumPy ndarrays."""
params_dict = zip(model.state_dict().keys(), params)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Aggregation function for (federated) evaluation metrics, i.e. those returned by
the client's evaluate() method."""
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


# Download MNIST dataset and partition it
mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS})

# ClientApp for Flower-Next
app = fl.client.ClientApp(
client_fn=get_client_fn(mnist_fds),
)
129 changes: 129 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1 +1,130 @@
"""$project_name: A Flower / PyTorch app."""

"""$project_name: A Flower / PyTorch app."""

from collections import OrderedDict
from typing import Dict, Tuple, List

import torch
from torch.utils.data import DataLoader

import flwr as fl
from flwr.common import Metrics
from flwr.common.typing import Scalar

from datasets import Dataset
from datasets.utils.logging import disable_progress_bar
from flwr_datasets import FederatedDataset

from utils import Net, test, apply_transforms

NUM_CLIENTS = 100
NUM_ROUNDS = 10


def get_client_fn(dataset: FederatedDataset):
"""Return a function to construct a client.

The VirtualClientEngine will execute this function whenever a client is sampled by
the strategy to participate.
"""

def client_fn(cid: str) -> fl.client.Client:
"""Construct a FlowerClient with its own dataset partition."""

# Let's get the partition corresponding to the i-th client
client_dataset = dataset.load_partition(int(cid), "train")

# Now let's split it into train (90%) and validation (10%)
client_dataset_splits = client_dataset.train_test_split(test_size=0.1)

trainset = client_dataset_splits["train"]
valset = client_dataset_splits["test"]

# Now we apply the transform to each batch.
trainset = trainset.with_transform(apply_transforms)
valset = valset.with_transform(apply_transforms)

# Create and return client
return FlowerClient(trainset, valset).to_client()

return client_fn


def fit_config(server_round: int) -> Dict[str, Scalar]:
"""Return a configuration with static batch size and (local) epochs."""
config = {
"epochs": 1, # Number of local epochs done by clients
"batch_size": 32, # Batch size to use by clients during fit()
}
return config


def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]):
"""Set model weights from a list of NumPy ndarrays."""
params_dict = zip(model.state_dict().keys(), params)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Aggregation function for (federated) evaluation metrics, i.e. those returned by
the client's evaluate() method."""
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


def get_evaluate_fn(
centralized_testset: Dataset,
):
"""Return an evaluation function for centralized evaluation."""

def evaluate(
server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
):
"""Use the entire CIFAR-10 test set for evaluation."""

# Determine device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Net()
set_params(model, parameters)
model.to(device)

# Apply transform to dataset
testset = centralized_testset.with_transform(apply_transforms)

# Disable tqdm for dataset preprocessing
disable_progress_bar()

testloader = DataLoader(testset, batch_size=50)
loss, accuracy = test(model, testloader, device=device)

return loss, {"accuracy": accuracy}

return evaluate


# Download MNIST dataset and partition it
mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": NUM_CLIENTS})
centralized_testset = mnist_fds.load_full("test")

# Configure the strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.1, # Sample 10% of available clients for training
fraction_evaluate=0.05, # Sample 5% of available clients for evaluation
min_available_clients=10,
on_fit_config_fn=fit_config,
evaluate_metrics_aggregation_fn=weighted_average, # Aggregate federated metrics
evaluate_fn=get_evaluate_fn(centralized_testset), # Global evaluation function
)

# ServerApp for Flower-Next
server = fl.server.ServerApp(
config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
strategy=strategy,
)
64 changes: 64 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/utils.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.transforms import ToTensor, Normalize, Compose


# transformation to convert images to tensors and apply normalization
def apply_transforms(batch):
transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
batch["image"] = [transforms(img) for img in batch["image"]]
return batch


# Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')
class Net(nn.Module):
def __init__(self, num_classes: int = 10) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


# borrowed from Pytorch quickstart example
def train(net, trainloader, optim, epochs, device: str):
"""Train the network on the training set."""
criterion = torch.nn.CrossEntropyLoss()
net.train()
for _ in range(epochs):
for batch in trainloader:
images, labels = batch["image"].to(device), batch["label"].to(device)
optim.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optim.step()


# borrowed from Pytorch quickstart example
def test(net, testloader, device: str):
"""Validate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for data in testloader:
images, labels = data["image"].to(device), data["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy
8 changes: 7 additions & 1 deletion src/py/flwr/cli/new/templates/app/flower.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ description = ""
license = "Apache-2.0"
authors = ["The Flower Authors <hello@flower.ai>"]

[components]
[flower.components]
serverapp = "$project_name.server:app"
clientapp = "$project_name.client:app"

[flower.engine]
name = "simulation" # optional

[flower.engine.simulation.super-node]
count = 10 # optional

0 comments on commit 9a2c96a

Please sign in to comment.