From 9bfd38e55df0b763707ab2992ec38b568e21501d Mon Sep 17 00:00:00 2001 From: Javier Date: Fri, 26 Jan 2024 12:28:14 +0000 Subject: [PATCH 1/2] Replace `RunState` with `RecordSet` (#2855) --- e2e/bare/client.py | 14 +++++++---- e2e/pytorch/client.py | 14 ++++++----- src/py/flwr/client/client.py | 8 +++--- .../client/message_handler/message_handler.py | 14 +++++------ .../message_handler/message_handler_test.py | 6 ++--- src/py/flwr/client/middleware/utils_test.py | 10 ++++---- src/py/flwr/client/node_state.py | 10 ++++---- src/py/flwr/client/node_state_tests.py | 17 +++++++------ src/py/flwr/client/numpy_client.py | 12 ++++----- src/py/flwr/client/run_state.py | 25 ------------------- src/py/flwr/client/typing.py | 6 ++--- .../simulation/ray_transport/ray_actor.py | 16 ++++++------ .../ray_transport/ray_client_proxy_test.py | 14 ++++++++--- 13 files changed, 78 insertions(+), 88 deletions(-) delete mode 100644 src/py/flwr/client/run_state.py diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 8e5c3adff5e6..db83919a0c11 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -3,6 +3,8 @@ import flwr as fl import numpy as np +from flwr.common.configsrecord import ConfigsRecord + SUBSET_SIZE = 1000 STATE_VAR = 'timestamp' @@ -18,13 +20,15 @@ def get_parameters(self, config): def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() - if STATE_VAR in self.state.state: - self.state.state[STATE_VAR] += f",{t_stamp}" - else: - self.state.state[STATE_VAR] = str(t_stamp) + value = str(t_stamp) + if STATE_VAR in self.state.configs.keys(): + value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + value += f",{t_stamp}" + + self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) def _retrieve_timestamp_from_state(self): - return self.state.state[STATE_VAR] + return self.state.get_configs(STATE_VAR)[STATE_VAR] def fit(self, parameters, config): model_params = parameters diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index d180ad5d4eca..53de31b7351b 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -11,6 +11,7 @@ from tqdm import tqdm import flwr as fl +from flwr.common.configsrecord import ConfigsRecord # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -95,14 +96,15 @@ def get_parameters(self, config): def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() - if STATE_VAR in self.state.state: - self.state.state[STATE_VAR] += f",{t_stamp}" - else: - self.state.state[STATE_VAR] = str(t_stamp) + value = str(t_stamp) + if STATE_VAR in self.state.configs.keys(): + value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + value += f",{t_stamp}" + + self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) def _retrieve_timestamp_from_state(self): - return self.state.state[STATE_VAR] - + return self.state.get_configs(STATE_VAR)[STATE_VAR] def fit(self, parameters, config): set_parameters(net, parameters) train(net, trainloader, epochs=1) diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 54b53296fd2f..8be7ff82f544 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -19,7 +19,6 @@ from abc import ABC -from flwr.client.run_state import RunState from flwr.common import ( Code, EvaluateIns, @@ -33,12 +32,13 @@ Parameters, Status, ) +from flwr.common.recordset import RecordSet class Client(ABC): """Abstract base class for Flower clients.""" - state: RunState + state: RecordSet def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -141,11 +141,11 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) - def get_state(self) -> RunState: + def get_state(self) -> RecordSet: """Get the run state from this client.""" return self.state - def set_state(self, state: RunState) -> None: + def set_state(self, state: RecordSet) -> None: """Apply a run state to this client.""" self.state = state diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 8cfe909c1738..2cb9df0f1cdd 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -28,10 +28,10 @@ get_server_message_from_task_ins, wrap_client_message_in_task_res, ) -from flwr.client.run_state import RunState from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde +from flwr.common.recordset import RecordSet from flwr.proto.task_pb2 import ( # pylint: disable=E0611 SecureAggregation, Task, @@ -88,15 +88,15 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: def handle( - client_fn: ClientFn, state: RunState, task_ins: TaskIns -) -> Tuple[TaskRes, RunState]: + client_fn: ClientFn, state: RecordSet, task_ins: TaskIns +) -> Tuple[TaskRes, RecordSet]: """Handle incoming TaskIns from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : RunState + state : RecordSet A dataclass storing the state for the run being executed by the client. task_ins: TaskIns The task instruction coming from the server, to be processed by the client. @@ -135,15 +135,15 @@ def handle( def handle_legacy_message( - client_fn: ClientFn, state: RunState, server_msg: ServerMessage -) -> Tuple[ClientMessage, RunState]: + client_fn: ClientFn, state: RecordSet, server_msg: ServerMessage +) -> Tuple[ClientMessage, RecordSet]: """Handle incoming messages from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : RunState + state : RecordSet A dataclass storing the state for the run being executed by the client. server_msg: ServerMessage The message coming from the server, to be processed by the client. diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 194f75fe30ca..6f73169677c5 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -18,7 +18,6 @@ import uuid from flwr.client import Client -from flwr.client.run_state import RunState from flwr.client.typing import ClientFn from flwr.common import ( EvaluateIns, @@ -33,6 +32,7 @@ serde, typing, ) +from flwr.common.recordset import RecordSet from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 @@ -141,7 +141,7 @@ def test_client_without_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=RunState(state={}), + state=RecordSet(), task_ins=task_ins, ) @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=RunState(state={}), + state=RecordSet(), task_ins=task_ins, ) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 006fe6db4799..37a24ea7442f 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -18,8 +18,8 @@ import unittest from typing import List -from flwr.client.run_state import RunState from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.common.recordset import RecordSet from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .utils import make_ffn @@ -45,7 +45,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: def app(fwd: Fwd) -> Bwd: footprint.append(name) fwd.task_ins.task_id += f"{name}" - return Bwd(task_res=TaskRes(task_id=name), state=RunState({})) + return Bwd(task_res=TaskRes(task_id=name), state=RecordSet()) return app @@ -66,7 +66,7 @@ def test_multiple_middlewares(self) -> None: # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res + task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res # Assert trace = mock_middleware_names + ["app"] @@ -86,11 +86,11 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: footprint.append("filter") fwd.task_ins.task_id += "filter" # Skip calling app - return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({})) + return Bwd(task_res=TaskRes(task_id="filter"), state=RecordSet()) # Execute wrapped_app = make_ffn(mock_app, [filter_layer]) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res + task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res # Assert self.assertEqual(footprint, ["filter"]) diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index 0a29be511806..dd0f9913d73d 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,7 +17,7 @@ from typing import Any, Dict -from flwr.client.run_state import RunState +from flwr.common.recordset import RecordSet class NodeState: @@ -25,14 +25,14 @@ class NodeState: def __init__(self) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.run_states: Dict[int, RunState] = {} + self.run_states: Dict[int, RecordSet] = {} def register_runstate(self, run_id: int) -> None: """Register new run state for this node.""" if run_id not in self.run_states: - self.run_states[run_id] = RunState({}) + self.run_states[run_id] = RecordSet() - def retrieve_runstate(self, run_id: int) -> RunState: + def retrieve_runstate(self, run_id: int) -> RecordSet: """Get run state given a run_id.""" if run_id in self.run_states: return self.run_states[run_id] @@ -43,6 +43,6 @@ def retrieve_runstate(self, run_id: int) -> RunState: " by a client." ) - def update_runstate(self, run_id: int, run_state: RunState) -> None: + def update_runstate(self, run_id: int, run_state: RecordSet) -> None: """Update run state.""" self.run_states[run_id] = run_state diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 7bc0d77d16cf..d26dddc45fe2 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -16,15 +16,18 @@ from flwr.client.node_state import NodeState -from flwr.client.run_state import RunState +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.recordset import RecordSet from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -def _run_dummy_task(state: RunState) -> RunState: - if "counter" in state.state: - state.state["counter"] += "1" - else: - state.state["counter"] = "1" +def _run_dummy_task(state: RecordSet) -> RecordSet: + counter_value: str = "1" + if "counter" in state.configs.keys(): + counter_value = state.get_configs("counter")["count"] # type: ignore + counter_value += "1" + + state.set_configs(name="counter", record=ConfigsRecord({"count": counter_value})) return state @@ -56,4 +59,4 @@ def test_multirun_in_node_state() -> None: # Verify values for run_id, state in node_state.run_states.items(): - assert state.state["counter"] == expected_values[run_id] + assert state.get_configs("counter")["count"] == expected_values[run_id] diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index d67fb90512d4..d0991fc27081 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -19,7 +19,6 @@ from typing import Callable, Dict, Tuple from flwr.client.client import Client -from flwr.client.run_state import RunState from flwr.common import ( Config, NDArrays, @@ -27,6 +26,7 @@ ndarrays_to_parameters, parameters_to_ndarrays, ) +from flwr.common.recordset import RecordSet from flwr.common.typing import ( Code, EvaluateIns, @@ -70,7 +70,7 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" - state: RunState + state: RecordSet def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -174,11 +174,11 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} - def get_state(self) -> RunState: + def get_state(self) -> RecordSet: """Get the run state from this client.""" return self.state - def set_state(self, state: RunState) -> None: + def set_state(self, state: RecordSet) -> None: """Apply a run state to this client.""" self.state = state @@ -278,12 +278,12 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) -def _get_state(self: Client) -> RunState: +def _get_state(self: Client) -> RecordSet: """Return state of underlying NumPyClient.""" return self.numpy_client.get_state() # type: ignore -def _set_state(self: Client, state: RunState) -> None: +def _set_state(self: Client, state: RecordSet) -> None: """Apply state to underlying NumPyClient.""" self.numpy_client.set_state(state) # type: ignore diff --git a/src/py/flwr/client/run_state.py b/src/py/flwr/client/run_state.py deleted file mode 100644 index c2755eb995eb..000000000000 --- a/src/py/flwr/client/run_state.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Run state.""" - -from dataclasses import dataclass -from typing import Dict - - -@dataclass -class RunState: - """State of a run executed by a client node.""" - - state: Dict[str, str] diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 5291afb83d98..f2d159a2950c 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Callable -from flwr.client.run_state import RunState +from flwr.common.recordset import RecordSet from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .client import Client as Client @@ -28,7 +28,7 @@ class Fwd: """.""" task_ins: TaskIns - state: RunState + state: RecordSet @dataclass @@ -36,7 +36,7 @@ class Bwd: """.""" task_res: TaskRes - state: RunState + state: RecordSet FlowerCallable = Callable[[Fwd], Bwd] diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 38af3f08daa2..aac73dfa2a62 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -27,8 +27,8 @@ from flwr import common from flwr.client import Client, ClientFn -from flwr.client.run_state import RunState from flwr.common.logger import log +from flwr.common.recordset import RecordSet from flwr.simulation.ray_transport.utils import check_clientfn_returns_client # All possible returns by a client @@ -61,8 +61,8 @@ def run( client_fn: ClientFn, job_fn: JobFn, cid: str, - state: RunState, - ) -> Tuple[str, ClientRes, RunState]: + state: RecordSet, + ) -> Tuple[str, ClientRes, RecordSet]: """Run a client run.""" # Execute tasks and return result # return also cid which is needed to ensure results @@ -237,7 +237,7 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: + def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RecordSet]) -> None: """Take idle actor and assign it a client run. Submit a job to an actor by first removing it from the list of idle actors, then @@ -255,7 +255,7 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RunState] + self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RecordSet] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,7 +295,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: + def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RecordSet]: """Fetch result and updated state for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -305,7 +305,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore res_cid, res, updated_state = ray.get( future - ) # type: (str, ClientRes, RunState) + ) # type: (str, ClientRes, RecordSet) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -409,7 +409,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, RunState]: + ) -> Tuple[ClientRes, RecordSet]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 9df71635b949..99ed0f4010df 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -22,8 +22,9 @@ import ray from flwr.client import Client, NumPyClient -from flwr.client.run_state import RunState from flwr.common import Code, GetPropertiesRes, Status +from flwr.common.configsrecord import ConfigsRecord +from flwr.common.recordset import RecordSet from flwr.simulation.ray_transport.ray_actor import ( ClientRes, DefaultActor, @@ -54,7 +55,9 @@ def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argumen result = int(cid) * pi # store something in state - client.numpy_client.state.state["result"] = str(result) # type: ignore + client.numpy_client.state.set_configs( # type: ignore + "result", record=ConfigsRecord({"result": str(result)}) + ) # now let's convert it to a GetPropertiesRes response return GetPropertiesRes( @@ -141,10 +144,13 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) prox.proxy_state.update_runstate(run_id, run_state=updated_state) res = cast(GetPropertiesRes, res) + assert int(prox.cid) * pi == res.properties["result"] assert ( str(int(prox.cid) * pi) - == prox.proxy_state.retrieve_runstate(run_id).state["result"] + == prox.proxy_state.retrieve_runstate(run_id).get_configs("result")[ + "result" + ] ) ray.shutdown() @@ -162,7 +168,7 @@ def test_cid_consistency_without_proxies() -> None: job = job_fn(cid) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, RunState(state={})), + (get_dummy_client, job, cid, RecordSet()), ) # fetch results one at a time From 08667701195d52e7699d1f0c2dfa4149b228ad40 Mon Sep 17 00:00:00 2001 From: Javier Date: Fri, 26 Jan 2024 14:10:21 +0000 Subject: [PATCH 2/2] Replace client's state with `common.Context` (#2858) Co-authored-by: Daniel J. Beutel --- e2e/bare/client.py | 8 ++-- e2e/pytorch/client.py | 8 ++-- src/py/flwr/client/app.py | 10 ++--- src/py/flwr/client/client.py | 16 ++++---- src/py/flwr/client/flower.py | 6 +-- .../client/message_handler/message_handler.py | 30 +++++++-------- .../message_handler/message_handler_test.py | 5 ++- src/py/flwr/client/middleware/utils_test.py | 13 +++++-- src/py/flwr/client/node_state.py | 29 +++++++------- src/py/flwr/client/node_state_tests.py | 26 +++++++------ src/py/flwr/client/numpy_client.py | 32 ++++++++-------- src/py/flwr/client/typing.py | 6 +-- .../simulation/ray_transport/ray_actor.py | 38 +++++++++---------- .../ray_transport/ray_client_proxy.py | 8 ++-- .../ray_transport/ray_client_proxy_test.py | 17 +++++---- 15 files changed, 131 insertions(+), 121 deletions(-) diff --git a/e2e/bare/client.py b/e2e/bare/client.py index db83919a0c11..59ef2e4248ee 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -21,14 +21,14 @@ def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() value = str(t_stamp) - if STATE_VAR in self.state.configs.keys(): - value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + if STATE_VAR in self.context.state.configs.keys(): + value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore value += f",{t_stamp}" - self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) + self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) def _retrieve_timestamp_from_state(self): - return self.state.get_configs(STATE_VAR)[STATE_VAR] + return self.context.state.get_configs(STATE_VAR)[STATE_VAR] def fit(self, parameters, config): model_params = parameters diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index 53de31b7351b..ccd36f47d22a 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -97,14 +97,14 @@ def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() value = str(t_stamp) - if STATE_VAR in self.state.configs.keys(): - value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + if STATE_VAR in self.context.state.configs.keys(): + value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore value += f",{t_stamp}" - self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) + self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) def _retrieve_timestamp_from_state(self): - return self.state.get_configs(STATE_VAR)[STATE_VAR] + return self.context.state.get_configs(STATE_VAR)[STATE_VAR] def fit(self, parameters, config): set_parameters(net, parameters) train(net, trainloader, epochs=1) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ae5beeae07d6..e1c9ec0cf9ae 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -352,7 +352,7 @@ def _load_app() -> Flower: break # Register state - node_state.register_runstate(run_id=task_ins.run_id) + node_state.register_context(run_id=task_ins.run_id) # Load app app: Flower = load_flower_callable_fn() @@ -360,14 +360,14 @@ def _load_app() -> Flower: # Handle task message fwd_msg: Fwd = Fwd( task_ins=task_ins, - state=node_state.retrieve_runstate(run_id=task_ins.run_id), + context=node_state.retrieve_context(run_id=task_ins.run_id), ) bwd_msg: Bwd = app(fwd=fwd_msg) # Update node state - node_state.update_runstate( - run_id=bwd_msg.task_res.run_id, - run_state=bwd_msg.state, + node_state.update_context( + run_id=fwd_msg.task_ins.run_id, + context=bwd_msg.context, ) # Send diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 8be7ff82f544..6d982ecc9a9e 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -32,13 +32,13 @@ Parameters, Status, ) -from flwr.common.recordset import RecordSet +from flwr.common.context import Context class Client(ABC): """Abstract base class for Flower clients.""" - state: RecordSet + context: Context def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -141,13 +141,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) - def get_state(self) -> RecordSet: - """Get the run state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: RecordSet) -> None: - """Apply a run state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Return client (itself).""" diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 535f096e5866..157642c0fea6 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -56,12 +56,12 @@ def __init__( ) -> None: # Create wrapper function for `handle` def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name - task_res, state_updated = handle( + task_res, context_updated = handle( client_fn=client_fn, - state=fwd.state, + context=fwd.context, task_ins=fwd.task_ins, ) - return Bwd(task_res=task_res, state=state_updated) + return Bwd(task_res=task_res, context=context_updated) # Wrap middleware layers around the wrapped handle function self._call = make_ffn(ffn, layers if layers is not None else []) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 2cb9df0f1cdd..8c920dcc585e 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -31,7 +31,7 @@ from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.proto.task_pb2 import ( # pylint: disable=E0611 SecureAggregation, Task, @@ -88,16 +88,16 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: def handle( - client_fn: ClientFn, state: RecordSet, task_ins: TaskIns -) -> Tuple[TaskRes, RecordSet]: + client_fn: ClientFn, context: Context, task_ins: TaskIns +) -> Tuple[TaskRes, Context]: """Handle incoming TaskIns from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : RecordSet - A dataclass storing the state for the run being executed by the client. + context : Context + A dataclass storing the context for the run being executed by the client. task_ins: TaskIns The task instruction coming from the server, to be processed by the client. @@ -110,7 +110,7 @@ def handle( if server_msg is None: # Instantiate the client client = client_fn("-1") - client.set_state(state) + client.set_context(context) # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -127,24 +127,24 @@ def handle( sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), ), ) - return task_res, client.get_state() + return task_res, client.get_context() raise NotImplementedError() - client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) + client_msg, updated_context = handle_legacy_message(client_fn, context, server_msg) task_res = wrap_client_message_in_task_res(client_msg) - return task_res, updated_state + return task_res, updated_context def handle_legacy_message( - client_fn: ClientFn, state: RecordSet, server_msg: ServerMessage -) -> Tuple[ClientMessage, RecordSet]: + client_fn: ClientFn, context: Context, server_msg: ServerMessage +) -> Tuple[ClientMessage, Context]: """Handle incoming messages from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : RecordSet - A dataclass storing the state for the run being executed by the client. + context : Context + A dataclass storing the context for the run being executed by the client. server_msg: ServerMessage The message coming from the server, to be processed by the client. @@ -161,7 +161,7 @@ def handle_legacy_message( # Instantiate the client client = client_fn("-1") - client.set_state(state) + client.set_context(context) # Execute task message = None if field == "get_properties_ins": @@ -173,7 +173,7 @@ def handle_legacy_message( if field == "evaluate_ins": message = _evaluate(client, server_msg.evaluate_ins) if message: - return message, client.get_state() + return message, client.get_context() raise UnknownServerMessage() diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 6f73169677c5..707570cd8e57 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -32,6 +32,7 @@ serde, typing, ) +from flwr.common.context import Context from flwr.common.recordset import RecordSet from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -141,7 +142,7 @@ def test_client_without_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=RecordSet(), + context=Context(state=RecordSet()), task_ins=task_ins, ) @@ -209,7 +210,7 @@ def test_client_with_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=RecordSet(), + context=Context(state=RecordSet()), task_ins=task_ins, ) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 37a24ea7442f..fe3f7832ceb2 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -19,6 +19,7 @@ from typing import List from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.common.context import Context from flwr.common.recordset import RecordSet from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -45,7 +46,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: def app(fwd: Fwd) -> Bwd: footprint.append(name) fwd.task_ins.task_id += f"{name}" - return Bwd(task_res=TaskRes(task_id=name), state=RecordSet()) + return Bwd(task_res=TaskRes(task_id=name), context=Context(state=RecordSet())) return app @@ -66,7 +67,8 @@ def test_multiple_middlewares(self) -> None: # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res + dummy_context = Context(state=RecordSet()) + task_res = wrapped_app(Fwd(task_ins=task_ins, context=dummy_context)).task_res # Assert trace = mock_middleware_names + ["app"] @@ -86,11 +88,14 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: footprint.append("filter") fwd.task_ins.task_id += "filter" # Skip calling app - return Bwd(task_res=TaskRes(task_id="filter"), state=RecordSet()) + return Bwd( + task_res=TaskRes(task_id="filter"), context=Context(state=RecordSet()) + ) # Execute wrapped_app = make_ffn(mock_app, [filter_layer]) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res + dummy_context = Context(state=RecordSet()) + task_res = wrapped_app(Fwd(task_ins=task_ins, context=dummy_context)).task_res # Assert self.assertEqual(footprint, ["filter"]) diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index dd0f9913d73d..465bbd356c1c 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,6 +17,7 @@ from typing import Any, Dict +from flwr.common.context import Context from flwr.common.recordset import RecordSet @@ -25,24 +26,24 @@ class NodeState: def __init__(self) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.run_states: Dict[int, RecordSet] = {} + self.run_contexts: Dict[int, Context] = {} - def register_runstate(self, run_id: int) -> None: - """Register new run state for this node.""" - if run_id not in self.run_states: - self.run_states[run_id] = RecordSet() + def register_context(self, run_id: int) -> None: + """Register new run context for this node.""" + if run_id not in self.run_contexts: + self.run_contexts[run_id] = Context(state=RecordSet()) - def retrieve_runstate(self, run_id: int) -> RecordSet: - """Get run state given a run_id.""" - if run_id in self.run_states: - return self.run_states[run_id] + def retrieve_context(self, run_id: int) -> Context: + """Get run context given a run_id.""" + if run_id in self.run_contexts: + return self.run_contexts[run_id] raise RuntimeError( - f"RunState for run_id={run_id} doesn't exist." - " A run must be registered before it can be retrieved or updated " + f"Context for run_id={run_id} doesn't exist." + " A run context must be registered before it can be retrieved or updated " " by a client." ) - def update_runstate(self, run_id: int, run_state: RecordSet) -> None: - """Update run state.""" - self.run_states[run_id] = run_state + def update_context(self, run_id: int, context: Context) -> None: + """Update run context.""" + self.run_contexts[run_id] = context diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index d26dddc45fe2..11e5e74a31ec 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -17,19 +17,21 @@ from flwr.client.node_state import NodeState from flwr.common.configsrecord import ConfigsRecord -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -def _run_dummy_task(state: RecordSet) -> RecordSet: +def _run_dummy_task(context: Context) -> Context: counter_value: str = "1" - if "counter" in state.configs.keys(): - counter_value = state.get_configs("counter")["count"] # type: ignore + if "counter" in context.state.configs.keys(): + counter_value = context.get_configs("counter")["count"] # type: ignore counter_value += "1" - state.set_configs(name="counter", record=ConfigsRecord({"count": counter_value})) + context.state.set_configs( + name="counter", record=ConfigsRecord({"count": counter_value}) + ) - return state + return context def test_multirun_in_node_state() -> None: @@ -46,17 +48,17 @@ def test_multirun_in_node_state() -> None: run_id = task.run_id # Register - node_state.register_runstate(run_id=run_id) + node_state.register_context(run_id=run_id) # Get run state - state = node_state.retrieve_runstate(run_id=run_id) + context = node_state.retrieve_context(run_id=run_id) # Run "task" - updated_state = _run_dummy_task(state) + updated_state = _run_dummy_task(context) # Update run state - node_state.update_runstate(run_id=run_id, run_state=updated_state) + node_state.update_context(run_id=run_id, context=updated_state) # Verify values - for run_id, state in node_state.run_states.items(): - assert state.get_configs("counter")["count"] == expected_values[run_id] + for run_id, context in node_state.run_contexts.items(): + assert context.state.get_configs("counter")["count"] == expected_values[run_id] diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index d0991fc27081..a77889912a09 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -26,7 +26,7 @@ ndarrays_to_parameters, parameters_to_ndarrays, ) -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.common.typing import ( Code, EvaluateIns, @@ -70,7 +70,7 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" - state: RecordSet + context: Context def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -174,13 +174,13 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} - def get_state(self) -> RecordSet: - """Get the run state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: RecordSet) -> None: - """Apply a run state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Convert to object to Client type and return it.""" @@ -278,21 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) -def _get_state(self: Client) -> RecordSet: - """Return state of underlying NumPyClient.""" - return self.numpy_client.get_state() # type: ignore +def _get_context(self: Client) -> Context: + """Return context of underlying NumPyClient.""" + return self.numpy_client.get_context() # type: ignore -def _set_state(self: Client, state: RecordSet) -> None: - """Apply state to underlying NumPyClient.""" - self.numpy_client.set_state(state) # type: ignore +def _set_context(self: Client, context: Context) -> None: + """Apply context to underlying NumPyClient.""" + self.numpy_client.set_context(context) # type: ignore def _wrap_numpy_client(client: NumPyClient) -> Client: member_dict: Dict[str, Callable] = { # type: ignore "__init__": _constructor, - "get_state": _get_state, - "set_state": _set_state, + "get_context": _get_context, + "set_context": _set_context, } # Add wrapper type methods (if overridden) diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index f2d159a2950c..8f7940405f42 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Callable -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .client import Client as Client @@ -28,7 +28,7 @@ class Fwd: """.""" task_ins: TaskIns - state: RecordSet + context: Context @dataclass @@ -36,7 +36,7 @@ class Bwd: """.""" task_res: TaskRes - state: RecordSet + context: Context FlowerCallable = Callable[[Fwd], Bwd] diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index aac73dfa2a62..853566a4cbeb 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -27,8 +27,8 @@ from flwr import common from flwr.client import Client, ClientFn +from flwr.common.context import Context from flwr.common.logger import log -from flwr.common.recordset import RecordSet from flwr.simulation.ray_transport.utils import check_clientfn_returns_client # All possible returns by a client @@ -61,8 +61,8 @@ def run( client_fn: ClientFn, job_fn: JobFn, cid: str, - state: RecordSet, - ) -> Tuple[str, ClientRes, RecordSet]: + context: Context, + ) -> Tuple[str, ClientRes, Context]: """Run a client run.""" # Execute tasks and return result # return also cid which is needed to ensure results @@ -70,12 +70,12 @@ def run( try: # Instantiate client (check 'Client' type is returned) client = check_clientfn_returns_client(client_fn(cid)) - # Inject state - client.set_state(state) + # Inject context + client.set_context(context) # Run client job job_results = job_fn(client) - # Retrieve state (potentially updated) - updated_state = client.get_state() + # Retrieve context (potentially updated) + updated_context = client.get_context() except Exception as ex: client_trace = traceback.format_exc() message = ( @@ -89,7 +89,7 @@ def run( ) raise ClientException(str(message)) from ex - return cid, job_results, updated_state + return cid, job_results, updated_context @ray.remote @@ -237,16 +237,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RecordSet]) -> None: + def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: """Take idle actor and assign it a client run. Submit a job to an actor by first removing it from the list of idle actors, then check if this actor was flagged to be removed from the pool """ - client_fn, job_fn, cid, state = value + client_fn, job_fn, cid, context = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, state) + future = fn(actor, client_fn, job_fn, cid, context) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -255,7 +255,7 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RecordSet]) -> None self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RecordSet] + self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,17 +295,17 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RecordSet]: - """Fetch result and updated state for a VirtualClient from Object Store. + def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: + """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is ready. Here we fetch it from the object store and return. """ try: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore - res_cid, res, updated_state = ray.get( + res_cid, res, updated_context = ray.get( future - ) # type: (str, ClientRes, RecordSet) + ) # type: (str, ClientRes, Context) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -322,7 +322,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RecordSet]: # Reset mapping self._reset_cid_to_future_dict(cid) - return res, updated_state + return res, updated_context def _flag_actor_for_removal(self, actor_id_hex: str) -> None: """Flag actor that should be removed from pool.""" @@ -409,7 +409,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, RecordSet]: + ) -> Tuple[ClientRes, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -421,5 +421,5 @@ def get_client_result( break # Fetch result belonging to the VirtualClient calling this method - # Return both result from tasks and (potentially) updated run state + # Return both result from tasks and (potentially) updated run context return self._fetch_future_result(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 5c05850dfd2f..894012dc6d70 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -138,20 +138,20 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: run_id = 0 # Register state - self.proxy_state.register_runstate(run_id=run_id) + self.proxy_state.register_context(run_id=run_id) # Retrieve state - state = self.proxy_state.retrieve_runstate(run_id=run_id) + state = self.proxy_state.retrieve_context(run_id=run_id) try: self.actor_pool.submit_client_job( lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), (self.client_fn, job_fn, self.cid, state), ) - res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) + res, updated_context = self.actor_pool.get_client_result(self.cid, timeout) # Update state - self.proxy_state.update_runstate(run_id=run_id, run_state=updated_state) + self.proxy_state.update_context(run_id=run_id, context=updated_context) except Exception as ex: if self.actor_pool.num_actors == 0: diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 99ed0f4010df..b380d37d01c8 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -24,6 +24,7 @@ from flwr.client import Client, NumPyClient from flwr.common import Code, GetPropertiesRes, Status from flwr.common.configsrecord import ConfigsRecord +from flwr.common.context import Context from flwr.common.recordset import RecordSet from flwr.simulation.ray_transport.ray_actor import ( ClientRes, @@ -54,8 +55,8 @@ def job_fn(cid: str) -> JobFn: # pragma: no cover def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument result = int(cid) * pi - # store something in state - client.numpy_client.state.set_configs( # type: ignore + # store something in context + client.numpy_client.context.state.set_configs( # type: ignore "result", record=ConfigsRecord({"result": str(result)}) ) @@ -128,9 +129,9 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: shuffle(proxies) for prox in proxies: # Register state - prox.proxy_state.register_runstate(run_id=run_id) + prox.proxy_state.register_context(run_id=run_id) # Retrieve state - state = prox.proxy_state.retrieve_runstate(run_id=run_id) + state = prox.proxy_state.retrieve_context(run_id=run_id) job = job_fn(prox.cid) prox.actor_pool.submit_client_job( @@ -141,14 +142,14 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: # fetch results one at a time shuffle(proxies) for prox in proxies: - res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) - prox.proxy_state.update_runstate(run_id, run_state=updated_state) + res, updated_context = prox.actor_pool.get_client_result(prox.cid, timeout=None) + prox.proxy_state.update_context(run_id, context=updated_context) res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] assert ( str(int(prox.cid) * pi) - == prox.proxy_state.retrieve_runstate(run_id).get_configs("result")[ + == prox.proxy_state.retrieve_context(run_id).state.get_configs("result")[ "result" ] ) @@ -168,7 +169,7 @@ def test_cid_consistency_without_proxies() -> None: job = job_fn(cid) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, RecordSet()), + (get_dummy_client, job, cid, Context(state=RecordSet())), ) # fetch results one at a time