Skip to content

Commit

Permalink
feat(framework) Add PushInsMessages and PullResMessages behaviour…
Browse files Browse the repository at this point in the history
… to `ServerAppIoServicer` (#4308)

Co-authored-by: Heng Pan <[email protected]>
  • Loading branch information
jafermarq and panh99 authored Jan 20, 2025
1 parent bf98124 commit f5b3253
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 5 deletions.
100 changes: 95 additions & 5 deletions src/py/flwr/server/superlink/driver/serverappio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@


import threading
import time
from logging import DEBUG, INFO
from typing import Optional
from uuid import UUID

import grpc

from flwr.common import ConfigsRecord
from flwr.common import ConfigsRecord, now
from flwr.common.constant import Status
from flwr.common.logger import log
from flwr.common.serde import (
context_from_proto,
context_to_proto,
fab_from_proto,
fab_to_proto,
message_from_proto,
message_from_taskres,
message_to_proto,
message_to_taskins,
run_status_from_proto,
run_status_to_proto,
run_to_proto,
Expand Down Expand Up @@ -153,7 +156,7 @@ def PushTaskIns(
)

# Set pushed_at (timestamp in seconds)
pushed_at = time.time()
pushed_at = now().timestamp()
for task_ins in request.task_ins_list:
task_ins.task.pushed_at = pushed_at

Expand Down Expand Up @@ -190,7 +193,54 @@ def PushMessages(
self, request: PushInsMessagesRequest, context: grpc.ServicerContext
) -> PushInsMessagesResponse:
"""Push a set of Messages."""
return PushInsMessagesResponse(message_ids=[])
log(DEBUG, "ServerAppIoServicer.PushMessages")

# Init state
state: LinkState = self.state_factory.state()

# Abort if the run is not running
abort_if(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
context,
)

# Set pushed_at (timestamp in seconds)
pushed_at = now().timestamp()

# Validate request and insert in State
_raise_if(
validation_error=len(request.messages_list) == 0,
request_name="PushMessages",
detail="`messages_list` must not be empty",
)
message_ids: list[Optional[UUID]] = []
while request.messages_list:
message_proto = request.messages_list.pop(0)
message = message_from_proto(message_proto=message_proto)
task_ins = message_to_taskins(message=message)
task_ins.task.pushed_at = pushed_at
validation_errors = validate_task_ins_or_res(task_ins)
_raise_if(
validation_error=bool(validation_errors),
request_name="PushMessages",
detail=", ".join(validation_errors),
)
_raise_if(
validation_error=request.run_id != task_ins.run_id,
request_name="PushMessages",
detail="`task_ins` has mismatched `run_id`",
)
# Store
message_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins)
message_ids.append(message_id)

return PushInsMessagesResponse(
message_ids=[
str(message_id) if message_id else "" for message_id in message_ids
]
)

def PullTaskRes(
self, request: PullTaskResRequest, context: grpc.ServicerContext
Expand Down Expand Up @@ -235,7 +285,47 @@ def PullMessages(
self, request: PullResMessagesRequest, context: grpc.ServicerContext
) -> PullResMessagesResponse:
"""Pull a set of Messages."""
return PullResMessagesResponse(messages_list=[])
log(DEBUG, "ServerAppIoServicer.PullMessages")

# Init state
state: LinkState = self.state_factory.state()

# Abort if the run is not running
abort_if(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
context,
)

# Convert each task_id str to UUID
message_ids: set[UUID] = {
UUID(message_id) for message_id in request.message_ids
}

# Read from state
task_res_list: list[TaskRes] = state.get_task_res(task_ids=message_ids)

# Convert to Messages
messages_list = []
while task_res_list:
task_res = task_res_list.pop(0)
_raise_if(
validation_error=request.run_id != task_res.run_id,
request_name="PullMessages",
detail="`task_res` has mismatched `run_id`",
)
message = message_from_taskres(taskres=task_res)
messages_list.append(message_to_proto(message))

# Delete the TaskIns/TaskRes pairs if TaskRes is found
task_ins_ids_to_delete = {
UUID(task_res.task.ancestry[0]) for task_res in task_res_list
}

state.delete_tasks(task_ins_ids=task_ins_ids_to_delete)

return PullResMessagesResponse(messages_list=messages_list)

def GetRun(
self, request: GetRunRequest, context: grpc.ServicerContext
Expand Down
65 changes: 65 additions & 0 deletions src/py/flwr/server/superlink/driver/serverappio_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from flwr.common.serde import context_to_proto, run_status_to_proto
from flwr.common.serde_test import RecordMaker
from flwr.common.typing import RunStatus
from flwr.proto.message_pb2 import Message # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
UpdateRunStatusRequest,
Expand Down Expand Up @@ -293,6 +294,41 @@ def test_push_task_ins_not_successful_if_not_running(
# Execute & Assert
self._assert_push_task_ins_not_allowed(task_ins, run_id)

def _assert_push_ins_messages_not_allowed(
self, message: Message, run_id: int
) -> None:
"""Assert `PushInsMessages` not allowed."""
run_status = self.state.get_run_status({run_id})[run_id]
request = PushInsMessagesRequest(messages_list=[message], run_id=run_id)

with self.assertRaises(grpc.RpcError) as e:
self._push_messages.with_call(request=request)
assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED
assert e.exception.details() == self.status_to_msg[run_status.status]

@parameterized.expand(
[
(0,), # Test not successful if RunStatus is pending.
(1,), # Test not successful if RunStatus is starting.
(3,), # Test not successful if RunStatus is finished.
]
) # type: ignore
def test_push_ins_messages_not_successful_if_not_running(
self, num_transitions: int
) -> None:
"""Test `PushInsMessages` not successful if RunStatus is not running."""
# Prepare
node_id = self.state.create_node(ping_interval=30)
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
message_ins = create_ins_message(
src_node_id=SUPERLINK_NODE_ID, dst_node_id=node_id, run_id=run_id
)

self._transition_run_status(run_id, num_transitions)

# Execute & Assert
self._assert_push_ins_messages_not_allowed(message_ins, run_id)

def test_pull_task_res_successful_if_running(self) -> None:
"""Test `PullTaskRes` success."""
# Prepare
Expand Down Expand Up @@ -352,6 +388,35 @@ def test_pull_task_res_not_successful_if_not_running(
# Execute & Assert
self._assert_pull_task_res_not_allowed(run_id)

def _assert_pull_messages_not_allowed(self, run_id: int) -> None:
"""Assert `PullMessages` not allowed."""
run_status = self.state.get_run_status({run_id})[run_id]
request = PullResMessagesRequest(run_id=run_id)

with self.assertRaises(grpc.RpcError) as e:
self._pull_messages.with_call(request=request)
assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED
assert e.exception.details() == self.status_to_msg[run_status.status]

@parameterized.expand(
[
(0,), # Test not successful if RunStatus is pending.
(1,), # Test not successful if RunStatus is starting.
(3,), # Test not successful if RunStatus is finished.
]
) # type: ignore
def test_pull_messages_not_successful_if_not_running(
self, num_transitions: int
) -> None:
"""Test `PullMessages` not successful if RunStatus is not running."""
# Prepare
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())

self._transition_run_status(run_id, num_transitions)

# Execute & Assert
self._assert_pull_messages_not_allowed(run_id)

def test_push_serverapp_outputs_successful_if_running(self) -> None:
"""Test `PushServerAppOutputs` success."""
# Prepare
Expand Down

0 comments on commit f5b3253

Please sign in to comment.