Skip to content

Commit

Permalink
Merge branch 'replace-fwd-bwd-flowercontext' of https://github.com/ad…
Browse files Browse the repository at this point in the history
…ap/flower into replace-fwd-bwd-flowercontext
  • Loading branch information
panh99 committed Jan 26, 2024
2 parents 69a7a17 + 3446de9 commit 1543b62
Show file tree
Hide file tree
Showing 14 changed files with 143 additions and 153 deletions.
14 changes: 9 additions & 5 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import flwr as fl
import numpy as np

from flwr.common.configsrecord import ConfigsRecord

SUBSET_SIZE = 1000
STATE_VAR = 'timestamp'

Expand All @@ -18,13 +20,15 @@ def get_parameters(self, config):
def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)
value = str(t_stamp)
if STATE_VAR in self.context.state.configs.keys():
value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]
return self.context.state.get_configs(STATE_VAR)[STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
Expand Down
14 changes: 8 additions & 6 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm

import flwr as fl
from flwr.common.configsrecord import ConfigsRecord

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -95,14 +96,15 @@ def get_parameters(self, config):
def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)
value = str(t_stamp)
if STATE_VAR in self.context.state.configs.keys():
value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]

return self.context.state.get_configs(STATE_VAR)[STATE_VAR]
def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
Expand Down
19 changes: 8 additions & 11 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@
TRANSPORT_TYPE_REST,
TRANSPORT_TYPES,
)
from flwr.common.context import Context
from flwr.common.logger import log, warn_experimental_feature
from flwr.common.recordset import RecordSet
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

Expand Down Expand Up @@ -326,7 +324,6 @@ def _load_app() -> Flower:
connection, address = _init_connection(transport, server_address)

node_state = NodeState()
# TODO: make NodeState work with RecordSet

while True:
sleep_duration: int = 0
Expand Down Expand Up @@ -355,11 +352,11 @@ def _load_app() -> Flower:
send(task_res)
break

# Register state
node_state.register_runstate(run_id=task_ins.run_id)
# Register context for this run
node_state.register_context(run_id=task_ins.run_id)

# TODO: get runstate from nodestate and construct context for this run
context = Context(state=RecordSet())
# Retrieve context for this run
context = node_state.retrieve_context(run_id=task_ins.run_id)

# Get Message from TaskIns
message = message_from_taskins(task_ins)
Expand All @@ -371,10 +368,10 @@ def _load_app() -> Flower:
out_message = app(message=message, context=context)

# Update node state
# node_state.update_runstate(
# run_id=bwd_msg.task_res.run_id,
# run_state=bwd_msg.state,
# )
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)

# Construct TaskRes from out_message
task_res = message_to_taskres(out_message)
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from abc import ABC

from flwr.client.run_state import RunState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -33,12 +32,13 @@
Parameters,
Status,
)
from flwr.common.context import Context


class Client(ABC):
"""Abstract base class for Flower clients."""

state: RunState
context: Context

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -141,13 +141,13 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> RunState:
"""Get the run state from this client."""
return self.state
def get_context(self) -> Context:
"""Get the run context from this client."""
return self.context

def set_state(self, state: RunState) -> None:
"""Apply a run state to this client."""
self.state = state
def set_context(self, context: Context) -> None:
"""Apply a run context to this client."""
self.context = context

def to_client(self) -> Client:
"""Return client (itself)."""
Expand Down
31 changes: 15 additions & 16 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
get_server_message_from_task_ins,
wrap_client_message_in_task_res,
)
from flwr.client.run_state import RunState
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.common import serde
Expand Down Expand Up @@ -107,16 +106,16 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:


