Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle ClientApp exception #2846

Merged
merged 37 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5432cfd
Handle FlowerCallable exceptions
danieljanes Jan 22, 2024
c669c68
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 24, 2024
fe6190f
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 24, 2024
1cd8c07
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 29, 2024
c3f1aca
Merge branch 'handle-flower-callable-exception' of github.com:adap/fl…
danieljanes Jan 29, 2024
a56f92d
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 29, 2024
d3531d6
Add tests
danieljanes Jan 30, 2024
8647fdd
Merge branch 'main' into handle-flower-callable-exception
danieljanes Feb 2, 2024
cb33527
Update src/py/flwr/client/app.py
danieljanes Feb 2, 2024
cc3cf7c
Merge branch 'main' into handle-flower-callable-exception
danieljanes Feb 9, 2024
fa15c5f
Merge branch 'main' into handle-flower-callable-exception
danieljanes Feb 23, 2024
8b9c854
Fix imports
danieljanes Feb 23, 2024
a6e93fd
Fix argument
danieljanes Feb 23, 2024
c19e0b4
Fix imports
danieljanes Feb 23, 2024
42cdc4b
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 1, 2024
2f8167e
Use message.error
danieljanes Mar 1, 2024
9958978
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 1, 2024
eac0dbb
small fix
jafermarq Mar 1, 2024
7f634d5
Update task validation
danieljanes Mar 1, 2024
91687f9
Merge branch 'handle-flower-callable-exception' of github.com:adap/fl…
danieljanes Mar 1, 2024
735d7db
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 1, 2024
37aa310
Fix test
danieljanes Mar 1, 2024
bbdd3d0
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 4, 2024
ae2ca3a
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 5, 2024
449ac04
Complete Handling of `ClientApp` exception (#3067)
jafermarq Mar 6, 2024
30f4a2f
merge w/ main
jafermarq Mar 7, 2024
4931937
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 7, 2024
ef1287c
Fix lint issues
danieljanes Mar 7, 2024
b1d3414
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 9, 2024
9a0376d
merge w/ main
jafermarq Mar 20, 2024
2f39cce
server_custom handles pulled errors
jafermarq Mar 20, 2024
bac30f5
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 25, 2024
0152490
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 27, 2024
c0288be
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 27, 2024
2ad4e85
raise ClientApp exception for `grpc-bidi` clients
jafermarq Mar 28, 2024
927ac33
Apply suggestions from code review
jafermarq Mar 28, 2024
bb6f72c
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions examples/app-pytorch/server_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,19 @@ def main(driver: Driver, context: Context) -> None:
all_replies: List[Message] = []
while True:
replies = driver.pull_messages(message_ids=message_ids)
print(f"Got {len(replies)} results")
for res in replies:
print(f"Got 1 {'result' if res.has_content() else 'error'}")
all_replies += replies
if len(all_replies) == len(message_ids):
break
print("Pulling messages...")
time.sleep(3)

# Collect correct results
# Filter correct results
all_fitres = [
recordset_to_fitres(msg.content, keep_input=True) for msg in all_replies
recordset_to_fitres(msg.content, keep_input=True)
for msg in all_replies
if msg.has_content()
]
print(f"Received {len(all_fitres)} results")

Expand All @@ -128,16 +132,21 @@ def main(driver: Driver, context: Context) -> None:
)
metrics_results.append((fitres.num_examples, fitres.metrics))

# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated
if len(weights_results) > 0:
# Aggregate parameters (FedAvg)
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
parameters = parameters_aggregated

# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
# Aggregate metrics
metrics_aggregated = weighted_average(metrics_results)
history.add_metrics_distributed_fit(
server_round=server_round, metrics=metrics_aggregated
)
print("Round ", server_round, " metrics: ", metrics_aggregated)
else:
print(
f"Round {server_round} got {len(weights_results)} results. Skipping aggregation..."
)

# Slow down the start of the next round
time.sleep(sleep_time)
Expand Down
59 changes: 35 additions & 24 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# ==============================================================================
"""Flower client app."""


