Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle ClientApp exception #2846

Merged
merged 37 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5432cfd
Handle FlowerCallable exceptions
danieljanes Jan 22, 2024
c669c68
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 24, 2024
fe6190f
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 24, 2024
1cd8c07
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 29, 2024
c3f1aca
Merge branch 'handle-flower-callable-exception' of github.com:adap/fl…
danieljanes Jan 29, 2024
a56f92d
Merge branch 'main' into handle-flower-callable-exception
danieljanes Jan 29, 2024
d3531d6
Add tests
danieljanes Jan 30, 2024
8647fdd
Merge branch 'main' into handle-flower-callable-exception
danieljanes Feb 2, 2024
cb33527
Update src/py/flwr/client/app.py
danieljanes Feb 2, 2024
cc3cf7c
Merge branch 'main' into handle-flower-callable-exception
danieljanes Feb 9, 2024
fa15c5f
Merge branch 'main' into handle-flower-callable-exception
danieljanes Feb 23, 2024
8b9c854
Fix imports
danieljanes Feb 23, 2024
a6e93fd
Fix argument
danieljanes Feb 23, 2024
c19e0b4
Fix imports
danieljanes Feb 23, 2024
42cdc4b
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 1, 2024
2f8167e
Use message.error
danieljanes Mar 1, 2024
9958978
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 1, 2024
eac0dbb
small fix
jafermarq Mar 1, 2024
7f634d5
Update task validation
danieljanes Mar 1, 2024
91687f9
Merge branch 'handle-flower-callable-exception' of github.com:adap/fl…
danieljanes Mar 1, 2024
735d7db
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 1, 2024
37aa310
Fix test
danieljanes Mar 1, 2024
bbdd3d0
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 4, 2024
ae2ca3a
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 5, 2024
449ac04
Complete Handling of `ClientApp` exception (#3067)
jafermarq Mar 6, 2024
30f4a2f
merge w/ main
jafermarq Mar 7, 2024
4931937
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 7, 2024
ef1287c
Fix lint issues
danieljanes Mar 7, 2024
b1d3414
Merge branch 'main' into handle-flower-callable-exception
danieljanes Mar 9, 2024
9a0376d
merge w/ main
jafermarq Mar 20, 2024
2f39cce
server_custom handles pulled errors
jafermarq Mar 20, 2024
bac30f5
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 25, 2024
0152490
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 27, 2024
c0288be
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 27, 2024
2ad4e85
raise ClientApp exception for `grpc-bidi` clients
jafermarq Mar 28, 2024
927ac33
Apply suggestions from code review
jafermarq Mar 28, 2024
bb6f72c
Merge branch 'main' into handle-flower-callable-exception
jafermarq Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import argparse
import sys
import time
from logging import DEBUG, INFO, WARN
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Optional, Tuple, Union

Expand All @@ -36,6 +36,8 @@
)
from flwr.common.logger import log, warn_deprecated_feature, warn_experimental_feature
from flwr.common.message import Message
from flwr.common.recordset import RecordSet
from flwr.common.serde import message_to_taskres

from .flower import load_flower_callable
from .grpc_client.connection import grpc_connection
Expand Down Expand Up @@ -379,7 +381,23 @@ def _load_app() -> Flower:
app: Flower = load_flower_callable_fn()

# Handle task message
out_message = app(message=message, context=context)
try:
out_message = app(message=message, context=context)
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, "FlowerCallable raised an exception", exc_info=ex)

# Don't update/change RunState
# Return empty Message

error_out_message = Message(
metadata=message.metadata,
message=RecordSet(),
)

# Construct TaskRes from out_message
error_task_res = message_to_taskres(error_out_message)
send(error_task_res)
continue

# Update node state
node_state.update_context(
Expand Down
19 changes: 19 additions & 0 deletions src/py/flwr/driver/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,27 @@ def _send_receive_recordset(
)
if len(task_res_list) == 1:
task_res = task_res_list[0]
validate_task_res(task_res=task_res)
return serde.recordset_from_proto(task_res.task.recordset)

if timeout is not None and time.time() > start_time + timeout:
raise RuntimeError("Timeout reached")
time.sleep(SLEEP_TIME)


def validate_task_res(
task_res: task_pb2.TaskRes, # pylint: disable=E1101
) -> None:
"""Validate if a TaskRes is empty or not."""
if not task_res.HasField("task"):
raise ValueError("Invalid TaskRes, field `task` missing")
if not task_res.task.HasField("recordset"):
raise ValueError("Invalid Task, field `recordset` missing")

rs = task_res.task.recordset
if (
(not rs.parameters.keys())
and (not rs.metrics.keys())
and (not rs.configs.keys())
):
raise ValueError("Exception during client-side task execution")
84 changes: 82 additions & 2 deletions src/py/flwr/driver/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@
Properties,
Status,
)
from flwr.driver.driver_client_proxy import DriverClientProxy
from flwr.proto import driver_pb2, node_pb2, task_pb2 # pylint: disable=E0611
from flwr.driver.driver_client_proxy import DriverClientProxy, validate_task_res
from flwr.proto import ( # pylint: disable=E0611
driver_pb2,
node_pb2,
recordset_pb2,
task_pb2,
)

MESSAGE_PARAMETERS = Parameters(tensors=[b"abc"], tensor_type="np")

Expand Down Expand Up @@ -245,3 +250,78 @@ def test_evaluate(self) -> None:
# Assert
assert 0.0 == evaluate_res.loss
assert 0 == evaluate_res.num_examples

def test_validate_task_res_valid(self) -> None:
"""Test valid TaskRes."""
metrics_record = recordset_pb2.MetricsRecord( # pylint: disable=E1101
data={
"loss": recordset_pb2.MetricsRecordValue( # pylint: disable=E1101
double=1.0
)
}
)
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
recordset=recordset_pb2.RecordSet( # pylint: disable=E1101
parameters={},
metrics={"loss": metrics_record},
configs={},
)
),
)

# Execute & assert
try:
validate_task_res(task_res=task_res)
except ValueError:
self.fail()

def test_validate_task_res_missing_task(self) -> None:
"""Test invalid TaskRes (missing task)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)

def test_validate_task_res_missing_recordset(self) -> None:
"""Test invalid TaskRes (missing recordset)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task(), # pylint: disable=E1101
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)

def test_validate_task_res_missing_content(self) -> None:
"""Test invalid TaskRes (missing content)."""
# Prepare
task_res = task_pb2.TaskRes( # pylint: disable=E1101
task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012",
group_id="",
run_id=0,
task=task_pb2.Task( # pylint: disable=E1101
recordset=recordset_pb2.RecordSet( # pylint: disable=E1101
parameters={},
metrics={},
configs={},
)
),
)

# Execute & assert
with self.assertRaises(ValueError):
validate_task_res(task_res=task_res)
Loading