Skip to content

Commit

Permalink
Add Flower Client App connection error handling (#2969)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <daniel@flower.ai>
  • Loading branch information
charlesbvll and danieljanes authored Mar 5, 2024
1 parent 2eb2127 commit 1298298
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 17 deletions.
101 changes: 93 additions & 8 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import time
from logging import DEBUG, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Optional, Tuple, Union
from typing import Callable, ContextManager, Optional, Tuple, Type, Union

from grpc import RpcError

from flwr.client.client import Client
from flwr.client.client_app import ClientApp
Expand All @@ -36,6 +38,7 @@
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
from flwr.common.retry_invoker import RetryInvoker, exponential

from .client_app import load_client_app
from .grpc_client.connection import grpc_connection
Expand Down Expand Up @@ -104,6 +107,8 @@ def _load() -> ClientApp:
transport="rest" if args.rest else "grpc-rere",
root_certificates=root_certificates,
insecure=args.insecure,
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
)
register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)

Expand Down Expand Up @@ -141,6 +146,22 @@ def _parse_args_run_client_app() -> argparse.ArgumentParser:
default="0.0.0.0:9092",
help="Server address",
)
parser.add_argument(
"--max-retries",
type=int,
default=None,
help="The maximum number of times the client will try to connect to the"
"server before giving up in case of a connection error. By default,"
"it is set to None, meaning there is no limit to the number of tries.",
)
parser.add_argument(
"--max-wait-time",
type=float,
default=None,
help="The maximum duration before the client stops trying to"
"connect to the server in case of connection error. By default, it"
"is set to None, meaning there is no limit to the total time.",
)
parser.add_argument(
"--dir",
default="",
Expand Down Expand Up @@ -180,6 +201,8 @@ def start_client(
root_certificates: Optional[Union[bytes, str]] = None,
insecure: Optional[bool] = None,
transport: Optional[str] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Expand Down Expand Up @@ -213,6 +236,14 @@ class `flwr.client.Client` (default: None)
- 'grpc-bidi': gRPC, bidirectional streaming
- 'grpc-rere': gRPC, request-response (experimental)
- 'rest': HTTP (experimental)
max_retries: Optional[int] (default: None)
The maximum number of times the client will try to connect to the
server before giving up in case of a connection error. If set to None,
there is no limit to the number of tries.
max_wait_time: Optional[float] (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
Examples
--------
Expand Down Expand Up @@ -254,6 +285,8 @@ class `flwr.client.Client` (default: None)
root_certificates=root_certificates,
insecure=insecure,
transport=transport,
max_retries=max_retries,
max_wait_time=max_wait_time,
)
event(EventType.START_CLIENT_LEAVE)

Expand All @@ -272,6 +305,8 @@ def _start_client_internal(
root_certificates: Optional[Union[bytes, str]] = None,
insecure: Optional[bool] = None,
transport: Optional[str] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Expand Down Expand Up @@ -299,14 +334,22 @@ class `flwr.client.Client` (default: None)
The PEM-encoded root certificates as a byte string or a path string.
If provided, a secure connection using the certificates will be
established to an SSL-enabled Flower server.
insecure : bool (default: True)
insecure : Optional[bool] (default: None)
Starts an insecure gRPC connection when True. Enables HTTPS connection
when False, using system certificates if `root_certificates` is None.
transport : Optional[str] (default: None)
Configure the transport layer. Allowed values:
- 'grpc-bidi': gRPC, bidirectional streaming
- 'grpc-rere': gRPC, request-response (experimental)
- 'rest': HTTP (experimental)
max_retries: Optional[int] (default: None)
The maximum number of times the client will try to connect to the
server before giving up in case of a connection error. If set to None,
there is no limit to the number of tries.
max_wait_time: Optional[float] (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
"""
if insecure is None:
insecure = root_certificates is None
Expand Down Expand Up @@ -338,7 +381,45 @@ def _load_client_app() -> ClientApp:
# Both `client` and `client_fn` must not be used directly

# Initialize connection context manager
connection, address = _init_connection(transport, server_address)
connection, address, connection_error_type = _init_connection(
transport, server_address
)

retry_invoker = RetryInvoker(
wait_factory=exponential,
recoverable_exceptions=connection_error_type,
max_tries=max_retries,
max_time=max_wait_time,
on_giveup=lambda retry_state: (
log(
WARN,
"Giving up reconnection after %.2f seconds and %s tries.",
retry_state.elapsed_time,
retry_state.tries,
)
if retry_state.tries > 1
else None
),
on_success=lambda retry_state: (
log(
INFO,
"Connection successful after %.2f seconds and %s tries.",
retry_state.elapsed_time,
retry_state.tries,
)
if retry_state.tries > 1
else None
),
on_backoff=lambda retry_state: (
log(WARN, "Connection attempt failed, retrying...")
if retry_state.tries == 1
else log(
DEBUG,
"Connection attempt failed, retrying in %.2f seconds",
retry_state.actual_wait,
)
),
)

node_state = NodeState()

Expand All @@ -347,6 +428,7 @@ def _load_client_app() -> ClientApp:
with connection(
address,
insecure,
retry_invoker,
grpc_max_message_length,
root_certificates,
) as conn:
Expand Down Expand Up @@ -509,7 +591,7 @@ def start_numpy_client(

def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
Callable[
[str, bool, int, Union[bytes, str, None]],
[str, bool, RetryInvoker, int, Union[bytes, str, None]],
ContextManager[
Tuple[
Callable[[], Optional[Message]],
Expand All @@ -520,6 +602,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
],
],
str,
Type[Exception],
]:
# Parse IP address
parsed_address = parse_address(server_address)
Expand All @@ -535,6 +618,8 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
# Use either gRPC bidirectional streaming or REST request/response
if transport == TRANSPORT_TYPE_REST:
try:
from requests.exceptions import ConnectionError as RequestsConnectionError

from .rest_client.connection import http_request_response
except ModuleNotFoundError:
sys.exit(MISSING_EXTRA_REST)
Expand All @@ -543,14 +628,14 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
"When using the REST API, please provide `https://` or "
"`http://` before the server address (e.g. `http://127.0.0.1:8080`)"
)
connection = http_request_response
connection, error_type = http_request_response, RequestsConnectionError
elif transport == TRANSPORT_TYPE_GRPC_RERE:
connection = grpc_request_response
connection, error_type = grpc_request_response, RpcError
elif transport == TRANSPORT_TYPE_GRPC_BIDI:
connection = grpc_connection
connection, error_type = grpc_connection, RpcError
else:
raise ValueError(
f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"
)

return connection, address
return connection, address, error_type
7 changes: 7 additions & 0 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.retry_invoker import RetryInvoker
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Reason,
Expand All @@ -62,6 +63,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
def grpc_connection( # pylint: disable=R0915
server_address: str,
insecure: bool,
retry_invoker: RetryInvoker, # pylint: disable=unused-argument
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
root_certificates: Optional[Union[bytes, str]] = None,
) -> Iterator[
Expand All @@ -80,6 +82,11 @@ def grpc_connection( # pylint: disable=R0915
The IPv4 or IPv6 address of the server. If the Flower server runs on the same
machine on port 8080, then `server_address` would be `"0.0.0.0:8080"` or
`"[::]:8080"`.
insecure : bool
Starts an insecure gRPC connection when True. Enables HTTPS connection
when False, using system certificates if `root_certificates` is None.
retry_invoker: RetryInvoker
Unused argument present for compatibilty.
max_message_length : int
The maximum length of gRPC messages that can be exchanged with the Flower
server. The default should be sufficient for most models. Users who train
Expand Down
12 changes: 11 additions & 1 deletion src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flwr.common import ConfigsRecord, Message, Metadata, RecordSet
from flwr.common import recordset_compat as compat
from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES
from flwr.common.retry_invoker import RetryInvoker, exponential
from flwr.common.typing import Code, GetPropertiesRes, Status
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
ClientMessage,
Expand Down Expand Up @@ -127,7 +128,16 @@ def test_integration_connection() -> None:
def run_client() -> int:
messages_received: int = 0

with grpc_connection(server_address=f"[::]:{port}", insecure=True) as conn:
with grpc_connection(
server_address=f"[::]:{port}",
insecure=True,
retry_invoker=RetryInvoker(
wait_factory=exponential,
recoverable_exceptions=grpc.RpcError,
max_tries=1,
max_time=None,
),
) as conn:
receive, send, _, _ = conn

# Setup processing loop
Expand Down
18 changes: 14 additions & 4 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flwr.common.grpc import create_channel
from flwr.common.logger import log, warn_experimental_feature
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand All @@ -51,6 +52,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
def grpc_request_response(
server_address: str,
insecure: bool,
retry_invoker: RetryInvoker,
max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613
root_certificates: Optional[Union[bytes, str]] = None,
) -> Iterator[
Expand All @@ -72,6 +74,13 @@ def grpc_request_response(
The IPv6 address of the server with `http://` or `https://`.
If the Flower server runs on the same machine
on port 8080, then `server_address` would be `"http://[::]:8080"`.
insecure : bool
Starts an insecure gRPC connection when True. Enables HTTPS connection
when False, using system certificates if `root_certificates` is None.
retry_invoker: RetryInvoker
`RetryInvoker` object that will try to reconnect the client to the server
after gRPC errors. If None, the client will only try to
reconnect once after a failure.
max_message_length : int
Ignored, only present to preserve API-compatibility.
root_certificates : Optional[Union[bytes, str]] (default: None)
Expand Down Expand Up @@ -113,7 +122,8 @@ def grpc_request_response(
def create_node() -> None:
"""Set create_node."""
create_node_request = CreateNodeRequest()
create_node_response = stub.CreateNode(
create_node_response = retry_invoker.invoke(
stub.CreateNode,
request=create_node_request,
)
node_store[KEY_NODE] = create_node_response.node
Expand All @@ -127,7 +137,7 @@ def delete_node() -> None:
node: Node = cast(Node, node_store[KEY_NODE])

delete_node_request = DeleteNodeRequest(node=node)
stub.DeleteNode(request=delete_node_request)
retry_invoker.invoke(stub.DeleteNode, request=delete_node_request)

del node_store[KEY_NODE]

Expand All @@ -141,7 +151,7 @@ def receive() -> Optional[Message]:

# Request instructions (task) from server
request = PullTaskInsRequest(node=node)
response = stub.PullTaskIns(request=request)
response = retry_invoker.invoke(stub.PullTaskIns, request=request)

# Get the current TaskIns
task_ins: Optional[TaskIns] = get_task_ins(response)
Expand Down Expand Up @@ -185,7 +195,7 @@ def send(message: Message) -> None:

# Serialize ProtoBuf to bytes
request = PushTaskResRequest(task_res_list=[task_res])
_ = stub.PushTaskRes(request)
_ = retry_invoker.invoke(stub.PushTaskRes, request)

state[KEY_METADATA] = None

Expand Down
Loading

0 comments on commit 1298298

Please sign in to comment.