From 5432cfd54d0086f2114c9f035545d727bc4325a8 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 22 Jan 2024 13:14:17 +0100 Subject: [PATCH 01/15] Handle FlowerCallable exceptions --- src/py/flwr/client/app.py | 22 +++++++++++++++++++--- src/py/flwr/driver/driver_client_proxy.py | 4 ++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ae5beeae07d..ee4d90db695 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,7 +18,7 @@ import argparse import sys import time -from logging import INFO, WARN +from logging import ERROR, INFO, WARN from pathlib import Path from typing import Callable, ContextManager, Optional, Tuple, Union @@ -35,7 +35,7 @@ TRANSPORT_TYPES, ) from flwr.common.logger import log, warn_experimental_feature -from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 +from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from .flower import load_flower_callable from .grpc_client.connection import grpc_connection @@ -362,7 +362,23 @@ def _load_app() -> Flower: task_ins=task_ins, state=node_state.retrieve_runstate(run_id=task_ins.run_id), ) - bwd_msg: Bwd = app(fwd=fwd_msg) + try: + bwd_msg: Bwd = app(fwd=fwd_msg) + except Exception as ex: + log(ERROR, "FlowerCallable raised the following exception:\n\n", ex) + + # Don't update/change RunState + # Return empty TaskRes + error_task_res = TaskRes( + task_id="", + group_id="", + run_id=0, + task=Task( + ancestry=[], + ), + ) + send(error_task_res) + continue # Update node state node_state.update_runstate( diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index 8b2e51c17ea..75bbba6d5e3 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -155,6 +155,10 @@ def _send_receive_msg( ) if len(task_res_list) == 1: task_res = task_res_list[0] + if not task_res.HasField("task") or not task_res.HasField( + "legacy_client_message" + ): + raise ValueError("Exception during client-side task execution") return serde.client_message_from_proto( # type: ignore task_res.task.legacy_client_message ) From d3531d6fef03d257a079fe94f1716bffaecf3e9a Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Tue, 30 Jan 2024 21:49:16 +0100 Subject: [PATCH 02/15] Add tests --- .../flwr/driver/driver_client_proxy_test.py | 84 ++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index 4e9a02a6cbf..9eaa7fcc021 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -43,8 +43,13 @@ Properties, Status, ) -from flwr.driver.driver_client_proxy import DriverClientProxy -from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611 +from flwr.driver.driver_client_proxy import DriverClientProxy, validate_task_res +from flwr.proto import ( # pylint: disable=E0611 + driver_pb2, + node_pb2, + recordset_pb2, + task_pb2, +) MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") @@ -245,3 +250,78 @@ def test_evaluate(self) -> None: # Assert assert 0.0 == evaluate_res.loss assert 0 == evaluate_res.num_examples + + def test_validate_task_res_valid(self) -> None: + """Test valid TaskRes.""" + metrics_record = recordset_pb2.MetricsRecord( # pylint: disable=E1101 + data={ + "loss": recordset_pb2.MetricsRecordValue( # pylint: disable=E1101 + double=1.0 + ) + } + ) + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + recordset=recordset_pb2.RecordSet( # pylint: disable=E1101 + parameters={}, + metrics={"loss": metrics_record}, + configs={}, + ) + ), + ) + + # Execute & assert + try: + validate_task_res(task_res=task_res) + except ValueError: + self.fail() + + def test_validate_task_res_missing_task(self) -> None: + """Test invalid TaskRes (missing task).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) + + def test_validate_task_res_missing_recordset(self) -> None: + """Test invalid TaskRes (missing recordset).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task(), # pylint: disable=E1101 + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) + + def test_validate_task_res_missing_content(self) -> None: + """Test invalid TaskRes (missing content).""" + # Prepare + task_res = task_pb2.TaskRes( # pylint: disable=E1101 + task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", + group_id="", + run_id=0, + task=task_pb2.Task( # pylint: disable=E1101 + recordset=recordset_pb2.RecordSet( # pylint: disable=E1101 + parameters={}, + metrics={}, + configs={}, + ) + ), + ) + + # Execute & assert + with self.assertRaises(ValueError): + validate_task_res(task_res=task_res) From cb335278f9b0b09e9cbc27135f308b400d21dac0 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 2 Feb 2024 17:11:00 +0100 Subject: [PATCH 03/15] Update src/py/flwr/client/app.py --- src/py/flwr/client/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 054f56b841e..99ff5f29b0f 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -386,7 +386,7 @@ def _load_app() -> Flower: except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, "FlowerCallable raised an exception", exc_info=ex) - # Don't update/change RunState + # Don't update/change NodeState # Return empty Message error_out_message = Message( From 8b9c8545276cb76bb72adc1e0ad69efed5601980 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 23 Feb 2024 20:38:03 +0100 Subject: [PATCH 04/15] Fix imports --- src/py/flwr/client/app.py | 7 +------ src/py/flwr/server/compat/driver_client_proxy_test.py | 2 -- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index a50623b26af..7dbf5675066 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -25,7 +25,7 @@ from flwr.client.client import Client from flwr.client.clientapp import ClientApp from flwr.client.typing import ClientFn -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, RecordSet, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -35,13 +35,8 @@ TRANSPORT_TYPES, ) from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature - - -from flwr.common.message import Message -from flwr.common.recordset import RecordSet from flwr.common.serde import message_to_taskres - from .clientapp import load_client_app from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index 0797a887ad8..3523a3c6fee 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -51,8 +51,6 @@ ) from flwr.server.compat.driver_client_proxy import DriverClientProxy, validate_task_res -from .driver_client_proxy import DriverClientProxy - MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np") CLIENT_PROPERTIES = cast(Properties, {"tensor_type": "numpy.ndarray"}) From a6e93fd245b4a8d3bdab54752a290e21f0429e8d Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 23 Feb 2024 20:46:50 +0100 Subject: [PATCH 05/15] Fix argument --- src/py/flwr/client/app.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 7dbf5675066..21033f349a6 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -388,12 +388,9 @@ def _load_client_app() -> ClientApp: # Return empty Message error_out_message = Message( metadata=message.metadata, - message=RecordSet(), + content=RecordSet(), ) - - # Construct TaskRes from out_message - error_task_res = message_to_taskres(error_out_message) - send(error_task_res) + send(error_out_message) continue # Update node state From c19e0b46fd810b7ada23f2c1be294e707a1a05c1 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 23 Feb 2024 20:55:34 +0100 Subject: [PATCH 06/15] Fix imports --- src/py/flwr/client/app.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 21033f349a6..e12560cd82b 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -35,7 +35,6 @@ TRANSPORT_TYPES, ) from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature -from flwr.common.serde import message_to_taskres from .clientapp import load_client_app from .grpc_client.connection import grpc_connection From 2f8167eb1b3fbee15d3dd5877f514e3d162cdfe2 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 1 Mar 2024 17:51:34 +0100 Subject: [PATCH 07/15] Use message.error --- src/py/flwr/client/app.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index e79c8ed3695..71c7ae0ff80 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -14,7 +14,6 @@ # ============================================================================== """Flower client app.""" - import argparse import sys import time @@ -25,7 +24,7 @@ from flwr.client.client import Client from flwr.client.client_app import ClientApp from flwr.client.typing import ClientFn -from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, RecordSet, event +from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, Message, event from flwr.common.address import parse_address from flwr.common.constant import ( MISSING_EXTRA_REST, @@ -36,6 +35,7 @@ ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature +from flwr.common.message import Error from .client_app import load_client_app from .grpc_client.connection import grpc_connection @@ -385,12 +385,16 @@ def _load_client_app() -> ClientApp: log(ERROR, "ClientApp raised an exception", exc_info=ex) # Don't update/change NodeState - # Return empty Message - error_out_message = Message( - metadata=message.metadata, - content=RecordSet(), - ) + + # Create error message + # Reason example: ":<'division by zero'>" + reason = str(type(ex)) + ":<'" + str(ex) + "'>" + error = Error(code=0, reason=reason) + error_out_message = message.create_error_reply(error=error) + + # Return error message send(error_out_message) + continue # Update node state From eac0dbbe4ac27c78d7125db5f95e4c20a1ff87aa Mon Sep 17 00:00:00 2001 From: jafermarq Date: Fri, 1 Mar 2024 17:08:43 +0000 Subject: [PATCH 08/15] small fix --- src/py/flwr/client/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 71c7ae0ff80..21d3b7837a3 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -390,7 +390,7 @@ def _load_client_app() -> ClientApp: # Reason example: ":<'division by zero'>" reason = str(type(ex)) + ":<'" + str(ex) + "'>" error = Error(code=0, reason=reason) - error_out_message = message.create_error_reply(error=error) + error_out_message = message.create_error_reply(error=error, ttl="") # Return error message send(error_out_message) From 7f634d5776404a2a28bf34cdb201e0e2f370ba70 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 1 Mar 2024 19:09:56 +0100 Subject: [PATCH 09/15] Update task validation --- src/py/flwr/server/compat/driver_client_proxy.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/compat/driver_client_proxy.py b/src/py/flwr/server/compat/driver_client_proxy.py index fccdd914072..8067041e8fd 100644 --- a/src/py/flwr/server/compat/driver_client_proxy.py +++ b/src/py/flwr/server/compat/driver_client_proxy.py @@ -168,7 +168,10 @@ def _send_receive_recordset( ) if len(task_res_list) == 1: task_res = task_res_list[0] + + # This will raise an Exception if task_res carries an `error` validate_task_res(task_res=task_res) + return serde.recordset_from_proto(task_res.task.recordset) if timeout is not None and time.time() > start_time + timeout: @@ -182,13 +185,7 @@ def validate_task_res( """Validate if a TaskRes is empty or not.""" if not task_res.HasField("task"): raise ValueError("Invalid TaskRes, field `task` missing") - if not task_res.task.HasField("recordset"): - raise ValueError("Invalid Task, field `recordset` missing") - - rs = task_res.task.recordset - if ( - (not rs.parameters.keys()) - and (not rs.metrics.keys()) - and (not rs.configs.keys()) - ): + if task_res.task.HasField("error"): raise ValueError("Exception during client-side task execution") + if not task_res.task.HasField("recordset"): + raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing") From 37aa3105e473f2d0d36012c07a51e08479f8867e Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Fri, 1 Mar 2024 22:26:01 +0100 Subject: [PATCH 10/15] Fix test --- src/py/flwr/server/compat/driver_client_proxy_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index 8ec76bfcb45..fb0313ebf69 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -45,6 +45,7 @@ ) from flwr.proto import ( # pylint: disable=E0611 driver_pb2, + error_pb2, node_pb2, recordset_pb2, task_pb2, @@ -316,10 +317,9 @@ def test_validate_task_res_missing_content(self) -> None: group_id="", run_id=0, task=task_pb2.Task( # pylint: disable=E1101 - recordset=recordset_pb2.RecordSet( # pylint: disable=E1101 - parameters={}, - metrics={}, - configs={}, + error=error_pb2.Error( # pylint: disable=E1101 + code=0, + reason="Some reason", ) ), ) From 449ac04b7c4d4aa300d653eeba0336d0e4b5f1e8 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 6 Mar 2024 22:39:02 +0100 Subject: [PATCH 11/15] Complete Handling of `ClientApp` exception (#3067) --- src/py/flwr/client/app.py | 31 ++++++++++++--------------- src/py/flwr/server/utils/validator.py | 14 ++++++++---- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 21d3b7837a3..d797c65946c 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -375,12 +375,18 @@ def _load_client_app() -> ClientApp: # Retrieve context for this run context = node_state.retrieve_context(run_id=message.metadata.run_id) - # Load ClientApp instance - client_app: ClientApp = load_client_app_fn() - - # Handle task message + # Handle app loading and task message try: + + # Load ClientApp instance + client_app: ClientApp = load_client_app_fn() + out_message = client_app(message=message, context=context) + # Update node state + node_state.update_context( + run_id=message.metadata.run_id, + context=context, + ) except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, "ClientApp raised an exception", exc_info=ex) @@ -390,21 +396,12 @@ def _load_client_app() -> ClientApp: # Reason example: ":<'division by zero'>" reason = str(type(ex)) + ":<'" + str(ex) + "'>" error = Error(code=0, reason=reason) - error_out_message = message.create_error_reply(error=error, ttl="") + out_message = message.create_error_reply(error=error, ttl="") - # Return error message - send(error_out_message) + finally: - continue - - # Update node state - node_state.update_context( - run_id=message.metadata.run_id, - context=context, - ) - - # Send - send(out_message) + # Send + send(out_message) # Unregister node if delete_node is not None: diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index f9b271beafd..846217b085a 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -66,8 +66,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str # Content check if tasks_ins_res.task.task_type == "": validation_errors.append("`task_type` MUST be set") - if not tasks_ins_res.task.HasField("recordset"): - validation_errors.append("`recordset` MUST be set") + if not ( + tasks_ins_res.task.HasField("recordset") + ^ tasks_ins_res.task.HasField("error") + ): + validation_errors.append("Either `recordset` or `error` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) != 0: @@ -106,8 +109,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str # Content check if tasks_ins_res.task.task_type == "": validation_errors.append("`task_type` MUST be set") - if not tasks_ins_res.task.HasField("recordset"): - validation_errors.append("`recordset` MUST be set") + if not ( + tasks_ins_res.task.HasField("recordset") + ^ tasks_ins_res.task.HasField("error") + ): + validation_errors.append("Either `recordset` or `error` MUST be set") # Ancestors if len(tasks_ins_res.task.ancestry) == 0: From ef1287c46982dae84e464e03c4e10e83a5dc5071 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 7 Mar 2024 23:00:40 +0100 Subject: [PATCH 12/15] Fix lint issues --- src/py/flwr/client/app.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index e96629cc239..1da00d83b00 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -457,13 +457,18 @@ def _load_client_app() -> ClientApp: # Retrieve context for this run context = node_state.retrieve_context(run_id=message.metadata.run_id) + # Create an error reply message that will never be used to prevent + # the used-before-assignment linting error + reply_message = message.create_error_reply( + error=Error(code=0, reason="Unknown"), ttl=message.metadata.ttl + ) + # Handle app loading and task message try: - # Load ClientApp instance client_app: ClientApp = load_client_app_fn() - out_message = client_app(message=message, context=context) + reply_message = client_app(message=message, context=context) # Update node state node_state.update_context( run_id=message.metadata.run_id, @@ -477,13 +482,13 @@ def _load_client_app() -> ClientApp: # Create error message # Reason example: ":<'division by zero'>" reason = str(type(ex)) + ":<'" + str(ex) + "'>" - error = Error(code=0, reason=reason) - out_message = message.create_error_reply(error=error, ttl="") + reply_message = message.create_error_reply( + error=Error(code=0, reason=reason), ttl=message.metadata.ttl + ) finally: - # Send - send(out_message) + send(reply_message) # Unregister node if delete_node is not None: From 2f39cce04bf546174f0e72fb5c8ee1313b3a3842 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 20 Mar 2024 15:44:26 +0100 Subject: [PATCH 13/15] server_custom handles pulled errors --- examples/app-pytorch/server_custom.py | 33 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/app-pytorch/server_custom.py b/examples/app-pytorch/server_custom.py index 0c2851e2afe..51ac8a5c006 100644 --- a/examples/app-pytorch/server_custom.py +++ b/examples/app-pytorch/server_custom.py @@ -102,15 +102,19 @@ def main(driver: Driver, context: Context) -> None: all_replies: List[Message] = [] while True: replies = driver.pull_messages(message_ids=message_ids) - print(f"Got {len(replies)} results") + for res in replies: + print(f"Got 1 {'result' if res.has_content() else 'error'}") all_replies += replies if len(all_replies) == len(message_ids): break + print("Pulling messages...") time.sleep(3) - # Collect correct results + # Filter correct results all_fitres = [ - recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies + recordset_to_fitres(msg.content, keep_input=True) + for msg in all_replies + if msg.has_content() ] print(f"Received {len(all_fitres)} results") @@ -127,16 +131,21 @@ def main(driver: Driver, context: Context) -> None: ) metrics_results.append((fitres.num_examples, fitres.metrics)) - # Aggregate parameters (FedAvg) - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) - parameters = parameters_aggregated + if len(weights_results) > 0: + # Aggregate parameters (FedAvg) + parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + parameters = parameters_aggregated - # Aggregate metrics - metrics_aggregated = weighted_average(metrics_results) - history.add_metrics_distributed_fit( - server_round=server_round, metrics=metrics_aggregated - ) - print("Round ", server_round, " metrics: ", metrics_aggregated) + # Aggregate metrics + metrics_aggregated = weighted_average(metrics_results) + history.add_metrics_distributed_fit( + server_round=server_round, metrics=metrics_aggregated + ) + print("Round ", server_round, " metrics: ", metrics_aggregated) + else: + print( + f"Round {server_round} got {len(weights_results)} results. Skipping aggregation..." + ) # Slow down the start of the next round time.sleep(sleep_time) From 2ad4e853714a46874c2d2ab13bc5dfad0c2e70e7 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 28 Mar 2024 11:56:55 +0100 Subject: [PATCH 14/15] raise ClientApp exception for `grpc-bidi` clients --- src/py/flwr/client/app.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 4a632b3bfe4..304a0b1f1de 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -502,6 +502,11 @@ def _load_client_app() -> ClientApp: except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, "ClientApp raised an exception", exc_info=ex) + # Legacy grpc-bidi + if transport in ["grpc-bidi", None]: + # Rasie exception, crashes process + raise ex + # Don't update/change NodeState # Create error message @@ -511,10 +516,9 @@ def _load_client_app() -> ClientApp: error=Error(code=0, reason=reason), ttl=message.metadata.ttl ) - finally: - # Send - send(reply_message) - log(INFO, "Sent reply") + # Send + send(reply_message) + log(INFO, "Sent reply") # Unregister node if delete_node is not None: From 927ac3369a6489a916c3966a193e2b4fae04469d Mon Sep 17 00:00:00 2001 From: Javier Date: Thu, 28 Mar 2024 11:15:55 +0000 Subject: [PATCH 15/15] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- src/py/flwr/client/app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 304a0b1f1de..d4bd8e2e39e 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -485,7 +485,7 @@ def _load_client_app() -> ClientApp: # Create an error reply message that will never be used to prevent # the used-before-assignment linting error reply_message = message.create_error_reply( - error=Error(code=0, reason="Unknown"), ttl=message.metadata.ttl + error=Error(code=0, reason="Unknown") ) # Handle app loading and task message @@ -504,7 +504,7 @@ def _load_client_app() -> ClientApp: # Legacy grpc-bidi if transport in ["grpc-bidi", None]: - # Rasie exception, crashes process + # Raise exception, crash process raise ex # Don't update/change NodeState @@ -513,7 +513,7 @@ def _load_client_app() -> ClientApp: # Reason example: ":<'division by zero'>" reason = str(type(ex)) + ":<'" + str(ex) + "'>" reply_message = message.create_error_reply( - error=Error(code=0, reason=reason), ttl=message.metadata.ttl + error=Error(code=0, reason=reason) ) # Send