Skip to content

Commit

Permalink
refactor(framework) Rename State to LinkState (#4347)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Oct 22, 2024
1 parent c268296 commit 2a80496
Show file tree
Hide file tree
Showing 27 changed files with 171 additions and 157 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from flwr.common.typing import Fab, Run, UserConfig
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes

from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer
from .grpc_adapter_client.connection import grpc_adapter
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/clientapp/clientappio_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from flwr.proto.message_pb2 import Context as ProtoContext
from flwr.proto.run_pb2 import Run as ProtoRun
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes

from .clientappio_servicer import ClientAppInputs, ClientAppIoServicer, ClientAppOutputs

Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
)
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
from .superlink.state import StateFactory
from .superlink.linkstate import LinkStateFactory

DATABASE = ":flwr-in-memory-state:"
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
Expand Down Expand Up @@ -216,7 +216,7 @@ def run_superlink() -> None:
certificates = _try_obtain_certificates(args)

# Initialize StateFactory
state_factory = StateFactory(args.database)
state_factory = LinkStateFactory(args.database)

# Initialize FfsFactory
ffs_factory = FfsFactory(args.storage_dir)
Expand Down Expand Up @@ -504,7 +504,7 @@ def _try_obtain_certificates(

def _run_fleet_api_grpc_rere(
address: str,
state_factory: StateFactory,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
certificates: Optional[tuple[bytes, bytes, bytes]],
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
Expand Down Expand Up @@ -532,7 +532,7 @@ def _run_fleet_api_grpc_rere(

def _run_fleet_api_grpc_adapter(
address: str,
state_factory: StateFactory,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
certificates: Optional[tuple[bytes, bytes, bytes]],
) -> grpc.Server:
Expand Down Expand Up @@ -563,7 +563,7 @@ def _run_fleet_api_rest(
port: int,
ssl_keyfile: Optional[str],
ssl_certfile: Optional[str],
state_factory: StateFactory,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
num_workers: int,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/driver/inmemory_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flwr.common.serde import message_from_taskres, message_to_taskins
from flwr.common.typing import Run
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.server.superlink.state import StateFactory
from flwr.server.superlink.linkstate import LinkStateFactory

from .driver import Driver

Expand All @@ -46,7 +46,7 @@ class InMemoryDriver(Driver):
def __init__(
self,
run_id: int,
state_factory: StateFactory,
state_factory: LinkStateFactory,
pull_interval: float = 0.1,
) -> None:
self._run_id = run_id
Expand Down
16 changes: 10 additions & 6 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@
)
from flwr.common.typing import Run
from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611
from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory
from flwr.server.superlink.state.utils import generate_rand_int_from_bytes
from flwr.server.superlink.linkstate import (
InMemoryLinkState,
LinkStateFactory,
SqliteLinkState,
)
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes

from .inmemory_driver import InMemoryDriver

Expand Down Expand Up @@ -227,12 +231,12 @@ def test_send_and_receive_messages_timeout(self) -> None:
def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
"""Test tasks are deleted in sqlite state once messages are pulled."""
# Prepare
state = StateFactory("").state()
state = LinkStateFactory("").state()
self.driver = InMemoryDriver(
state.create_run("", "", "", {}), MagicMock(state=lambda: state)
)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteState)
assert isinstance(state, SqliteLinkState)

# Check recorded
task_ins = state.query("SELECT * FROM task_ins;")
Expand All @@ -253,11 +257,11 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
"""Test tasks are deleted in in-memory state once messages are pulled."""
# Prepare
state_factory = StateFactory(":flwr-in-memory-state:")
state_factory = LinkStateFactory(":flwr-in-memory-state:")
state = state_factory.state()
self.driver = InMemoryDriver(state.create_run("", "", "", {}), state_factory)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryState)
assert isinstance(state, InMemoryLinkState)

# Check recorded
self.assertEqual(len(state.task_ins_store), len(list(msg_ids)))
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/superlink/driver/driver_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
add_DriverServicer_to_server,
)
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.state import StateFactory
from flwr.server.superlink.linkstate import LinkStateFactory

from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
from .driver_servicer import DriverServicer


def run_driver_api_grpc(
address: str,
state_factory: StateFactory,
state_factory: LinkStateFactory,
ffs_factory: FfsFactory,
certificates: Optional[tuple[bytes, bytes, bytes]],
) -> grpc.Server:
Expand Down
16 changes: 9 additions & 7 deletions src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs import Ffs
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.state import State, StateFactory
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
from flwr.server.utils.validator import validate_task_ins_or_res


class DriverServicer(driver_pb2_grpc.DriverServicer):
"""Driver API servicer."""

def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
def __init__(
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
) -> None:
self.state_factory = state_factory
self.ffs_factory = ffs_factory

