diff --git a/e2e/bare/client.py b/e2e/bare/client.py index db83919a0c11..59ef2e4248ee 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -21,14 +21,14 @@ def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() value = str(t_stamp) - if STATE_VAR in self.state.configs.keys(): - value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + if STATE_VAR in self.context.state.configs.keys(): + value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore value += f",{t_stamp}" - self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) + self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) def _retrieve_timestamp_from_state(self): - return self.state.get_configs(STATE_VAR)[STATE_VAR] + return self.context.state.get_configs(STATE_VAR)[STATE_VAR] def fit(self, parameters, config): model_params = parameters diff --git a/e2e/pytorch/client.py b/e2e/pytorch/client.py index 53de31b7351b..ccd36f47d22a 100644 --- a/e2e/pytorch/client.py +++ b/e2e/pytorch/client.py @@ -97,14 +97,14 @@ def _record_timestamp_to_state(self): """Record timestamp to client's state.""" t_stamp = datetime.now().timestamp() value = str(t_stamp) - if STATE_VAR in self.state.configs.keys(): - value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore + if STATE_VAR in self.context.state.configs.keys(): + value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore value += f",{t_stamp}" - self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) + self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value})) def _retrieve_timestamp_from_state(self): - return self.state.get_configs(STATE_VAR)[STATE_VAR] + return self.context.state.get_configs(STATE_VAR)[STATE_VAR] def fit(self, parameters, config): set_parameters(net, parameters) train(net, trainloader, epochs=1) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index ae5beeae07d6..e1c9ec0cf9ae 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -352,7 +352,7 @@ def _load_app() -> Flower: break # Register state - node_state.register_runstate(run_id=task_ins.run_id) + node_state.register_context(run_id=task_ins.run_id) # Load app app: Flower = load_flower_callable_fn() @@ -360,14 +360,14 @@ def _load_app() -> Flower: # Handle task message fwd_msg: Fwd = Fwd( task_ins=task_ins, - state=node_state.retrieve_runstate(run_id=task_ins.run_id), + context=node_state.retrieve_context(run_id=task_ins.run_id), ) bwd_msg: Bwd = app(fwd=fwd_msg) # Update node state - node_state.update_runstate( - run_id=bwd_msg.task_res.run_id, - run_state=bwd_msg.state, + node_state.update_context( + run_id=fwd_msg.task_ins.run_id, + context=bwd_msg.context, ) # Send diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 8be7ff82f544..6d982ecc9a9e 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -32,13 +32,13 @@ Parameters, Status, ) -from flwr.common.recordset import RecordSet +from flwr.common.context import Context class Client(ABC): """Abstract base class for Flower clients.""" - state: RecordSet + context: Context def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -141,13 +141,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) - def get_state(self) -> RecordSet: - """Get the run state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: RecordSet) -> None: - """Apply a run state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Return client (itself).""" diff --git a/src/py/flwr/client/flower.py b/src/py/flwr/client/flower.py index 535f096e5866..157642c0fea6 100644 --- a/src/py/flwr/client/flower.py +++ b/src/py/flwr/client/flower.py @@ -56,12 +56,12 @@ def __init__( ) -> None: # Create wrapper function for `handle` def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name - task_res, state_updated = handle( + task_res, context_updated = handle( client_fn=client_fn, - state=fwd.state, + context=fwd.context, task_ins=fwd.task_ins, ) - return Bwd(task_res=task_res, state=state_updated) + return Bwd(task_res=task_res, context=context_updated) # Wrap middleware layers around the wrapped handle function self._call = make_ffn(ffn, layers if layers is not None else []) diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 2cb9df0f1cdd..8c920dcc585e 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -31,7 +31,7 @@ from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn from flwr.common import serde -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.proto.task_pb2 import ( # pylint: disable=E0611 SecureAggregation, Task, @@ -88,16 +88,16 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: def handle( - client_fn: ClientFn, state: RecordSet, task_ins: TaskIns -) -> Tuple[TaskRes, RecordSet]: + client_fn: ClientFn, context: Context, task_ins: TaskIns +) -> Tuple[TaskRes, Context]: """Handle incoming TaskIns from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : RecordSet - A dataclass storing the state for the run being executed by the client. + context : Context + A dataclass storing the context for the run being executed by the client. task_ins: TaskIns The task instruction coming from the server, to be processed by the client. @@ -110,7 +110,7 @@ def handle( if server_msg is None: # Instantiate the client client = client_fn("-1") - client.set_state(state) + client.set_context(context) # Secure Aggregation if task_ins.task.HasField("sa") and isinstance( client, SecureAggregationHandler @@ -127,24 +127,24 @@ def handle( sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), ), ) - return task_res, client.get_state() + return task_res, client.get_context() raise NotImplementedError() - client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg) + client_msg, updated_context = handle_legacy_message(client_fn, context, server_msg) task_res = wrap_client_message_in_task_res(client_msg) - return task_res, updated_state + return task_res, updated_context def handle_legacy_message( - client_fn: ClientFn, state: RecordSet, server_msg: ServerMessage -) -> Tuple[ClientMessage, RecordSet]: + client_fn: ClientFn, context: Context, server_msg: ServerMessage +) -> Tuple[ClientMessage, Context]: """Handle incoming messages from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : RecordSet - A dataclass storing the state for the run being executed by the client. + context : Context + A dataclass storing the context for the run being executed by the client. server_msg: ServerMessage The message coming from the server, to be processed by the client. @@ -161,7 +161,7 @@ def handle_legacy_message( # Instantiate the client client = client_fn("-1") - client.set_state(state) + client.set_context(context) # Execute task message = None if field == "get_properties_ins": @@ -173,7 +173,7 @@ def handle_legacy_message( if field == "evaluate_ins": message = _evaluate(client, server_msg.evaluate_ins) if message: - return message, client.get_state() + return message, client.get_context() raise UnknownServerMessage() diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 6f73169677c5..707570cd8e57 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -32,6 +32,7 @@ serde, typing, ) +from flwr.common.context import Context from flwr.common.recordset import RecordSet from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -141,7 +142,7 @@ def test_client_without_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=RecordSet(), + context=Context(state=RecordSet()), task_ins=task_ins, ) @@ -209,7 +210,7 @@ def test_client_with_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=RecordSet(), + context=Context(state=RecordSet()), task_ins=task_ins, ) diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 37a24ea7442f..fe3f7832ceb2 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -19,6 +19,7 @@ from typing import List from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer +from flwr.common.context import Context from flwr.common.recordset import RecordSet from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -45,7 +46,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: def app(fwd: Fwd) -> Bwd: footprint.append(name) fwd.task_ins.task_id += f"{name}" - return Bwd(task_res=TaskRes(task_id=name), state=RecordSet()) + return Bwd(task_res=TaskRes(task_id=name), context=Context(state=RecordSet())) return app @@ -66,7 +67,8 @@ def test_multiple_middlewares(self) -> None: # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res + dummy_context = Context(state=RecordSet()) + task_res = wrapped_app(Fwd(task_ins=task_ins, context=dummy_context)).task_res # Assert trace = mock_middleware_names + ["app"] @@ -86,11 +88,14 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: footprint.append("filter") fwd.task_ins.task_id += "filter" # Skip calling app - return Bwd(task_res=TaskRes(task_id="filter"), state=RecordSet()) + return Bwd( + task_res=TaskRes(task_id="filter"), context=Context(state=RecordSet()) + ) # Execute wrapped_app = make_ffn(mock_app, [filter_layer]) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res + dummy_context = Context(state=RecordSet()) + task_res = wrapped_app(Fwd(task_ins=task_ins, context=dummy_context)).task_res # Assert self.assertEqual(footprint, ["filter"]) diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index dd0f9913d73d..465bbd356c1c 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,6 +17,7 @@ from typing import Any, Dict +from flwr.common.context import Context from flwr.common.recordset import RecordSet @@ -25,24 +26,24 @@ class NodeState: def __init__(self) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.run_states: Dict[int, RecordSet] = {} + self.run_contexts: Dict[int, Context] = {} - def register_runstate(self, run_id: int) -> None: - """Register new run state for this node.""" - if run_id not in self.run_states: - self.run_states[run_id] = RecordSet() + def register_context(self, run_id: int) -> None: + """Register new run context for this node.""" + if run_id not in self.run_contexts: + self.run_contexts[run_id] = Context(state=RecordSet()) - def retrieve_runstate(self, run_id: int) -> RecordSet: - """Get run state given a run_id.""" - if run_id in self.run_states: - return self.run_states[run_id] + def retrieve_context(self, run_id: int) -> Context: + """Get run context given a run_id.""" + if run_id in self.run_contexts: + return self.run_contexts[run_id] raise RuntimeError( - f"RunState for run_id={run_id} doesn't exist." - " A run must be registered before it can be retrieved or updated " + f"Context for run_id={run_id} doesn't exist." + " A run context must be registered before it can be retrieved or updated " " by a client." ) - def update_runstate(self, run_id: int, run_state: RecordSet) -> None: - """Update run state.""" - self.run_states[run_id] = run_state + def update_context(self, run_id: int, context: Context) -> None: + """Update run context.""" + self.run_contexts[run_id] = context diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index d26dddc45fe2..11e5e74a31ec 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -17,19 +17,21 @@ from flwr.client.node_state import NodeState from flwr.common.configsrecord import ConfigsRecord -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 -def _run_dummy_task(state: RecordSet) -> RecordSet: +def _run_dummy_task(context: Context) -> Context: counter_value: str = "1" - if "counter" in state.configs.keys(): - counter_value = state.get_configs("counter")["count"] # type: ignore + if "counter" in context.state.configs.keys(): + counter_value = context.get_configs("counter")["count"] # type: ignore counter_value += "1" - state.set_configs(name="counter", record=ConfigsRecord({"count": counter_value})) + context.state.set_configs( + name="counter", record=ConfigsRecord({"count": counter_value}) + ) - return state + return context def test_multirun_in_node_state() -> None: @@ -46,17 +48,17 @@ def test_multirun_in_node_state() -> None: run_id = task.run_id # Register - node_state.register_runstate(run_id=run_id) + node_state.register_context(run_id=run_id) # Get run state - state = node_state.retrieve_runstate(run_id=run_id) + context = node_state.retrieve_context(run_id=run_id) # Run "task" - updated_state = _run_dummy_task(state) + updated_state = _run_dummy_task(context) # Update run state - node_state.update_runstate(run_id=run_id, run_state=updated_state) + node_state.update_context(run_id=run_id, context=updated_state) # Verify values - for run_id, state in node_state.run_states.items(): - assert state.get_configs("counter")["count"] == expected_values[run_id] + for run_id, context in node_state.run_contexts.items(): + assert context.state.get_configs("counter")["count"] == expected_values[run_id] diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index d0991fc27081..a77889912a09 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -26,7 +26,7 @@ ndarrays_to_parameters, parameters_to_ndarrays, ) -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.common.typing import ( Code, EvaluateIns, @@ -70,7 +70,7 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" - state: RecordSet + context: Context def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -174,13 +174,13 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} - def get_state(self) -> RecordSet: - """Get the run state from this client.""" - return self.state + def get_context(self) -> Context: + """Get the run context from this client.""" + return self.context - def set_state(self, state: RecordSet) -> None: - """Apply a run state to this client.""" - self.state = state + def set_context(self, context: Context) -> None: + """Apply a run context to this client.""" + self.context = context def to_client(self) -> Client: """Convert to object to Client type and return it.""" @@ -278,21 +278,21 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) -def _get_state(self: Client) -> RecordSet: - """Return state of underlying NumPyClient.""" - return self.numpy_client.get_state() # type: ignore +def _get_context(self: Client) -> Context: + """Return context of underlying NumPyClient.""" + return self.numpy_client.get_context() # type: ignore -def _set_state(self: Client, state: RecordSet) -> None: - """Apply state to underlying NumPyClient.""" - self.numpy_client.set_state(state) # type: ignore +def _set_context(self: Client, context: Context) -> None: + """Apply context to underlying NumPyClient.""" + self.numpy_client.set_context(context) # type: ignore def _wrap_numpy_client(client: NumPyClient) -> Client: member_dict: Dict[str, Callable] = { # type: ignore "__init__": _constructor, - "get_state": _get_state, - "set_state": _set_state, + "get_context": _get_context, + "set_context": _set_context, } # Add wrapper type methods (if overridden) diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index f2d159a2950c..8f7940405f42 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Callable -from flwr.common.recordset import RecordSet +from flwr.common.context import Context from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from .client import Client as Client @@ -28,7 +28,7 @@ class Fwd: """.""" task_ins: TaskIns - state: RecordSet + context: Context @dataclass @@ -36,7 +36,7 @@ class Bwd: """.""" task_res: TaskRes - state: RecordSet + context: Context FlowerCallable = Callable[[Fwd], Bwd] diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index aac73dfa2a62..853566a4cbeb 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -27,8 +27,8 @@ from flwr import common from flwr.client import Client, ClientFn +from flwr.common.context import Context from flwr.common.logger import log -from flwr.common.recordset import RecordSet from flwr.simulation.ray_transport.utils import check_clientfn_returns_client # All possible returns by a client @@ -61,8 +61,8 @@ def run( client_fn: ClientFn, job_fn: JobFn, cid: str, - state: RecordSet, - ) -> Tuple[str, ClientRes, RecordSet]: + context: Context, + ) -> Tuple[str, ClientRes, Context]: """Run a client run.""" # Execute tasks and return result # return also cid which is needed to ensure results @@ -70,12 +70,12 @@ def run( try: # Instantiate client (check 'Client' type is returned) client = check_clientfn_returns_client(client_fn(cid)) - # Inject state - client.set_state(state) + # Inject context + client.set_context(context) # Run client job job_results = job_fn(client) - # Retrieve state (potentially updated) - updated_state = client.get_state() + # Retrieve context (potentially updated) + updated_context = client.get_context() except Exception as ex: client_trace = traceback.format_exc() message = ( @@ -89,7 +89,7 @@ def run( ) raise ClientException(str(message)) from ex - return cid, job_results, updated_state + return cid, job_results, updated_context @ray.remote @@ -237,16 +237,16 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RecordSet]) -> None: + def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, Context]) -> None: """Take idle actor and assign it a client run. Submit a job to an actor by first removing it from the list of idle actors, then check if this actor was flagged to be removed from the pool """ - client_fn, job_fn, cid, state = value + client_fn, job_fn, cid, context = value actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): - future = fn(actor, client_fn, job_fn, cid, state) + future = fn(actor, client_fn, job_fn, cid, context) future_key = tuple(future) if isinstance(future, List) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -255,7 +255,7 @@ def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RecordSet]) -> None self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RecordSet] + self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -295,17 +295,17 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RecordSet]: - """Fetch result and updated state for a VirtualClient from Object Store. + def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, Context]: + """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is ready. Here we fetch it from the object store and return. """ try: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore - res_cid, res, updated_state = ray.get( + res_cid, res, updated_context = ray.get( future - ) # type: (str, ClientRes, RecordSet) + ) # type: (str, ClientRes, Context) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -322,7 +322,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RecordSet]: # Reset mapping self._reset_cid_to_future_dict(cid) - return res, updated_state + return res, updated_context def _flag_actor_for_removal(self, actor_id_hex: str) -> None: """Flag actor that should be removed from pool.""" @@ -409,7 +409,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, RecordSet]: + ) -> Tuple[ClientRes, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -421,5 +421,5 @@ def get_client_result( break # Fetch result belonging to the VirtualClient calling this method - # Return both result from tasks and (potentially) updated run state + # Return both result from tasks and (potentially) updated run context return self._fetch_future_result(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index 5c05850dfd2f..894012dc6d70 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -138,20 +138,20 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: run_id = 0 # Register state - self.proxy_state.register_runstate(run_id=run_id) + self.proxy_state.register_context(run_id=run_id) # Retrieve state - state = self.proxy_state.retrieve_runstate(run_id=run_id) + state = self.proxy_state.retrieve_context(run_id=run_id) try: self.actor_pool.submit_client_job( lambda a, c_fn, j_fn, cid, state: a.run.remote(c_fn, j_fn, cid, state), (self.client_fn, job_fn, self.cid, state), ) - res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) + res, updated_context = self.actor_pool.get_client_result(self.cid, timeout) # Update state - self.proxy_state.update_runstate(run_id=run_id, run_state=updated_state) + self.proxy_state.update_context(run_id=run_id, context=updated_context) except Exception as ex: if self.actor_pool.num_actors == 0: diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 99ed0f4010df..b380d37d01c8 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -24,6 +24,7 @@ from flwr.client import Client, NumPyClient from flwr.common import Code, GetPropertiesRes, Status from flwr.common.configsrecord import ConfigsRecord +from flwr.common.context import Context from flwr.common.recordset import RecordSet from flwr.simulation.ray_transport.ray_actor import ( ClientRes, @@ -54,8 +55,8 @@ def job_fn(cid: str) -> JobFn: # pragma: no cover def cid_times_pi(client: Client) -> ClientRes: # pylint: disable=unused-argument result = int(cid) * pi - # store something in state - client.numpy_client.state.set_configs( # type: ignore + # store something in context + client.numpy_client.context.state.set_configs( # type: ignore "result", record=ConfigsRecord({"result": str(result)}) ) @@ -128,9 +129,9 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: shuffle(proxies) for prox in proxies: # Register state - prox.proxy_state.register_runstate(run_id=run_id) + prox.proxy_state.register_context(run_id=run_id) # Retrieve state - state = prox.proxy_state.retrieve_runstate(run_id=run_id) + state = prox.proxy_state.retrieve_context(run_id=run_id) job = job_fn(prox.cid) prox.actor_pool.submit_client_job( @@ -141,14 +142,14 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: # fetch results one at a time shuffle(proxies) for prox in proxies: - res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) - prox.proxy_state.update_runstate(run_id, run_state=updated_state) + res, updated_context = prox.actor_pool.get_client_result(prox.cid, timeout=None) + prox.proxy_state.update_context(run_id, context=updated_context) res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] assert ( str(int(prox.cid) * pi) - == prox.proxy_state.retrieve_runstate(run_id).get_configs("result")[ + == prox.proxy_state.retrieve_context(run_id).state.get_configs("result")[ "result" ] ) @@ -168,7 +169,7 @@ def test_cid_consistency_without_proxies() -> None: job = job_fn(cid) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, RecordSet()), + (get_dummy_client, job, cid, Context(state=RecordSet())), ) # fetch results one at a time