From 39ee863b6ad8380abcb70c4b3bab61276a9457a6 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Sat, 14 Dec 2024 18:39:56 +0000 Subject: [PATCH] feat(framework) Update message handler methods to raise `AbortRunException` if called but run is not running (#4694) --- src/py/flwr/common/typing.py | 8 ++++ .../grpc_adapter/grpc_adapter_servicer.py | 1 + .../fleet/grpc_rere/fleet_servicer.py | 44 ++++++++++++++----- .../grpc_rere/server_interceptor_test.py | 21 +++++++-- .../fleet/message_handler/message_handler.py | 33 +++++++++++++- .../superlink/fleet/rest_rere/rest_api.py | 5 ++- 6 files changed, 94 insertions(+), 18 deletions(-) diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 42c7d9dab1ff..d6b940ba75e6 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -258,3 +258,11 @@ class Fab: class RunNotRunningException(BaseException): """Raised when a run is not running.""" + + +class InvalidRunStatusException(BaseException): + """Raised when an RPC is invalidated by the RunStatus.""" + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index ffef57d89e8c..49a658b589bf 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -158,4 +158,5 @@ def _get_fab(self, request: GetFabRequest) -> GetFabResponse: return message_handler.get_fab( request=request, ffs=self.ffs_factory.ffs(), + state=self.state_factory.state(), ) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index dacbab135057..225e5e5e225f 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -20,6 +20,7 @@ import grpc from flwr.common.logger import log +from flwr.common.typing import InvalidRunStatusException from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611 from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 @@ -38,6 +39,7 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.fleet.message_handler import message_handler from flwr.server.superlink.linkstate import LinkStateFactory +from flwr.server.superlink.utils import abort_grpc_context class FleetServicer(fleet_pb2_grpc.FleetServicer): @@ -105,27 +107,45 @@ def PushTaskRes( ) else: log(INFO, "[Fleet.PushTaskRes] No task results to push") - return message_handler.push_task_res( - request=request, - state=self.state_factory.state(), - ) + + try: + res = message_handler.push_task_res( + request=request, + state=self.state_factory.state(), + ) + except InvalidRunStatusException as e: + abort_grpc_context(e.message, context) + + return res def GetRun( self, request: GetRunRequest, context: grpc.ServicerContext ) -> GetRunResponse: """Get run information.""" log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id) - return message_handler.get_run( - request=request, - state=self.state_factory.state(), - ) + + try: + res = message_handler.get_run( + request=request, + state=self.state_factory.state(), + ) + except InvalidRunStatusException as e: + abort_grpc_context(e.message, context) + + return res def GetFab( self, request: GetFabRequest, context: grpc.ServicerContext ) -> GetFabResponse: """Get FAB.""" log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str) - return message_handler.get_fab( - request=request, - ffs=self.ffs_factory.ffs(), - ) + try: + res = message_handler.get_fab( + request=request, + ffs=self.ffs_factory.ffs(), + state=self.state_factory.state(), + ) + except InvalidRunStatusException as e: + abort_grpc_context(e.message, context) + + return res diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index ce43fc4bae0a..23a28d0dfafc 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -21,7 +21,7 @@ import grpc from flwr.common import ConfigsRecord -from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS +from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, Status from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, generate_key_pairs, @@ -29,6 +29,7 @@ private_key_to_bytes, public_key_to_bytes, ) +from flwr.common.typing import RunStatus from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -275,8 +276,14 @@ def test_successful_push_task_res_with_metadata(self) -> None: node_id = self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) ) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) + # Transition status to running. PushTaskRes is only allowed in running status. + _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) request = PushTaskResRequest( - task_res_list=[TaskRes(task=Task(producer=Node(node_id=node_id)))] + task_res_list=[ + TaskRes(task=Task(producer=Node(node_id=node_id)), run_id=run_id) + ] ) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -307,6 +314,10 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: node_id = self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) ) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) + # Transition status to running. PushTaskRes is only allowed in running status. + _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) request = PushTaskResRequest( task_res_list=[TaskRes(task=Task(producer=Node(node_id=node_id)))] ) @@ -320,7 +331,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: ) # Execute & Assert - with self.assertRaises(grpc.RpcError): + with self.assertRaises(grpc.RpcError) as e: self._push_task_res.with_call( request=request, metadata=( @@ -328,6 +339,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: (_AUTH_TOKEN_HEADER, hmac_value), ), ) + assert e.exception.code() == grpc.StatusCode.UNAUTHENTICATED def test_successful_get_run_with_metadata(self) -> None: """Test server interceptor for pull task ins.""" @@ -336,6 +348,9 @@ def test_successful_get_run_with_metadata(self) -> None: ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) ) run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) + # Transition status to running. GetRun is only allowed in running status. + _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 38df6f441a20..34917b8168eb 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -19,8 +19,9 @@ from typing import Optional from uuid import UUID +from flwr.common.constant import Status from flwr.common.serde import fab_to_proto, user_config_to_proto -from flwr.common.typing import Fab +from flwr.common.typing import Fab, InvalidRunStatusException from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -44,6 +45,7 @@ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs import Ffs from flwr.server.superlink.linkstate import LinkState +from flwr.server.superlink.utils import check_abort def create_node( @@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR task_res: TaskRes = request.task_res_list[0] # pylint: enable=no-member + # Abort if the run is not running + abort_msg = check_abort( + task_res.run_id, + [Status.PENDING, Status.STARTING, Status.FINISHED], + state, + ) + if abort_msg: + raise InvalidRunStatusException(abort_msg) + # Set pushed_at (timestamp in seconds) task_res.task.pushed_at = time.time() @@ -121,6 +132,15 @@ def get_run( if run is None: return GetRunResponse() + # Abort if the run is not running + abort_msg = check_abort( + request.run_id, + [Status.PENDING, Status.STARTING, Status.FINISHED], + state, + ) + if abort_msg: + raise InvalidRunStatusException(abort_msg) + return GetRunResponse( run=Run( run_id=run.run_id, @@ -133,9 +153,18 @@ def get_run( def get_fab( - request: GetFabRequest, ffs: Ffs # pylint: disable=W0613 + request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613 ) -> GetFabResponse: """Get FAB.""" + # Abort if the run is not running + abort_msg = check_abort( + request.run_id, + [Status.PENDING, Status.STARTING, Status.FINISHED], + state, + ) + if abort_msg: + raise InvalidRunStatusException(abort_msg) + if result := ffs.get(request.hash_str): fab = Fab(request.hash_str, result[0]) return GetFabResponse(fab=fab_to_proto(fab)) diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index a684cd9b3bf2..11b8bbfcac07 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse: # Get ffs from app ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs() + # Get state from app + state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state() + # Handle message - return message_handler.get_fab(request=request, ffs=ffs) + return message_handler.get_fab(request=request, ffs=ffs, state=state) routes = [