diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 93d654379cf..43781776f78 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -20,7 +20,9 @@ import time from logging import DEBUG, INFO, WARN from pathlib import Path -from typing import Callable, ContextManager, Optional, Tuple, Union +from typing import Callable, ContextManager, Optional, Tuple, Type, Union + +from grpc import RpcError from flwr.client.client import Client from flwr.client.client_app import ClientApp @@ -36,6 +38,7 @@ ) from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature +from flwr.common.retry_invoker import RetryInvoker, exponential from .client_app import load_client_app from .grpc_client.connection import grpc_connection @@ -104,6 +107,8 @@ def _load() -> ClientApp: transport="rest" if args.rest else "grpc-rere", root_certificates=root_certificates, insecure=args.insecure, + max_retries=args.max_retries, + max_wait_time=args.max_wait_time, ) register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE) @@ -141,6 +146,22 @@ def _parse_args_run_client_app() -> argparse.ArgumentParser: default="0.0.0.0:9092", help="Server address", ) + parser.add_argument( + "--max-retries", + type=int, + default=None, + help="The maximum number of times the client will try to connect to the" + "server before giving up in case of a connection error. By default," + "it is set to None, meaning there is no limit to the number of tries.", + ) + parser.add_argument( + "--max-wait-time", + type=float, + default=None, + help="The maximum duration before the client stops trying to" + "connect to the server in case of connection error. By default, it" + "is set to None, meaning there is no limit to the total time.", + ) parser.add_argument( "--dir", default="", @@ -180,6 +201,8 @@ def start_client( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -213,6 +236,14 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. Examples -------- @@ -254,6 +285,8 @@ class `flwr.client.Client` (default: None) root_certificates=root_certificates, insecure=insecure, transport=transport, + max_retries=max_retries, + max_wait_time=max_wait_time, ) event(EventType.START_CLIENT_LEAVE) @@ -272,6 +305,8 @@ def _start_client_internal( root_certificates: Optional[Union[bytes, str]] = None, insecure: Optional[bool] = None, transport: Optional[str] = None, + max_retries: Optional[int] = None, + max_wait_time: Optional[float] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -299,7 +334,7 @@ class `flwr.client.Client` (default: None) The PEM-encoded root certificates as a byte string or a path string. If provided, a secure connection using the certificates will be established to an SSL-enabled Flower server. - insecure : bool (default: True) + insecure : Optional[bool] (default: None) Starts an insecure gRPC connection when True. Enables HTTPS connection when False, using system certificates if `root_certificates` is None. transport : Optional[str] (default: None) @@ -307,6 +342,14 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + max_retries: Optional[int] (default: None) + The maximum number of times the client will try to connect to the + server before giving up in case of a connection error. If set to None, + there is no limit to the number of tries. + max_wait_time: Optional[float] (default: None) + The maximum duration before the client stops trying to + connect to the server in case of connection error. + If set to None, there is no limit to the total time. """ if insecure is None: insecure = root_certificates is None @@ -338,7 +381,45 @@ def _load_client_app() -> ClientApp: # Both `client` and `client_fn` must not be used directly # Initialize connection context manager - connection, address = _init_connection(transport, server_address) + connection, address, connection_error_type = _init_connection( + transport, server_address + ) + + retry_invoker = RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=connection_error_type, + max_tries=max_retries, + max_time=max_wait_time, + on_giveup=lambda retry_state: ( + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_success=lambda retry_state: ( + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_backoff=lambda retry_state: ( + log(WARN, "Connection attempt failed, retrying...") + if retry_state.tries == 1 + else log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + ), + ) node_state = NodeState() @@ -347,6 +428,7 @@ def _load_client_app() -> ClientApp: with connection( address, insecure, + retry_invoker, grpc_max_message_length, root_certificates, ) as conn: @@ -509,7 +591,7 @@ def start_numpy_client( def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ Callable[ - [str, bool, int, Union[bytes, str, None]], + [str, bool, RetryInvoker, int, Union[bytes, str, None]], ContextManager[ Tuple[ Callable[[], Optional[Message]], @@ -520,6 +602,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ ], ], str, + Type[Exception], ]: # Parse IP address parsed_address = parse_address(server_address) @@ -535,6 +618,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ # Use either gRPC bidirectional streaming or REST request/response if transport == TRANSPORT_TYPE_REST: try: + from requests.exceptions import ConnectionError as RequestsConnectionError + from .rest_client.connection import http_request_response except ModuleNotFoundError: sys.exit(MISSING_EXTRA_REST) @@ -543,14 +628,14 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ "When using the REST API, please provide `https://` or " "`http://` before the server address (e.g. `http://127.0.0.1:8080`)" ) - connection = http_request_response + connection, error_type = http_request_response, RequestsConnectionError elif transport == TRANSPORT_TYPE_GRPC_RERE: - connection = grpc_request_response + connection, error_type = grpc_request_response, RpcError elif transport == TRANSPORT_TYPE_GRPC_BIDI: - connection = grpc_connection + connection, error_type = grpc_connection, RpcError else: raise ValueError( f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})" ) - return connection, address + return connection, address, error_type diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 3561626dcb3..ddbb5336b2a 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -39,6 +39,7 @@ ) from flwr.common.grpc import create_channel from flwr.common.logger import log +from flwr.common.retry_invoker import RetryInvoker from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, Reason, @@ -62,6 +63,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: def grpc_connection( # pylint: disable=R0915 server_address: str, insecure: bool, + retry_invoker: RetryInvoker, # pylint: disable=unused-argument max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ @@ -80,6 +82,11 @@ def grpc_connection( # pylint: disable=R0915 The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or `"[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + Unused argument present for compatibilty. max_message_length : int The maximum length of gRPC messages that can be exchanged with the Flower server. The default should be sufficient for most models. Users who train diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index 30bff068b60..28e03979fd6 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -26,6 +26,7 @@ from flwr.common import ConfigsRecord, Message, Metadata, RecordSet from flwr.common import recordset_compat as compat from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES +from flwr.common.retry_invoker import RetryInvoker, exponential from flwr.common.typing import Code, GetPropertiesRes, Status from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -127,7 +128,16 @@ def test_integration_connection() -> None: def run_client() -> int: messages_received: int = 0 - with grpc_connection(server_address=f"[::]:{port}", insecure=True) as conn: + with grpc_connection( + server_address=f"[::]:{port}", + insecure=True, + retry_invoker=RetryInvoker( + wait_factory=exponential, + recoverable_exceptions=grpc.RpcError, + max_tries=1, + max_time=None, + ), + ) as conn: receive, send, _, _ = conn # Setup processing loop diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 00b7a864c5d..e6e22998b94 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log, warn_experimental_feature from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: def grpc_request_response( server_address: str, insecure: bool, + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, ) -> Iterator[ @@ -72,6 +74,13 @@ def grpc_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Starts an insecure gRPC connection when True. Enables HTTPS connection + when False, using system certificates if `root_certificates` is None. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after gRPC errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -113,7 +122,8 @@ def grpc_request_response( def create_node() -> None: """Set create_node.""" create_node_request = CreateNodeRequest() - create_node_response = stub.CreateNode( + create_node_response = retry_invoker.invoke( + stub.CreateNode, request=create_node_request, ) node_store[KEY_NODE] = create_node_response.node @@ -127,7 +137,7 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_request = DeleteNodeRequest(node=node) - stub.DeleteNode(request=delete_node_request) + retry_invoker.invoke(stub.DeleteNode, request=delete_node_request) del node_store[KEY_NODE] @@ -141,7 +151,7 @@ def receive() -> Optional[Message]: # Request instructions (task) from server request = PullTaskInsRequest(node=node) - response = stub.PullTaskIns(request=request) + response = retry_invoker.invoke(stub.PullTaskIns, request=request) # Get the current TaskIns task_ins: Optional[TaskIns] = get_task_ins(response) @@ -185,7 +195,7 @@ def send(message: Message) -> None: # Serialize ProtoBuf to bytes request = PushTaskResRequest(task_res_list=[task_res]) - _ = stub.PushTaskRes(request) + _ = retry_invoker.invoke(stub.PushTaskRes, request) state[KEY_METADATA] = None diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index c637475551e..d2cc71ba3b3 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -27,6 +27,7 @@ from flwr.common.constant import MISSING_EXTRA_REST from flwr.common.logger import log from flwr.common.message import Message, Metadata +from flwr.common.retry_invoker import RetryInvoker from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -61,6 +62,7 @@ def http_request_response( server_address: str, insecure: bool, # pylint: disable=unused-argument + retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[ Union[bytes, str] @@ -84,6 +86,12 @@ def http_request_response( The IPv6 address of the server with `http://` or `https://`. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"http://[::]:8080"`. + insecure : bool + Unused argument present for compatibilty. + retry_invoker: RetryInvoker + `RetryInvoker` object that will try to reconnect the client to the server + after REST connection errors. If None, the client will only try to + reconnect once after a failure. max_message_length : int Ignored, only present to preserve API-compatibility. root_certificates : Optional[Union[bytes, str]] (default: None) @@ -134,7 +142,8 @@ def create_node() -> None: create_node_req_proto = CreateNodeRequest() create_node_req_bytes: bytes = create_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_CREATE_NODE}", headers={ "Accept": "application/protobuf", @@ -177,7 +186,8 @@ def delete_node() -> None: node: Node = cast(Node, node_store[KEY_NODE]) delete_node_req_proto = DeleteNodeRequest(node=node) delete_node_req_req_bytes: bytes = delete_node_req_proto.SerializeToString() - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_DELETE_NODE}", headers={ "Accept": "application/protobuf", @@ -218,7 +228,8 @@ def receive() -> Optional[Message]: pull_task_ins_req_bytes: bytes = pull_task_ins_req_proto.SerializeToString() # Request instructions (task) from server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PULL_TASK_INS}", headers={ "Accept": "application/protobuf", @@ -298,7 +309,8 @@ def send(message: Message) -> None: ) # Send ClientMessage to server - res = requests.post( + res = retry_invoker.invoke( + requests.post, url=f"{base_url}/{PATH_PUSH_TASK_RES}", headers={ "Accept": "application/protobuf",