Skip to content

Commit

Permalink
refactor(framework) Update sklearn template (#4293)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 8, 2024
1 parent a0775af commit b6dc30f
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 49 deletions.
1 change: 1 addition & 0 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def new(
MlFramework.HUGGINGFACE.value,
MlFramework.MLX.value,
MlFramework.TENSORFLOW.value,
MlFramework.SKLEARN.value,
MlFramework.NUMPY.value,
]
if framework_str in frameworks_with_tasks:
Expand Down
62 changes: 14 additions & 48 deletions src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,17 @@

import warnings

import numpy as np
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context
from flwr_datasets import FederatedDataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss


def get_model_parameters(model):
if model.fit_intercept:
params = [
model.coef_,
model.intercept_,
]
else:
params = [model.coef_]
return params


def set_model_params(model, params):
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model


def set_initial_params(model):
n_classes = 10 # MNIST has 10 classes
n_features = 784 # Number of features in dataset
model.classes_ = np.array([i for i in range(10)])

model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from $import_name.task import (
get_model,
get_model_params,
load_data,
set_initial_params,
set_model_params,
)


class FlowerClient(NumPyClient):
Expand All @@ -46,9 +23,6 @@ class FlowerClient(NumPyClient):
self.y_train = y_train
self.y_test = y_test

def get_parameters(self, config):
return get_model_parameters(self.model)

def fit(self, parameters, config):
set_model_params(self.model, parameters)

Expand All @@ -57,7 +31,7 @@ class FlowerClient(NumPyClient):
warnings.simplefilter("ignore")
self.model.fit(self.X_train, self.y_train)

return get_model_parameters(self.model), len(self.X_train), {}
return get_model_params(self.model), len(self.X_train), {}

def evaluate(self, parameters, config):
set_model_params(self.model, parameters)
Expand All @@ -71,21 +45,13 @@ class FlowerClient(NumPyClient):
def client_fn(context: Context):
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
dataset = fds.load_partition(partition_id, "train").with_format("numpy")

X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

# Split the on edge data: 80% train, 20% test
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
X_train, X_test, y_train, y_test = load_data(partition_id, num_partitions)

# Create LogisticRegression Model
model = LogisticRegression(
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)
penalty = context.run_config["penalty"]
local_epochs = context.run_config["local-epochs"]
model = get_model(penalty, local_epochs)

# Setting initial parameters, akin to model.compile for keras models
set_initial_params(model)
Expand Down
14 changes: 13 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from $import_name.task import get_model, get_model_params, set_initial_params


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

# Create LogisticRegression Model
penalty = context.run_config["penalty"]
local_epochs = context.run_config["local-epochs"]
model = get_model(penalty, local_epochs)

# Setting initial parameters, akin to model.compile for keras models
set_initial_params(model)

initial_parameters = ndarrays_to_parameters(get_model_params(model))

# Define strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=initial_parameters,
)
config = ServerConfig(num_rounds=num_rounds)

Expand Down
67 changes: 67 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/task.sklearn.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""$project_name: A Flower / $framework_str app."""

import numpy as np
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from sklearn.linear_model import LogisticRegression

fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
"""Load partition MNIST data."""
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="mnist",
partitioners={"train": partitioner},
)

dataset = fds.load_partition(partition_id, "train").with_format("numpy")

X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

# Split the on edge data: 80% train, 20% test
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]

return X_train, X_test, y_train, y_test


def get_model(penalty: str, local_epochs: int):

return LogisticRegression(
penalty=penalty,
max_iter=local_epochs,
warm_start=True,
)


def get_model_params(model):
if model.fit_intercept:
params = [
model.coef_,
model.intercept_,
]
else:
params = [model.coef_]
return params


def set_model_params(model, params):
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model


def set_initial_params(model):
n_classes = 10 # MNIST has 10 classes
n_features = 784 # Number of features in dataset
model.classes_ = np.array([i for i in range(10)])

model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))
2 changes: 2 additions & 0 deletions src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
penalty = "l2"
local-epochs = 1

[tool.flwr.federations]
default = "local-simulation"
Expand Down

0 comments on commit b6dc30f

Please sign in to comment.