def handle(
client_fn: ClientFn, state: RunState, task_ins: TaskIns
) -> Tuple[TaskRes, RunState]:
client_fn: ClientFn, context: Context, task_ins: TaskIns
) -> Tuple[TaskRes, Context]:
"""Handle incoming TaskIns from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RunState
A dataclass storing the state for the run being executed by the client.
context : Context
A dataclass storing the context for the run being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.
Expand All @@ -129,7 +128,7 @@ def handle(
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
client.set_state(state)
client.set_context(context)
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -146,24 +145,24 @@ def handle(
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res, client.get_state()
return task_res, client.get_context()
raise NotImplementedError()
client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg)
client_msg, updated_context = handle_legacy_message(client_fn, context, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
return task_res, updated_state
return task_res, updated_context


def handle_legacy_message(
client_fn: ClientFn, state: RunState, server_msg: ServerMessage
) -> Tuple[ClientMessage, RunState]:
client_fn: ClientFn, context: Context, server_msg: ServerMessage
) -> Tuple[ClientMessage, Context]:
"""Handle incoming messages from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RunState
A dataclass storing the state for the run being executed by the client.
context : Context
A dataclass storing the context for the run being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.
Expand All @@ -180,7 +179,7 @@ def handle_legacy_message(

# Instantiate the client
client = client_fn("-1")
client.set_state(state)
client.set_context(context)
# Execute task
message = None
if field == "get_properties_ins":
Expand All @@ -192,7 +191,7 @@ def handle_legacy_message(
if field == "evaluate_ins":
message = _evaluate(client, server_msg.evaluate_ins)
if message:
return message, client.get_state()
return message, client.get_context()
raise UnknownServerMessage()


Expand All @@ -202,7 +201,7 @@ def handle_legacy_message_from_tasktype(
"""Handle legacy message in the inner most middleware layer."""
client = client_fn("-1")

# TODO: inject state (i.e. context.state) into client?
client.set_context(context)

task_type = message.metadata.task_type

Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import uuid

from flwr.client import Client
from flwr.client.run_state import RunState
from flwr.client.typing import ClientFn
from flwr.common import (
EvaluateIns,
Expand All @@ -33,6 +32,8 @@
serde,
typing,
)
from flwr.common.context import Context
from flwr.common.recordset import RecordSet
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
Expand Down Expand Up @@ -141,7 +142,7 @@ def test_client_without_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RunState(state={}),
context=Context(state=RecordSet()),
task_ins=task_ins,
)

Expand Down Expand Up @@ -209,7 +210,7 @@ def test_client_with_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RunState(state={}),
context=Context(state=RecordSet()),
task_ins=task_ins,
)

Expand Down
31 changes: 16 additions & 15 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,33 @@

from typing import Any, Dict

from flwr.client.run_state import RunState
from flwr.common.context import Context
from flwr.common.recordset import RecordSet


class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(self) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.run_states: Dict[int, RunState] = {}
self.run_contexts: Dict[int, Context] = {}

def register_runstate(self, run_id: int) -> None:
"""Register new run state for this node."""
if run_id not in self.run_states:
self.run_states[run_id] = RunState({})
def register_context(self, run_id: int) -> None:
"""Register new run context for this node."""
if run_id not in self.run_contexts:
self.run_contexts[run_id] = Context(state=RecordSet())

def retrieve_runstate(self, run_id: int) -> RunState:
"""Get run state given a run_id."""
if run_id in self.run_states:
return self.run_states[run_id]
def retrieve_context(self, run_id: int) -> Context:
"""Get run context given a run_id."""
if run_id in self.run_contexts:
return self.run_contexts[run_id]

raise RuntimeError(
f"RunState for run_id={run_id} doesn't exist."
" A run must be registered before it can be retrieved or updated "
f"Context for run_id={run_id} doesn't exist."
" A run context must be registered before it can be retrieved or updated "
" by a client."
)

def update_runstate(self, run_id: int, run_state: RunState) -> None:
"""Update run state."""
self.run_states[run_id] = run_state
def update_context(self, run_id: int, context: Context) -> None:
"""Update run context."""
self.run_contexts[run_id] = context
31 changes: 18 additions & 13 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@


from flwr.client.node_state import NodeState
from flwr.client.run_state import RunState
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.context import Context
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611


def _run_dummy_task(state: RunState) -> RunState:
if "counter" in state.state:
state.state["counter"] += "1"
else:
state.state["counter"] = "1"
def _run_dummy_task(context: Context) -> Context:
counter_value: str = "1"
if "counter" in context.state.configs.keys():
counter_value = context.get_configs("counter")["count"] # type: ignore
counter_value += "1"

return state
context.state.set_configs(
name="counter", record=ConfigsRecord({"count": counter_value})
)

return context


def test_multirun_in_node_state() -> None:
Expand All @@ -43,17 +48,17 @@ def test_multirun_in_node_state() -> None:
run_id = task.run_id

# Register
node_state.register_runstate(run_id=run_id)
node_state.register_context(run_id=run_id)

# Get run state
state = node_state.retrieve_runstate(run_id=run_id)
context = node_state.retrieve_context(run_id=run_id)

# Run "task"
updated_state = _run_dummy_task(state)
updated_state = _run_dummy_task(context)

# Update run state
node_state.update_runstate(run_id=run_id, run_state=updated_state)
node_state.update_context(run_id=run_id, context=updated_state)

# Verify values
for run_id, state in node_state.run_states.items():
assert state.state["counter"] == expected_values[run_id]
for run_id, context in node_state.run_contexts.items():
assert context.state.get_configs("counter")["count"] == expected_values[run_id]
Loading

0 comments on commit 1543b62

Please sign in to comment.