Skip to content

Commit

Permalink
Replace client's state with common.Context (#2858)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
jafermarq and danieljanes authored Jan 26, 2024
1 parent 9bfd38e commit 0866770
Show file tree
Hide file tree
Showing 15 changed files with 131 additions and 121 deletions.
8 changes: 4 additions & 4 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
if STATE_VAR in self.state.configs.keys():
value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
if STATE_VAR in self.context.state.configs.keys():
value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))
self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.get_configs(STATE_VAR)[STATE_VAR]
return self.context.state.get_configs(STATE_VAR)[STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
Expand Down
8 changes: 4 additions & 4 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
if STATE_VAR in self.state.configs.keys():
value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
if STATE_VAR in self.context.state.configs.keys():
value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))
self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.get_configs(STATE_VAR)[STATE_VAR]
return self.context.state.get_configs(STATE_VAR)[STATE_VAR]
def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,22 +352,22 @@ def _load_app() -> Flower:
break

# Register state
node_state.register_runstate(run_id=task_ins.run_id)
node_state.register_context(run_id=task_ins.run_id)

# Load app
app: Flower = load_flower_callable_fn()

# Handle task message
fwd_msg: Fwd = Fwd(
task_ins=task_ins,
state=node_state.retrieve_runstate(run_id=task_ins.run_id),
context=node_state.retrieve_context(run_id=task_ins.run_id),
)
bwd_msg: Bwd = app(fwd=fwd_msg)

# Update node state
node_state.update_runstate(
run_id=bwd_msg.task_res.run_id,
run_state=bwd_msg.state,
node_state.update_context(
run_id=fwd_msg.task_ins.run_id,
context=bwd_msg.context,
)

# Send
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
Parameters,
Status,
)
from flwr.common.recordset import RecordSet
from flwr.common.context import Context


class Client(ABC):
"""Abstract base class for Flower clients."""

state: RecordSet
context: Context

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -141,13 +141,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> RecordSet:
"""Get the run state from this client."""
return self.state
def get_context(self) -> Context:
"""Get the run context from this client."""
return self.context

def set_state(self, state: RecordSet) -> None:
"""Apply a run state to this client."""
self.state = state
def set_context(self, context: Context) -> None:
"""Apply a run context to this client."""
self.context = context

def to_client(self) -> Client:
"""Return client (itself)."""
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __init__(
) -> None:
# Create wrapper function for `handle`
def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name
task_res, state_updated = handle(
task_res, context_updated = handle(
client_fn=client_fn,
state=fwd.state,
context=fwd.context,
task_ins=fwd.task_ins,
)
return Bwd(task_res=task_res, state=state_updated)
return Bwd(task_res=task_res, context=context_updated)

# Wrap middleware layers around the wrapped handle function
self._call = make_ffn(ffn, layers if layers is not None else [])
Expand Down
30 changes: 15 additions & 15 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.common.recordset import RecordSet
from flwr.common.context import Context
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
SecureAggregation,
Task,
Expand Down Expand Up @@ -88,16 +88,16 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:


def handle(
client_fn: ClientFn, state: RecordSet, task_ins: TaskIns
) -> Tuple[TaskRes, RecordSet]:
client_fn: ClientFn, context: Context, task_ins: TaskIns
) -> Tuple[TaskRes, Context]:
"""Handle incoming TaskIns from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RecordSet
A dataclass storing the state for the run being executed by the client.
context : Context
A dataclass storing the context for the run being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.
Expand All @@ -110,7 +110,7 @@ def handle(
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
client.set_state(state)
client.set_context(context)
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -127,24 +127,24 @@ def handle(
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res, client.get_state()
return task_res, client.get_context()
raise NotImplementedError()
client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg)
client_msg, updated_context = handle_legacy_message(client_fn, context, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
return task_res, updated_state
return task_res, updated_context


def handle_legacy_message(
client_fn: ClientFn, state: RecordSet, server_msg: ServerMessage
) -> Tuple[ClientMessage, RecordSet]:
client_fn: ClientFn, context: Context, server_msg: ServerMessage
) -> Tuple[ClientMessage, Context]:
"""Handle incoming messages from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RecordSet
A dataclass storing the state for the run being executed by the client.
context : Context
A dataclass storing the context for the run being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.
Expand All @@ -161,7 +161,7 @@ def handle_legacy_message(

# Instantiate the client
client = client_fn("-1")
client.set_state(state)
client.set_context(context)
# Execute task
message = None
if field == "get_properties_ins":
Expand All @@ -173,7 +173,7 @@ def handle_legacy_message(
if field == "evaluate_ins":
message = _evaluate(client, server_msg.evaluate_ins)
if message:
return message, client.get_state()
return message, client.get_context()
raise UnknownServerMessage()


Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
serde,
typing,
)
from flwr.common.context import Context
from flwr.common.recordset import RecordSet
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
Expand Down Expand Up @@ -141,7 +142,7 @@ def test_client_without_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RecordSet(),
context=Context(state=RecordSet()),
task_ins=task_ins,
)