Expand All @@ -67,7 +69,7 @@ def GetNodes(
) -> GetNodesResponse:
"""Get available nodes."""
log(DEBUG, "DriverServicer.GetNodes")
state: State = self.state_factory.state()
state: LinkState = self.state_factory.state()
all_ids: set[int] = state.get_nodes(request.run_id)
nodes: list[Node] = [
Node(node_id=node_id, anonymous=False) for node_id in all_ids
Expand All @@ -79,7 +81,7 @@ def CreateRun(
) -> CreateRunResponse:
"""Create run ID."""
log(DEBUG, "DriverServicer.CreateRun")
state: State = self.state_factory.state()
state: LinkState = self.state_factory.state()
if request.HasField("fab"):
fab = fab_from_proto(request.fab)
ffs: Ffs = self.ffs_factory.ffs()
Expand Down Expand Up @@ -116,7 +118,7 @@ def PushTaskIns(
_raise_if(bool(validation_errors), ", ".join(validation_errors))

# Init state
state: State = self.state_factory.state()
state: LinkState = self.state_factory.state()

# Store each TaskIns
task_ids: list[Optional[UUID]] = []
Expand All @@ -138,7 +140,7 @@ def PullTaskRes(
task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids}

# Init state
state: State = self.state_factory.state()
state: LinkState = self.state_factory.state()

# Register callback
def on_rpc_done() -> None:
Expand Down Expand Up @@ -167,7 +169,7 @@ def GetRun(
log(DEBUG, "DriverServicer.GetRun")

# Init state
state: State = self.state_factory.state()
state: LinkState = self.state_factory.state()

# Retrieve run information
run = state.get_run(request.run_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.state import StateFactory
from flwr.server.superlink.linkstate import LinkStateFactory

T = TypeVar("T", bound=GrpcMessage)

Expand Down Expand Up @@ -77,7 +77,9 @@ def _handle(
class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer):
"""Fleet API via GrpcAdapter servicer."""

def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
def __init__(
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
) -> None:
self.state_factory = state_factory
self.ffs_factory = ffs_factory

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.state import StateFactory
from flwr.server.superlink.linkstate import LinkStateFactory


class FleetServicer(fleet_pb2_grpc.FleetServicer):
"""Fleet API servicer."""

def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
def __init__(
self, state_factory: LinkStateFactory, ffs_factory: FfsFactory
) -> None:
self.state_factory = state_factory
self.ffs_factory = ffs_factory

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.server.superlink.state import State
from flwr.server.superlink.linkstate import LinkState

_PUBLIC_KEY_HEADER = "public-key"
_AUTH_TOKEN_HEADER = "auth-token"
Expand Down Expand Up @@ -84,7 +84,7 @@ def _get_value_from_tuples(
class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
"""Server interceptor for node authentication."""

def __init__(self, state: State):
def __init__(self, state: LinkState):
self.state = state

self.node_public_keys = state.get_node_public_keys()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611
from flwr.server.app import _run_fleet_api_grpc_rere
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.state.state_factory import StateFactory
from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory

from .server_interceptor import (
_AUTH_TOKEN_HEADER,
Expand All @@ -62,7 +62,7 @@ def setUp(self) -> None:
self._node_private_key, self._node_public_key = generate_key_pairs()
self._server_private_key, self._server_public_key = generate_key_pairs()

state_factory = StateFactory(":flwr-in-memory-state:")
state_factory = LinkStateFactory(":flwr-in-memory-state:")
self.state = state_factory.state()
ffs_factory = FfsFactory(".")
self.ffs = ffs_factory.ffs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@
)
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs import Ffs
from flwr.server.superlink.state import State
from flwr.server.superlink.linkstate import LinkState


def create_node(
request: CreateNodeRequest, # pylint: disable=unused-argument
state: State,
state: LinkState,
) -> CreateNodeResponse:
"""."""
# Create node
node_id = state.create_node(ping_interval=request.ping_interval)
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))


def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:
def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse:
"""."""
# Validate node_id
if request.node.anonymous or request.node.node_id == 0:
Expand All @@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse:

def ping(
request: PingRequest, # pylint: disable=unused-argument
state: State, # pylint: disable=unused-argument
state: LinkState, # pylint: disable=unused-argument
) -> PingResponse:
"""."""
res = state.acknowledge_ping(request.node.node_id, request.ping_interval)
return PingResponse(success=res)


def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse:
def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse:
"""Pull TaskIns handler."""
# Get node_id if client node is not anonymous
node = request.node # pylint: disable=no-member
Expand All @@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo
return response


def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResResponse:
def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse:
"""Push TaskRes handler."""
# pylint: disable=no-member
task_res: TaskRes = request.task_res_list[0]
Expand All @@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo


def get_run(
request: GetRunRequest, state: State # pylint: disable=W0613
request: GetRunRequest, state: LinkState # pylint: disable=W0613
) -> GetRunResponse:
"""Get run information."""
run = state.get_run(request.run_id)
Expand Down
Loading

0 comments on commit 2a80496

Please sign in to comment.