Skip to content

Commit

Permalink
feat(framework) Add run_config to ClientApp Context (#3751)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
charlesbvll and danieljanes authored Jul 11, 2024
1 parent 2570af5 commit 2f2e346
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 16 deletions.
18 changes: 12 additions & 6 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from dataclasses import dataclass
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union

from cryptography.hazmat.primitives.asymmetric import ec
Expand Down Expand Up @@ -193,6 +194,7 @@ def _start_client_internal(
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
partition_id: Optional[int] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Expand Down Expand Up @@ -239,6 +241,8 @@ class `flwr.client.Client` (default: None)
partition_id: Optional[int] (default: None)
The data partition index associated with this node. Better suited for
prototyping purposes.
flwr_dir: Optional[Path] (default: None)
The fully resolved path containing installed Flower Apps.
"""
if insecure is None:
insecure = root_certificates is None
Expand Down Expand Up @@ -316,7 +320,7 @@ def _on_backoff(retry_state: RetryState) -> None:
)

node_state = NodeState(partition_id=partition_id)
run_info: Dict[int, Run] = {}
runs: Dict[int, Run] = {}

while not app_state_tracker.interrupt:
sleep_duration: int = 0
Expand Down Expand Up @@ -366,15 +370,17 @@ def _on_backoff(retry_state: RetryState) -> None:

# Get run info
run_id = message.metadata.run_id
if run_id not in run_info:
if run_id not in runs:
if get_run is not None:
run_info[run_id] = get_run(run_id)
runs[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
run_info[run_id] = Run(run_id, "", "", {})
runs[run_id] = Run(run_id, "", "", {})

# Register context for this run
node_state.register_context(run_id=run_id)
node_state.register_context(
run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir
)

# Retrieve context for this run
context = node_state.retrieve_context(run_id=run_id)
Expand All @@ -388,7 +394,7 @@ def _on_backoff(retry_state: RetryState) -> None:
# Handle app loading and task message
try:
# Load ClientApp instance
run: Run = run_info[run_id]
run: Run = runs[run_id]
client_app: ClientApp = load_client_app_fn(
run.fab_id, run.fab_version
)
Expand Down
44 changes: 36 additions & 8 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,53 @@
"""Node state."""


from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config
from flwr.common.typing import Run


@dataclass()
class RunInfo:
"""Contains the Context and initial run_config of a Run."""

context: Context
initial_run_config: Dict[str, str]


class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(self, partition_id: Optional[int]) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.run_contexts: Dict[int, Context] = {}
self.run_infos: Dict[int, RunInfo] = {}
self._partition_id = partition_id

def register_context(self, run_id: int) -> None:
def register_context(
self,
run_id: int,
run: Optional[Run] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Register new run context for this node."""
if run_id not in self.run_contexts:
self.run_contexts[run_id] = Context(
state=RecordSet(), run_config={}, partition_id=self._partition_id
if run_id not in self.run_infos:
initial_run_config = get_fused_config(run, flwr_dir) if run else {}
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
state=RecordSet(),
run_config=initial_run_config.copy(),
partition_id=self._partition_id,
),
)

def retrieve_context(self, run_id: int) -> Context:
"""Get run context given a run_id."""
if run_id in self.run_contexts:
return self.run_contexts[run_id]
if run_id in self.run_infos:
return self.run_infos[run_id].context

raise RuntimeError(
f"Context for run_id={run_id} doesn't exist."
Expand All @@ -48,4 +71,9 @@ def retrieve_context(self, run_id: int) -> Context:

def update_context(self, run_id: int, context: Context) -> None:
"""Update run context."""
self.run_contexts[run_id] = context
if context.run_config != self.run_infos[run_id].initial_run_config:
raise ValueError(
"The `run_config` field of the `Context` object cannot be "
f"modified (run_id: {run_id})."
)
self.run_infos[run_id].context = context
5 changes: 3 additions & 2 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None:
node_state.update_context(run_id=run_id, context=updated_state)

# Verify values
for run_id, context in node_state.run_contexts.items():
for run_id, run_info in node_state.run_infos.items():
assert (
context.state.configs_records["counter"]["count"] == expected_values[run_id]
run_info.context.state.configs_records["counter"]["count"]
== expected_values[run_id]
)
1 change: 1 addition & 0 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def run_supernode() -> None:
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
partition_id=args.partition_id,
flwr_dir=get_flwr_dir(args.flwr_dir),
)

# Graceful shutdown
Expand Down

0 comments on commit 2f2e346

Please sign in to comment.