From 2f2e346aa9e91576f65ecb23806e9a321bbf70a0 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Thu, 11 Jul 2024 16:48:11 +0200 Subject: [PATCH] feat(framework) Add `run_config` to `ClientApp` `Context` (#3751) Co-authored-by: Daniel J. Beutel --- src/py/flwr/client/app.py | 18 +++++++---- src/py/flwr/client/node_state.py | 44 +++++++++++++++++++++----- src/py/flwr/client/node_state_tests.py | 5 +-- src/py/flwr/client/supernode/app.py | 1 + 4 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 15d384cb74a2..851083d4abb7 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -19,6 +19,7 @@ import time from dataclasses import dataclass from logging import DEBUG, ERROR, INFO, WARN +from pathlib import Path from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union from cryptography.hazmat.primitives.asymmetric import ec @@ -193,6 +194,7 @@ def _start_client_internal( max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, partition_id: Optional[int] = None, + flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -239,6 +241,8 @@ class `flwr.client.Client` (default: None) partition_id: Optional[int] (default: None) The data partition index associated with this node. Better suited for prototyping purposes. + flwr_dir: Optional[Path] (default: None) + The fully resolved path containing installed Flower Apps. """ if insecure is None: insecure = root_certificates is None @@ -316,7 +320,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) node_state = NodeState(partition_id=partition_id) - run_info: Dict[int, Run] = {} + runs: Dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -366,15 +370,17 @@ def _on_backoff(retry_state: RetryState) -> None: # Get run info run_id = message.metadata.run_id - if run_id not in run_info: + if run_id not in runs: if get_run is not None: - run_info[run_id] = get_run(run_id) + runs[run_id] = get_run(run_id) # If get_run is None, i.e., in grpc-bidi mode else: - run_info[run_id] = Run(run_id, "", "", {}) + runs[run_id] = Run(run_id, "", "", {}) # Register context for this run - node_state.register_context(run_id=run_id) + node_state.register_context( + run_id=run_id, run=runs[run_id], flwr_dir=flwr_dir + ) # Retrieve context for this run context = node_state.retrieve_context(run_id=run_id) @@ -388,7 +394,7 @@ def _on_backoff(retry_state: RetryState) -> None: # Handle app loading and task message try: # Load ClientApp instance - run: Run = run_info[run_id] + run: Run = runs[run_id] client_app: ClientApp = load_client_app_fn( run.fab_id, run.fab_version ) diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 64a0d348b23e..2b090eba9720 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -15,9 +15,21 @@ """Node state.""" +from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, Optional from flwr.common import Context, RecordSet +from flwr.common.config import get_fused_config +from flwr.common.typing import Run + + +@dataclass() +class RunInfo: + """Contains the Context and initial run_config of a Run.""" + + context: Context + initial_run_config: Dict[str, str] class NodeState: @@ -25,20 +37,31 @@ class NodeState: def __init__(self, partition_id: Optional[int]) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.run_contexts: Dict[int, Context] = {} + self.run_infos: Dict[int, RunInfo] = {} self._partition_id = partition_id - def register_context(self, run_id: int) -> None: + def register_context( + self, + run_id: int, + run: Optional[Run] = None, + flwr_dir: Optional[Path] = None, + ) -> None: """Register new run context for this node.""" - if run_id not in self.run_contexts: - self.run_contexts[run_id] = Context( - state=RecordSet(), run_config={}, partition_id=self._partition_id + if run_id not in self.run_infos: + initial_run_config = get_fused_config(run, flwr_dir) if run else {} + self.run_infos[run_id] = RunInfo( + initial_run_config=initial_run_config, + context=Context( + state=RecordSet(), + run_config=initial_run_config.copy(), + partition_id=self._partition_id, + ), ) def retrieve_context(self, run_id: int) -> Context: """Get run context given a run_id.""" - if run_id in self.run_contexts: - return self.run_contexts[run_id] + if run_id in self.run_infos: + return self.run_infos[run_id].context raise RuntimeError( f"Context for run_id={run_id} doesn't exist." @@ -48,4 +71,9 @@ def retrieve_context(self, run_id: int) -> Context: def update_context(self, run_id: int, context: Context) -> None: """Update run context.""" - self.run_contexts[run_id] = context + if context.run_config != self.run_infos[run_id].initial_run_config: + raise ValueError( + "The `run_config` field of the `Context` object cannot be " + f"modified (run_id: {run_id})." + ) + self.run_infos[run_id].context = context diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 311dbd41d742..effd64a3ae7a 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -59,7 +59,8 @@ def test_multirun_in_node_state() -> None: node_state.update_context(run_id=run_id, context=updated_state) # Verify values - for run_id, context in node_state.run_contexts.items(): + for run_id, run_info in node_state.run_infos.items(): assert ( - context.state.configs_records["counter"]["count"] == expected_values[run_id] + run_info.context.state.configs_records["counter"]["count"] + == expected_values[run_id] ) diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 4115c57d4738..355a2a13a0e5 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -68,6 +68,7 @@ def run_supernode() -> None: max_retries=args.max_retries, max_wait_time=args.max_wait_time, partition_id=args.partition_id, + flwr_dir=get_flwr_dir(args.flwr_dir), ) # Graceful shutdown