diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index b6104e294301..0bf82a21242d 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -16,14 +16,13 @@ import threading -import time from logging import DEBUG, INFO from typing import Optional from uuid import UUID import grpc -from flwr.common import ConfigsRecord +from flwr.common import ConfigsRecord, now from flwr.common.constant import Status from flwr.common.logger import log from flwr.common.serde import ( @@ -31,6 +30,10 @@ context_to_proto, fab_from_proto, fab_to_proto, + message_from_proto, + message_from_taskres, + message_to_proto, + message_to_taskins, run_status_from_proto, run_status_to_proto, run_to_proto, @@ -153,7 +156,7 @@ def PushTaskIns( ) # Set pushed_at (timestamp in seconds) - pushed_at = time.time() + pushed_at = now().timestamp() for task_ins in request.task_ins_list: task_ins.task.pushed_at = pushed_at @@ -190,7 +193,54 @@ def PushMessages( self, request: PushInsMessagesRequest, context: grpc.ServicerContext ) -> PushInsMessagesResponse: """Push a set of Messages.""" - return PushInsMessagesResponse(message_ids=[]) + log(DEBUG, "ServerAppIoServicer.PushMessages") + + # Init state + state: LinkState = self.state_factory.state() + + # Abort if the run is not running + abort_if( + request.run_id, + [Status.PENDING, Status.STARTING, Status.FINISHED], + state, + context, + ) + + # Set pushed_at (timestamp in seconds) + pushed_at = now().timestamp() + + # Validate request and insert in State + _raise_if( + validation_error=len(request.messages_list) == 0, + request_name="PushMessages", + detail="`messages_list` must not be empty", + ) + message_ids: list[Optional[UUID]] = [] + while request.messages_list: + message_proto = request.messages_list.pop(0) + message = message_from_proto(message_proto=message_proto) + task_ins = message_to_taskins(message=message) + task_ins.task.pushed_at = pushed_at + validation_errors = validate_task_ins_or_res(task_ins) + _raise_if( + validation_error=bool(validation_errors), + request_name="PushMessages", + detail=", ".join(validation_errors), + ) + _raise_if( + validation_error=request.run_id != task_ins.run_id, + request_name="PushMessages", + detail="`task_ins` has mismatched `run_id`", + ) + # Store + message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins) + message_ids.append(message_id) + + return PushInsMessagesResponse( + message_ids=[ + str(message_id) if message_id else "" for message_id in message_ids + ] + ) def PullTaskRes( self, request: PullTaskResRequest, context: grpc.ServicerContext @@ -235,7 +285,47 @@ def PullMessages( self, request: PullResMessagesRequest, context: grpc.ServicerContext ) -> PullResMessagesResponse: """Pull a set of Messages.""" - return PullResMessagesResponse(messages_list=[]) + log(DEBUG, "ServerAppIoServicer.PullMessages") + + # Init state + state: LinkState = self.state_factory.state() + + # Abort if the run is not running + abort_if( + request.run_id, + [Status.PENDING, Status.STARTING, Status.FINISHED], + state, + context, + ) + + # Convert each task_id str to UUID + message_ids: set[UUID] = { + UUID(message_id) for message_id in request.message_ids + } + + # Read from state + task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids) + + # Convert to Messages + messages_list = [] + while task_res_list: + task_res = task_res_list.pop(0) + _raise_if( + validation_error=request.run_id != task_res.run_id, + request_name="PullMessages", + detail="`task_res` has mismatched `run_id`", + ) + message = message_from_taskres(taskres=task_res) + messages_list.append(message_to_proto(message)) + + # Delete the TaskIns/TaskRes pairs if TaskRes is found + task_ins_ids_to_delete = { + UUID(task_res.task.ancestry[0]) for task_res in task_res_list + } + + state.delete_tasks(task_ins_ids=task_ins_ids_to_delete) + + return PullResMessagesResponse(messages_list=messages_list) def GetRun( self, request: GetRunRequest, context: grpc.ServicerContext diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer_test.py b/src/py/flwr/server/superlink/driver/serverappio_servicer_test.py index 684dbb9c84ac..fdd824acb70a 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer_test.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer_test.py @@ -30,6 +30,7 @@ from flwr.common.serde import context_to_proto, run_status_to_proto from flwr.common.serde_test import RecordMaker from flwr.common.typing import RunStatus +from flwr.proto.message_pb2 import Message # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import ( # pylint: disable=E0611 UpdateRunStatusRequest, @@ -293,6 +294,41 @@ def test_push_task_ins_not_successful_if_not_running( # Execute & Assert self._assert_push_task_ins_not_allowed(task_ins, run_id) + def _assert_push_ins_messages_not_allowed( + self, message: Message, run_id: int + ) -> None: + """Assert `PushInsMessages` not allowed.""" + run_status = self.state.get_run_status({run_id})[run_id] + request = PushInsMessagesRequest(messages_list=[message], run_id=run_id) + + with self.assertRaises(grpc.RpcError) as e: + self._push_messages.with_call(request=request) + assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED + assert e.exception.details() == self.status_to_msg[run_status.status] + + @parameterized.expand( + [ + (0,), # Test not successful if RunStatus is pending. + (1,), # Test not successful if RunStatus is starting. + (3,), # Test not successful if RunStatus is finished. + ] + ) # type: ignore + def test_push_ins_messages_not_successful_if_not_running( + self, num_transitions: int + ) -> None: + """Test `PushInsMessages` not successful if RunStatus is not running.""" + # Prepare + node_id = self.state.create_node(ping_interval=30) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) + message_ins = create_ins_message( + src_node_id=SUPERLINK_NODE_ID, dst_node_id=node_id, run_id=run_id + ) + + self._transition_run_status(run_id, num_transitions) + + # Execute & Assert + self._assert_push_ins_messages_not_allowed(message_ins, run_id) + def test_pull_task_res_successful_if_running(self) -> None: """Test `PullTaskRes` success.""" # Prepare @@ -352,6 +388,35 @@ def test_pull_task_res_not_successful_if_not_running( # Execute & Assert self._assert_pull_task_res_not_allowed(run_id) + def _assert_pull_messages_not_allowed(self, run_id: int) -> None: + """Assert `PullMessages` not allowed.""" + run_status = self.state.get_run_status({run_id})[run_id] + request = PullResMessagesRequest(run_id=run_id) + + with self.assertRaises(grpc.RpcError) as e: + self._pull_messages.with_call(request=request) + assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED + assert e.exception.details() == self.status_to_msg[run_status.status] + + @parameterized.expand( + [ + (0,), # Test not successful if RunStatus is pending. + (1,), # Test not successful if RunStatus is starting. + (3,), # Test not successful if RunStatus is finished. + ] + ) # type: ignore + def test_pull_messages_not_successful_if_not_running( + self, num_transitions: int + ) -> None: + """Test `PullMessages` not successful if RunStatus is not running.""" + # Prepare + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) + + self._transition_run_status(run_id, num_transitions) + + # Execute & Assert + self._assert_pull_messages_not_allowed(run_id) + def test_push_serverapp_outputs_successful_if_running(self) -> None: """Test `PushServerAppOutputs` success.""" # Prepare