Skip to content

Commit

Permalink
feat(framework) Update message handler methods to raise `AbortRunExce…
Browse files Browse the repository at this point in the history
…ption` if called but run is not running (#4694)
  • Loading branch information
chongshenng authored Dec 14, 2024
1 parent 76809af commit 39ee863
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 18 deletions.
8 changes: 8 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,11 @@ class Fab:

class RunNotRunningException(BaseException):
"""Raised when a run is not running."""


class InvalidRunStatusException(BaseException):
"""Raised when an RPC is invalidated by the RunStatus."""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,5 @@ def _get_fab(self, request: GetFabRequest) -> GetFabResponse:
return message_handler.get_fab(
request=request,
ffs=self.ffs_factory.ffs(),
state=self.state_factory.state(),
)
44 changes: 32 additions & 12 deletions src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import grpc

from flwr.common.logger import log
from flwr.common.typing import InvalidRunStatusException
from flwr.proto import fleet_pb2_grpc # pylint: disable=E0611
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
Expand All @@ -38,6 +39,7 @@
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.linkstate import LinkStateFactory
from flwr.server.superlink.utils import abort_grpc_context


class FleetServicer(fleet_pb2_grpc.FleetServicer):
Expand Down Expand Up @@ -105,27 +107,45 @@ def PushTaskRes(
)
else:
log(INFO, "[Fleet.PushTaskRes] No task results to push")
return message_handler.push_task_res(
request=request,
state=self.state_factory.state(),
)

try:
res = message_handler.push_task_res(
request=request,
state=self.state_factory.state(),
)
except InvalidRunStatusException as e:
abort_grpc_context(e.message, context)

return res

def GetRun(
self, request: GetRunRequest, context: grpc.ServicerContext
) -> GetRunResponse:
"""Get run information."""
log(INFO, "[Fleet.GetRun] Requesting `Run` for run_id=%s", request.run_id)
return message_handler.get_run(
request=request,
state=self.state_factory.state(),
)

try:
res = message_handler.get_run(
request=request,
state=self.state_factory.state(),
)
except InvalidRunStatusException as e:
abort_grpc_context(e.message, context)

return res

def GetFab(
self, request: GetFabRequest, context: grpc.ServicerContext
) -> GetFabResponse:
"""Get FAB."""
log(INFO, "[Fleet.GetFab] Requesting FAB for fab_hash=%s", request.hash_str)
return message_handler.get_fab(
request=request,
ffs=self.ffs_factory.ffs(),
)
try:
res = message_handler.get_fab(
request=request,
ffs=self.ffs_factory.ffs(),
state=self.state_factory.state(),
)
except InvalidRunStatusException as e:
abort_grpc_context(e.message, context)

return res
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
import grpc

from flwr.common import ConfigsRecord
from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, Status
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
compute_hmac,
generate_key_pairs,
generate_shared_key,
private_key_to_bytes,
public_key_to_bytes,
)
from flwr.common.typing import RunStatus
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
CreateNodeResponse,
Expand Down Expand Up @@ -275,8 +276,14 @@ def test_successful_push_task_res_with_metadata(self) -> None:
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. PushTaskRes is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
_ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", ""))
request = PushTaskResRequest(
task_res_list=[TaskRes(task=Task(producer=Node(node_id=node_id)))]
task_res_list=[
TaskRes(task=Task(producer=Node(node_id=node_id)), run_id=run_id)
]
)
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -307,6 +314,10 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None:
node_id = self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. PushTaskRes is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
_ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", ""))
request = PushTaskResRequest(
task_res_list=[TaskRes(task=Task(producer=Node(node_id=node_id)))]
)
Expand All @@ -320,14 +331,15 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None:
)

# Execute & Assert
with self.assertRaises(grpc.RpcError):
with self.assertRaises(grpc.RpcError) as e:
self._push_task_res.with_call(
request=request,
metadata=(
(_PUBLIC_KEY_HEADER, public_key_bytes),
(_AUTH_TOKEN_HEADER, hmac_value),
),
)
assert e.exception.code() == grpc.StatusCode.UNAUTHENTICATED

def test_successful_get_run_with_metadata(self) -> None:
"""Test server interceptor for pull task ins."""
Expand All @@ -336,6 +348,9 @@ def test_successful_get_run_with_metadata(self) -> None:
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
# Transition status to running. GetRun is only allowed in running status.
_ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))
_ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", ""))
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from typing import Optional
from uuid import UUID

from flwr.common.constant import Status
from flwr.common.serde import fab_to_proto, user_config_to_proto
from flwr.common.typing import Fab
from flwr.common.typing import Fab, InvalidRunStatusException
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand All @@ -44,6 +45,7 @@
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs import Ffs
from flwr.server.superlink.linkstate import LinkState
from flwr.server.superlink.utils import check_abort


def create_node(
Expand Down Expand Up @@ -98,6 +100,15 @@ def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResR
task_res: TaskRes = request.task_res_list[0]
# pylint: enable=no-member

# Abort if the run is not running
abort_msg = check_abort(
task_res.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
)
if abort_msg:
raise InvalidRunStatusException(abort_msg)

# Set pushed_at (timestamp in seconds)
task_res.task.pushed_at = time.time()

Expand All @@ -121,6 +132,15 @@ def get_run(
if run is None:
return GetRunResponse()

# Abort if the run is not running
abort_msg = check_abort(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
)
if abort_msg:
raise InvalidRunStatusException(abort_msg)

return GetRunResponse(
run=Run(
run_id=run.run_id,
Expand All @@ -133,9 +153,18 @@ def get_run(


def get_fab(
request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
request: GetFabRequest, ffs: Ffs, state: LinkState # pylint: disable=W0613
) -> GetFabResponse:
"""Get FAB."""
# Abort if the run is not running
abort_msg = check_abort(
request.run_id,
[Status.PENDING, Status.STARTING, Status.FINISHED],
state,
)
if abort_msg:
raise InvalidRunStatusException(abort_msg)

if result := ffs.get(request.hash_str):
fab = Fab(request.hash_str, result[0])
return GetFabResponse(fab=fab_to_proto(fab))
Expand Down
5 changes: 4 additions & 1 deletion src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,11 @@ async def get_fab(request: GetFabRequest) -> GetFabResponse:
# Get ffs from app
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()

# Get state from app
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()

# Handle message
return message_handler.get_fab(request=request, ffs=ffs)
return message_handler.get_fab(request=request, ffs=ffs, state=state)


routes = [
Expand Down

0 comments on commit 39ee863

Please sign in to comment.