-
Notifications
You must be signed in to change notification settings - Fork 858
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(examples) Update opacus example (#4262)
Co-authored-by: jafermarq <javier@flower.ai>
- Loading branch information
1 parent
20dfa78
commit 74d64dd
Showing
8 changed files
with
285 additions
and
231 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""opacus: Training with Sample-Level Differential Privacy using Opacus Privacy Engine.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
"""opacus: Training with Sample-Level Differential Privacy using Opacus Privacy Engine.""" | ||
|
||
import warnings | ||
|
||
import torch | ||
from opacus import PrivacyEngine | ||
from opacus_fl.task import Net, get_weights, load_data, set_weights, test, train | ||
import logging | ||
|
||
from flwr.client import ClientApp, NumPyClient | ||
from flwr.common import Context | ||
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
|
||
class FlowerClient(NumPyClient): | ||
def __init__( | ||
self, | ||
train_loader, | ||
test_loader, | ||
target_delta, | ||
noise_multiplier, | ||
max_grad_norm, | ||
) -> None: | ||
super().__init__() | ||
self.model = Net() | ||
self.train_loader = train_loader | ||
self.test_loader = test_loader | ||
self.target_delta = target_delta | ||
self.noise_multiplier = noise_multiplier | ||
self.max_grad_norm = max_grad_norm | ||
|
||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
def fit(self, parameters, config): | ||
model = self.model | ||
set_weights(model, parameters) | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) | ||
|
||
privacy_engine = PrivacyEngine(secure_mode=False) | ||
( | ||
model, | ||
optimizer, | ||
self.train_loader, | ||
) = privacy_engine.make_private( | ||
module=model, | ||
optimizer=optimizer, | ||
data_loader=self.train_loader, | ||
noise_multiplier=self.noise_multiplier, | ||
max_grad_norm=self.max_grad_norm, | ||
) | ||
|
||
epsilon = train( | ||
model, | ||
self.train_loader, | ||
privacy_engine, | ||
optimizer, | ||
self.target_delta, | ||
device=self.device, | ||
) | ||
|
||
if epsilon is not None: | ||
print(f"Epsilon value for delta={self.target_delta} is {epsilon:.2f}") | ||
else: | ||
print("Epsilon value not available.") | ||
|
||
return (get_weights(model), len(self.train_loader.dataset), {}) | ||
|
||
def evaluate(self, parameters, config): | ||
set_weights(self.model, parameters) | ||
loss, accuracy = test(self.model, self.test_loader, self.device) | ||
return loss, len(self.test_loader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
def client_fn(context: Context): | ||
partition_id = context.node_config["partition-id"] | ||
noise_multiplier = 1.0 if partition_id % 2 == 0 else 1.5 | ||
|
||
train_loader, test_loader = load_data( | ||
partition_id=partition_id, num_partitions=context.node_config["num-partitions"] | ||
) | ||
return FlowerClient( | ||
train_loader, | ||
test_loader, | ||
context.run_config["target-delta"], | ||
noise_multiplier, | ||
context.run_config["max-grad-norm"], | ||
).to_client() | ||
|
||
|
||
app = ClientApp(client_fn=client_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""opacus: Training with Sample-Level Differential Privacy using Opacus Privacy Engine.""" | ||
|
||
import logging | ||
from typing import List, Tuple | ||
|
||
from opacus_fl.task import Net, get_weights | ||
|
||
from flwr.common import Context, Metrics, ndarrays_to_parameters | ||
from flwr.server import ServerApp, ServerAppComponents, ServerConfig | ||
from flwr.server.strategy import FedAvg | ||
|
||
# Opacus logger seems to change the flwr logger to DEBUG level. Set back to INFO | ||
logging.getLogger("flwr").setLevel(logging.INFO) | ||
|
||
|
||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | ||
examples = [num_examples for num_examples, _ in metrics] | ||
return {"accuracy": sum(accuracies) / sum(examples)} | ||
|
||
|
||
def server_fn(context: Context) -> ServerAppComponents: | ||
num_rounds = context.run_config["num-server-rounds"] | ||
|
||
ndarrays = get_weights(Net()) | ||
parameters = ndarrays_to_parameters(ndarrays) | ||
|
||
strategy = FedAvg( | ||
evaluate_metrics_aggregation_fn=weighted_average, | ||
initial_parameters=parameters, | ||
) | ||
config = ServerConfig(num_rounds=num_rounds) | ||
|
||
return ServerAppComponents(config=config, strategy=strategy) | ||
|
||
|
||
app = ServerApp(server_fn=server_fn) |
Oops, something went wrong.