diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py index ba9cdb11d69..67c1bce99c5 100644 --- a/examples/app-pytorch/server_custom.py +++ b/examples/app-pytorch/server_custom.py @@ -103,15 +103,19 @@ def main(driver: Driver, context: Context) -> None: all_replies: List[Message] = [] while True: replies = driver.pull_messages(message_ids=message_ids) - print(f"Got {len(replies)} results") + for res in replies: + print(f"Got 1 {'result' if res.has_content() else 'error'}") all_replies += replies if len(all_replies) == len(message_ids): break + print("Pulling messages...") time.sleep(3) - # Collect correct results + # Filter correct results all_fitres = [ - recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies + recordset_to_fitres(msg.content, keep_input=True) + for msg in all_replies + if msg.has_content() ] print(f"Received {len(all_fitres)} results") @@ -128,16 +132,21 @@ def main(driver: Driver, context: Context) -> None: ) metrics_results.append((fitres.num_examples, fitres.metrics)) - # Aggregate parameters (FedAvg) - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) - parameters = parameters_aggregated + if len(weights_results) > 0: + # Aggregate parameters (FedAvg) + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated - # Aggregate metrics - metrics_aggregated = weighted_average(metrics_results) - history.add_metrics_distributed_fit( - server_round=server_round, metrics=metrics_aggregated - ) - print("Round ", server_round, " metrics: ", metrics_aggregated) + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) + print("Round ", server_round, " metrics: ", metrics_aggregated) + else: + print( + f"Round {server_round} got {len(weights_results)} results. Skipping aggregation..." + ) # Slow down the start of the next round time.sleep(sleep_time) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index c8287afc0fd..d4bd8e2e39e 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -14,11 +14,10 @@ # ============================================================================== """Flower client app.""" - import argparse import sys import time -from logging import DEBUG, INFO, WARN +from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path from typing import Callable, ContextManager, Optional, Tuple, Type, Union @@ -38,6 +37,7 @@ ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature +from flwr.common.message import Error from flwr.common.object_ref import load_app, validate from flwr.common.retry_invoker import RetryInvoker, exponential @@ -482,32 +482,43 @@ def _load_client_app() -> ClientApp: # Retrieve context for this run context = node_state.retrieve_context(run_id=message.metadata.run_id) - # Load ClientApp instance - client_app: ClientApp = load_client_app_fn() + # Create an error reply message that will never be used to prevent + # the used-before-assignment linting error + reply_message = message.create_error_reply( + error=Error(code=0, reason="Unknown") + ) - # Handle task message - out_message = client_app(message=message, context=context) + # Handle app loading and task message + try: + # Load ClientApp instance + client_app: ClientApp = load_client_app_fn() - # Update node state - node_state.update_context( - run_id=message.metadata.run_id, - context=context, - ) + reply_message = client_app(message=message, context=context) + # Update node state + node_state.update_context( + run_id=message.metadata.run_id, + context=context, + ) + except Exception as ex: # pylint: disable=broad-exception-caught + log(ERROR, "ClientApp raised an exception", exc_info=ex) + + # Legacy grpc-bidi + if transport in ["grpc-bidi", None]: + # Raise exception, crash process + raise ex + + # Don't update/change NodeState + + # Create error message + # Reason example: ":<'division by zero'>" + reason = str(type(ex)) + ":<'" + str(ex) + "'>" + reply_message = message.create_error_reply( + error=Error(code=0, reason=reason) + ) # Send - send(out_message) - log( - INFO, - "[RUN %s, ROUND %s]", - out_message.metadata.run_id, - out_message.metadata.group_id, - ) - log( - INFO, - "Sent: %s reply to message %s", - out_message.metadata.message_type, - message.metadata.message_id, - ) + send(reply_message) + log(INFO, "Sent reply") # Unregister node if delete_node is not None: diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index 7fdc07d620f..58341c7bb8f 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -170,8 +170,24 @@ def _send_receive_recordset( ) if len(task_res_list) == 1: task_res = task_res_list[0] + + # This will raise an Exception if task_res carries an `error` + validate_task_res(task_res=task_res) + return serde.recordset_from_proto(task_res.task.recordset) if timeout is not None and time.time() > start_time + timeout: raise RuntimeError("Timeout reached") time.sleep(SLEEP_TIME) + + +def validate_task_res( + task_res: task_pb2.TaskRes, # pylint: disable=E1101 +) -> None: + """Validate if a TaskRes is empty or not.""" + if not task_res.HasField("task"): + raise ValueError("Invalid TaskRes, field `task` missing") + if task_res.task.HasField("error"): + raise ValueError("Exception during client-side task execution") + if not task_res.task.HasField("recordset"): + raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing") diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index 3494049c106..57b35fc61a3 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -38,9 +38,14 @@ Properties, Status, ) -from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 - -from .driver_client_proxy import DriverClientProxy +from flwr.proto import ( # pylint: disable=E0611 + driver_pb2, + error_pb2, + node_pb2, + recordset_pb2, + task_pb2, +) +from flwr.server.compat.driver_client_proxy import DriverClientProxy, validate_task_res MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") @@ -243,3 +248,77 @@ def test_evaluate(self) -> None: # Assert assert 0.0 == evaluate_res.loss assert 0 == evaluate_res.num_examples + + def test_validate_task_res_valid(self) -> None: + """Test valid TaskRes.""" + metrics_record = recordset_pb2.MetricsRecord( # pylint: disable=E1101 + data={ + "loss": recordset_pb2.MetricsRecordValue( # pylint: disable=E1101 + double=1.0 + ) + } + ) + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + recordset=recordset_pb2.RecordSet( # pylint: disable=E1101 + parameters={}, + metrics={"loss": metrics_record}, + configs={}, + ) + ), + ) + + # Execute & assert + try: + validate_task_res(task_res=task_res) + except ValueError: + self.fail() + + def test_validate_task_res_missing_task(self) -> None: + """Test invalid TaskRes (missing task).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) + + def test_validate_task_res_missing_recordset(self) -> None: + """Test invalid TaskRes (missing recordset).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task(), # pylint: disable=E1101 + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) + + def test_validate_task_res_missing_content(self) -> None: + """Test invalid TaskRes (missing content).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + error=error_pb2.Error( # pylint: disable=E1101 + code=0, + reason="Some reason", + ) + ), + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res)