Skip to content

Commit

Permalink
feat(framework) Capture node_id/node_config in Context via `Nod…
Browse files Browse the repository at this point in the history
…eState` (#3780)

Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
jafermarq and danieljanes authored Jul 12, 2024
1 parent 19fad01 commit 01ca846
Show file tree
Hide file tree
Showing 18 changed files with 95 additions and 30 deletions.
40 changes: 35 additions & 5 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,13 @@ def _on_backoff(retry_state: RetryState) -> None:
on_backoff=_on_backoff,
)

node_state = NodeState(partition_id=partition_id)
# Empty dict (for now)
# This will be removed once users can pass node_config via flower-supernode
node_config: Dict[str, str] = {}

# NodeState gets initialized when the first connection is established
node_state: Optional[NodeState] = None

runs: Dict[int, Run] = {}

while not app_state_tracker.interrupt:
Expand All @@ -334,9 +340,33 @@ def _on_backoff(retry_state: RetryState) -> None:
) as conn:
receive, send, create_node, delete_node, get_run = conn

# Register node
if create_node is not None:
create_node() # pylint: disable=not-callable
# Register node when connecting the first time
if node_state is None:
if create_node is None:
if transport not in ["grpc-bidi", None]:
raise NotImplementedError(
"All transports except `grpc-bidi` require "
"an implementation for `create_node()`.'"
)
# gRPC-bidi doesn't have the concept of node_id,
# so we set it to -1
node_state = NodeState(
node_id=-1,
node_config={},
partition_id=partition_id,
)
else:
# Call create_node fn to register node
node_id: Optional[int] = ( # pylint: disable=assignment-from-none
create_node()
) # pylint: disable=not-callable
if node_id is None:
raise ValueError("Node registration failed")
node_state = NodeState(
node_id=node_id,
node_config=node_config,
partition_id=partition_id,
)

app_state_tracker.register_signal_handler()
while not app_state_tracker.interrupt:
Expand Down Expand Up @@ -580,7 +610,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_adapter_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def grpc_adapter( # pylint: disable=R0913
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def grpc_connection( # pylint: disable=R0913, R0915
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
Expand Down Expand Up @@ -176,7 +176,7 @@ def ping() -> None:
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)

def create_node() -> None:
def create_node() -> Optional[int]:
"""Set create_node."""
# Call FleetAPI
create_node_request = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)
Expand All @@ -189,6 +189,7 @@ def create_node() -> None:
nonlocal node, ping_thread
node = cast(Node, create_node_response.node)
ping_thread = start_ping_loop(ping, ping_stop_event)
return node.node_id

def delete_node() -> None:
"""Set delete_node."""
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_client_without_get_properties() -> None:
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
context=Context(state=RecordSet(), run_config={}),
context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}),
)

# Assert
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None:
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
context=Context(state=RecordSet(), run_config={}),
context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}),
)

# Assert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord:

def _make_ctxt() -> Context:
cfg = ConfigsRecord(SecAggPlusState().to_dict())
return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg}), run_config={})
return Context(
node_id=123,
node_config={},
state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}),
run_config={},
)


def _make_set_state_fn(
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/mod/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_multiple_mods(self) -> None:

state = RecordSet()
state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0})
context = Context(state=state, run_config={})
context = Context(node_id=0, node_config={}, state=state, run_config={})
message = _get_dummy_flower_message()

# Execute
Expand All @@ -129,7 +129,7 @@ def test_filter(self) -> None:
# Prepare
footprint: List[str] = []
mock_app = make_mock_app("app", footprint)
context = Context(state=RecordSet(), run_config={})
context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})
message = _get_dummy_flower_message()