Expand Down Expand Up @@ -209,7 +210,7 @@ def test_client_with_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RecordSet(),
context=Context(state=RecordSet()),
task_ins=task_ins,
)

Expand Down
13 changes: 9 additions & 4 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List

from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.common.context import Context
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

Expand All @@ -45,7 +46,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable:
def app(fwd: Fwd) -> Bwd:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
return Bwd(task_res=TaskRes(task_id=name), state=RecordSet())
return Bwd(task_res=TaskRes(task_id=name), context=Context(state=RecordSet()))

return app

Expand All @@ -66,7 +67,8 @@ def test_multiple_middlewares(self) -> None:

# Execute
wrapped_app = make_ffn(mock_app, mock_middleware_layers)
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res
dummy_context = Context(state=RecordSet())
task_res = wrapped_app(Fwd(task_ins=task_ins, context=dummy_context)).task_res

# Assert
trace = mock_middleware_names + ["app"]
Expand All @@ -86,11 +88,14 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd:
footprint.append("filter")
fwd.task_ins.task_id += "filter"
# Skip calling app
return Bwd(task_res=TaskRes(task_id="filter"), state=RecordSet())
return Bwd(
task_res=TaskRes(task_id="filter"), context=Context(state=RecordSet())
)

# Execute
wrapped_app = make_ffn(mock_app, [filter_layer])
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res
dummy_context = Context(state=RecordSet())
task_res = wrapped_app(Fwd(task_ins=task_ins, context=dummy_context)).task_res

# Assert
self.assertEqual(footprint, ["filter"])
Expand Down
29 changes: 15 additions & 14 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import Any, Dict

from flwr.common.context import Context
from flwr.common.recordset import RecordSet


Expand All @@ -25,24 +26,24 @@ class NodeState:

def __init__(self) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.run_states: Dict[int, RecordSet] = {}
self.run_contexts: Dict[int, Context] = {}

def register_runstate(self, run_id: int) -> None:
"""Register new run state for this node."""
if run_id not in self.run_states:
self.run_states[run_id] = RecordSet()
def register_context(self, run_id: int) -> None:
"""Register new run context for this node."""
if run_id not in self.run_contexts:
self.run_contexts[run_id] = Context(state=RecordSet())

def retrieve_runstate(self, run_id: int) -> RecordSet:
"""Get run state given a run_id."""
if run_id in self.run_states:
return self.run_states[run_id]
def retrieve_context(self, run_id: int) -> Context:
"""Get run context given a run_id."""
if run_id in self.run_contexts:
return self.run_contexts[run_id]

raise RuntimeError(
f"RunState for run_id={run_id} doesn't exist."
" A run must be registered before it can be retrieved or updated "
f"Context for run_id={run_id} doesn't exist."
" A run context must be registered before it can be retrieved or updated "
" by a client."
)

def update_runstate(self, run_id: int, run_state: RecordSet) -> None:
"""Update run state."""
self.run_states[run_id] = run_state
def update_context(self, run_id: int, context: Context) -> None:
"""Update run context."""
self.run_contexts[run_id] = context
26 changes: 14 additions & 12 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

from flwr.client.node_state import NodeState
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.recordset import RecordSet
from flwr.common.context import Context
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611


def _run_dummy_task(state: RecordSet) -> RecordSet:
def _run_dummy_task(context: Context) -> Context:
counter_value: str = "1"
if "counter" in state.configs.keys():
counter_value = state.get_configs("counter")["count"] # type: ignore
if "counter" in context.state.configs.keys():
counter_value = context.get_configs("counter")["count"] # type: ignore
counter_value += "1"

state.set_configs(name="counter", record=ConfigsRecord({"count": counter_value}))
context.state.set_configs(
name="counter", record=ConfigsRecord({"count": counter_value})
)

return state
return context


def test_multirun_in_node_state() -> None:
Expand All @@ -46,17 +48,17 @@ def test_multirun_in_node_state() -> None:
run_id = task.run_id

# Register
node_state.register_runstate(run_id=run_id)
node_state.register_context(run_id=run_id)

# Get run state
state = node_state.retrieve_runstate(run_id=run_id)
context = node_state.retrieve_context(run_id=run_id)

# Run "task"
updated_state = _run_dummy_task(state)
updated_state = _run_dummy_task(context)

# Update run state
node_state.update_runstate(run_id=run_id, run_state=updated_state)
node_state.update_context(run_id=run_id, context=updated_state)

# Verify values
for run_id, state in node_state.run_states.items():
assert state.get_configs("counter")["count"] == expected_values[run_id]
for run_id, context in node_state.run_contexts.items():
assert context.state.get_configs("counter")["count"] == expected_values[run_id]
Loading

0 comments on commit 0866770

Please sign in to comment.