import argparse
import sys
import time
from logging import DEBUG, INFO, WARN
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Optional, Tuple, Type, Union

Expand All @@ -38,6 +37,7 @@
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
from flwr.common.message import Error
from flwr.common.object_ref import load_app, validate
from flwr.common.retry_invoker import RetryInvoker, exponential

Expand Down Expand Up @@ -482,32 +482,43 @@ def _load_client_app() -> ClientApp:
# Retrieve context for this run
context = node_state.retrieve_context(run_id=message.metadata.run_id)

# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()
# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
reply_message = message.create_error_reply(
error=Error(code=0, reason="Unknown")
)

# Handle task message
out_message = client_app(message=message, context=context)
# Handle app loading and task message
try:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()

# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
reply_message = client_app(message=message, context=context)
# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, "ClientApp raised an exception", exc_info=ex)

# Legacy grpc-bidi
if transport in ["grpc-bidi", None]:
# Raise exception, crash process
raise ex

# Don't update/change NodeState

# Create error message
# Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
reply_message = message.create_error_reply(
error=Error(code=0, reason=reason)
)

# Send
send(out_message)
log(
INFO,
"[RUN %s, ROUND %s]",
out_message.metadata.run_id,
out_message.metadata.group_id,
)
log(
INFO,
"Sent: %s reply to message %s",
out_message.metadata.message_type,
message.metadata.message_id,
)
send(reply_message)
log(INFO, "Sent reply")

# Unregister node
if delete_node is not None:
Expand Down
16 changes: 16 additions & 0 deletions src/py/flwr/server/compat/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,24 @@ def _send_receive_recordset(
)
if len(task_res_list) == 1:
task_res = task_res_list[0]

# This will raise an Exception if task_res carries an `error`
validate_task_res(task_res=task_res)

return serde.recordset_from_proto(task_res.task.recordset)

if timeout is not None and time.time() > start_time + timeout:
raise RuntimeError("Timeout reached")
time.sleep(SLEEP_TIME)


def validate_task_res(
task_res: task_pb2.TaskRes, # pylint: disable=E1101
) -> None:
"""Validate if a TaskRes is empty or not."""
if not task_res.HasField("task"):
raise ValueError("Invalid TaskRes, field `task` missing")
if task_res.task.HasField("error"):
raise ValueError("Exception during client-side task execution")
if not task_res.task.HasField("recordset"):
raise ValueError("Invalid TaskRes, both `recordset` and `error` are missing")
85 changes: 82 additions & 3 deletions src/py/flwr/server/compat/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@
Properties,
Status,
)
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611

from .driver_client_proxy import DriverClientProxy
from flwr.proto import ( # pylint: disable=E0611
driver_pb2,
error_pb2,
node_pb2,
recordset_pb2,
task_pb2,
)
from flwr.server.compat.driver_client_proxy import DriverClientProxy, validate_task_res

MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np")

Expand Down Expand Up @@ -243,3 +248,77 @@ def test_evaluate(self) -> None:
# Assert
assert 0.0 == evaluate_res.loss
assert 0 == evaluate_res.num_examples

def test_validate_task_res_valid(self) -> None:
"""Test valid TaskRes."""
metrics_record = recordset_pb2.MetricsRecord( # pylint: disable=E1101
data={
"loss": recordset_pb2.MetricsRecordValue( # pylint: disable=E1101
double=1.0
)
}
)
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
recordset=recordset_pb2.RecordSet( # pylint: disable=E1101
parameters={},
metrics={"loss": metrics_record},
configs={},
)
),
)

# Execute & assert
try:
validate_task_res(task_res=task_res)
except ValueError:
self.fail()

def test_validate_task_res_missing_task(self) -> None:
"""Test invalid TaskRes (missing task)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)

def test_validate_task_res_missing_recordset(self) -> None:
"""Test invalid TaskRes (missing recordset)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task(), # pylint: disable=E1101
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)

def test_validate_task_res_missing_content(self) -> None:
"""Test invalid TaskRes (missing content)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
error=error_pb2.Error( # pylint: disable=E1101
code=0,
reason="Some reason",
)
),
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)