def filter_mod(
Expand Down
11 changes: 8 additions & 3 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Dict, Optional

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config
Expand All @@ -35,8 +35,11 @@ class RunInfo:
class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(self, partition_id: Optional[int]) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
def __init__(
self, node_id: int, node_config: Dict[str, str], partition_id: Optional[int]
) -> None:
self.node_id = node_id
self.node_config = node_config
self.run_infos: Dict[int, RunInfo] = {}
self._partition_id = partition_id

Expand All @@ -52,6 +55,8 @@ def register_context(
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
node_id=self.node_id,
node_config=self.node_config,
state=RecordSet(),
run_config=initial_run_config.copy(),
partition_id=self._partition_id,
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}

# NodeState
node_state = NodeState(partition_id=None)
node_state = NodeState(node_id=0, node_config={}, partition_id=None)

for task in tasks:
run_id = task.run_id
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915
Tuple[
Callable[[], Optional[Message]],
Callable[[Message], None],
Optional[Callable[[], None]],
Optional[Callable[[], Optional[int]]],
Optional[Callable[[], None]],
Optional[Callable[[int], Run]],
]
Expand Down Expand Up @@ -237,19 +237,20 @@ def ping() -> None:
if not ping_stop_event.is_set():
ping_stop_event.wait(next_interval)

def create_node() -> None:
def create_node() -> Optional[int]:
"""Set create_node."""
req = CreateNodeRequest(ping_interval=PING_DEFAULT_INTERVAL)

# Send the request
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
if res is None:
return
return None

# Remember the node and the ping-loop thread
nonlocal node, ping_thread
node = res.node
ping_thread = start_ping_loop(ping, ping_stop_event)
return node.node_id

def delete_node() -> None:
"""Set delete_node."""
Expand Down
15 changes: 13 additions & 2 deletions src/py/flwr/common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class Context:
Parameters
----------
node_id : int
The ID that identifies the node.
node_config : Dict[str, str]
A config (key/value mapping) unique to the node and independent of the
`run_config`. This config persists across all runs this node participates in.
state : RecordSet
Holds records added by the entity in a given run and that will stay local.
This means that the data it holds will never leave the system it's running from.
Expand All @@ -44,16 +49,22 @@ class Context:
simulation or proto typing setups.
"""

node_id: int
node_config: Dict[str, str]
state: RecordSet
partition_id: Optional[int]
run_config: Dict[str, str]
partition_id: Optional[int]

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
node_id: int,
node_config: Dict[str, str],
state: RecordSet,
run_config: Dict[str, str],
partition_id: Optional[int] = None,
) -> None:
self.node_id = node_id
self.node_config = node_config
self.state = state
self.run_config = run_config
self.partition_id = partition_id
2 changes: 1 addition & 1 deletion src/py/flwr/server/compat/legacy_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def __init__(
self.strategy = strategy
self.client_manager = client_manager
self.history = History()
super().__init__(state, run_config={})
super().__init__(node_id=0, node_config={}, state=state, run_config={})
4 changes: 3 additions & 1 deletion src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def _load() -> ServerApp:
server_app = _load()

# Initialize Context
context = Context(state=RecordSet(), run_config=server_app_run_config)
context = Context(
node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
)

# Call ServerApp
server_app(driver=driver, context=context)
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/server_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_server_app_custom_mode() -> None:
# Prepare
app = ServerApp()
driver = MagicMock()
context = Context(state=RecordSet(), run_config={})
context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})

called = {"called": False}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _create_message_and_context() -> Tuple[Message, Context, float]:
)

# Construct emtpy Context
context = Context(state=RecordSet(), run_config={})
context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})

# Expected output
expected_output = pi * mult_factor
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def start_vce(
# Construct mapping of NodeStates
node_states: Dict[int, NodeState] = {}
for node_id, partition_id in nodes_mapping.items():
node_states[node_id] = NodeState(partition_id=partition_id)
node_states[node_id] = NodeState(
node_id=node_id, node_config={}, partition_id=partition_id
)

# Load backend config
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _load_app() -> ClientApp:

self.app_fn = _load_app
self.actor_pool = actor_pool
self.proxy_state = NodeState(partition_id=self.partition_id)
self.proxy_state = NodeState(
node_id=node_id, node_config={}, partition_id=self.partition_id
)

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
"""Sumbit a message to the ActorPool."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,13 @@ def _load_app() -> ClientApp:
_load_app,
message,
str(node_id),
Context(state=RecordSet(), run_config={}, partition_id=node_id),
Context(
node_id=0,
node_config={},
state=RecordSet(),
run_config={},
partition_id=node_id,
),
),
)

Expand Down

0 comments on commit 01ca846

Please sign in to comment.