Skip to content

Commit

Permalink
Merge branch 'main' into vce-fleet-api-backends
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Feb 23, 2024
2 parents 277af9d + a50dfd0 commit 8ab5e68
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 148 deletions.
8 changes: 4 additions & 4 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
if STATE_VAR in self.context.state.configs.keys():
value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
if STATE_VAR in self.context.state.configs_records.keys():
value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))
self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value})

def _retrieve_timestamp_from_state(self):
return self.context.state.get_configs(STATE_VAR)[STATE_VAR]
return self.context.state.configs_records[STATE_VAR][STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
Expand Down
8 changes: 4 additions & 4 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
if STATE_VAR in self.context.state.configs.keys():
value = self.context.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
if STATE_VAR in self.context.state.configs_records.keys():
value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.context.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))
self.context.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value})

def _retrieve_timestamp_from_state(self):
return self.context.state.get_configs(STATE_VAR)[STATE_VAR]
return self.context.state.configs_records[STATE_VAR][STATE_VAR]
def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
Expand Down
10 changes: 5 additions & 5 deletions examples/secaggplus-mt/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@
RECORD_KEY_CONFIGS,
)
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
from flwr.common.typing import ConfigsRecordValues, FitIns, ServerMessage
from flwr.common.typing import ConfigsRecordValues, FitIns
from flwr.proto.task_pb2 import Task
from flwr.common import serde
from flwr.common.constant import TASK_TYPE_FIT
from flwr.common.constant import MESSAGE_TYPE_FIT
from flwr.common import RecordSet
from flwr.common import recordset_compat as compat
from flwr.common import ConfigsRecord
Expand All @@ -79,16 +79,16 @@ def _wrap_in_task(
recordset = compat.fitins_to_recordset(fit_ins, keep_input=True)
else:
recordset = RecordSet()
recordset.set_configs(RECORD_KEY_CONFIGS, ConfigsRecord(named_values))
recordset.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(named_values)
return Task(
task_type=TASK_TYPE_FIT,
task_type=MESSAGE_TYPE_FIT,
recordset=serde.recordset_to_proto(recordset),
)


def _get_from_task(task: Task) -> Dict[str, ConfigsRecordValues]:
recordset = serde.recordset_from_proto(task.recordset)
return recordset.get_configs(RECORD_KEY_CONFIGS).data
return recordset.configs_records[RECORD_KEY_CONFIGS]


_secure_aggregation_configuration = {
Expand Down
8 changes: 5 additions & 3 deletions src/py/flwr/client/grpc_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def receive() -> Message:
message_type = MESSAGE_TYPE_EVALUATE
elif field == "reconnect_ins":
recordset = RecordSet()
recordset.set_configs(
"config", ConfigsRecord({"seconds": proto.reconnect_ins.seconds})
recordset.configs_records["config"] = ConfigsRecord(
{"seconds": proto.reconnect_ins.seconds}
)
message_type = "reconnect"
else:
Expand Down Expand Up @@ -207,7 +207,9 @@ def send(message: Message) -> None:
evalres = compat.recordset_to_evaluateres(recordset)
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
elif message_type == "reconnect":
reason = cast(Reason.ValueType, recordset.get_configs("config")["reason"])
reason = cast(
Reason.ValueType, recordset.configs_records["config"]["reason"]
)
msg_proto = ClientMessage(
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
)
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
ttl="",
message_type="reconnect",
),
content=RecordSet(configs={"config": ConfigsRecord({"reason": 0})}),
content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}),
)


Expand Down
18 changes: 15 additions & 3 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Client-side message handler."""


from logging import WARN
from typing import Optional, Tuple, cast

from flwr.client.client import (
Expand All @@ -23,8 +24,9 @@
maybe_call_get_parameters,
maybe_call_get_properties,
)
from flwr.client.numpy_client import NumPyClient
from flwr.client.typing import ClientFn
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
from flwr.common.constant import (
MESSAGE_TYPE_EVALUATE,
MESSAGE_TYPE_FIT,
Expand Down Expand Up @@ -75,15 +77,15 @@ def handle_control_message(message: Message) -> Tuple[Optional[Message], int]:
if message.metadata.message_type == "reconnect":
# Retrieve ReconnectIns from recordset
recordset = message.content
seconds = cast(int, recordset.get_configs("config")["seconds"])
seconds = cast(int, recordset.configs_records["config"]["seconds"])
# Construct ReconnectIns and call _reconnect
disconnect_msg, sleep_duration = _reconnect(
ServerMessage.ReconnectIns(seconds=seconds)
)
# Store DisconnectRes in recordset
reason = cast(int, disconnect_msg.disconnect_res.reason)
recordset = RecordSet()
recordset.set_configs("config", ConfigsRecord({"reason": reason}))
recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
out_message = message.create_reply(recordset, ttl="")
# Return TaskRes and sleep duration
return out_message, sleep_duration
Expand All @@ -98,6 +100,16 @@ def handle_legacy_message_from_msgtype(
"""Handle legacy message in the inner most mod."""
client = client_fn(str(message.metadata.dst_node_id))

# Check if NumPyClient is returend
if isinstance(client, NumPyClient):
client = client.to_client()
log(
WARN,
"Deprecation Warning: The `client_fn` function must return an instance "
"of `Client`, but an instance of `NumpyClient` was returned. "
"Please use `NumPyClient.to_client()` method to convert it to `Client`.",
)

client.set_context(context)

message_type = message.metadata.message_type
Expand Down
12 changes: 6 additions & 6 deletions src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ def secaggplus_mod(
return call_next(msg, ctxt)

# Retrieve local state
if RECORD_KEY_STATE not in ctxt.state.configs:
ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord({}))
state_dict = ctxt.state.get_configs(RECORD_KEY_STATE)
if RECORD_KEY_STATE not in ctxt.state.configs_records:
ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord({})
state_dict = ctxt.state.configs_records[RECORD_KEY_STATE]
state = SecAggPlusState(**state_dict)

# Retrieve incoming configs
configs = msg.content.get_configs(RECORD_KEY_CONFIGS)
configs = msg.content.configs_records[RECORD_KEY_CONFIGS]

# Check the validity of the next stage
check_stage(state.current_stage, configs)
Expand All @@ -206,10 +206,10 @@ def secaggplus_mod(
raise ValueError(f"Unknown secagg stage: {state.current_stage}")

# Save state
ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict()))
ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())

# Return message
content = RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)})
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)})
return msg.create_reply(content, ttl="")


Expand Down
12 changes: 7 additions & 5 deletions src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,29 @@ def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord:
ttl="",
message_type=MESSAGE_TYPE_FIT,
),
content=RecordSet(configs={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}),
content=RecordSet(
configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}
),
)
out_msg = app(in_msg, ctxt)
return out_msg.content.get_configs(RECORD_KEY_CONFIGS)
return out_msg.content.configs_records[RECORD_KEY_CONFIGS]

return func


def _make_ctxt() -> Context:
cfg = ConfigsRecord(SecAggPlusState().to_dict())
return Context(RecordSet(configs={RECORD_KEY_STATE: cfg}))
return Context(RecordSet(configs_records={RECORD_KEY_STATE: cfg}))


def _make_set_state_fn(
ctxt: Context,
) -> Callable[[str], None]:
def set_stage(stage: str) -> None:
state_dict = ctxt.state.get_configs(RECORD_KEY_STATE)
state_dict = ctxt.state.configs_records[RECORD_KEY_STATE]
state = SecAggPlusState(**state_dict)
state.current_stage = stage
ctxt.state.set_configs(RECORD_KEY_STATE, ConfigsRecord(state.to_dict()))
ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())

return set_stage

Expand Down
33 changes: 18 additions & 15 deletions src/py/flwr/client/mod/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


import unittest
from typing import List
from typing import List, cast

from flwr.client.typing import ClientAppCallable, Mod
from flwr.common import (
Expand All @@ -36,10 +36,10 @@

def _increment_context_counter(context: Context) -> None:
# Read from context
current_counter: int = context.state.get_metrics(METRIC)[COUNTER] # type: ignore
current_counter = cast(int, context.state.metrics_records[METRIC][COUNTER])
# update and override context
current_counter += 1
context.state.set_metrics(METRIC, record=MetricsRecord({COUNTER: current_counter}))
context.state.metrics_records[METRIC] = MetricsRecord({COUNTER: current_counter})


def make_mock_mod(name: str, footprint: List[str]) -> Mod:
Expand All @@ -48,13 +48,13 @@ def make_mock_mod(name: str, footprint: List[str]) -> Mod:
def mod(message: Message, context: Context, app: ClientAppCallable) -> Message:
footprint.append(name)
# add empty ConfigRecord to in_message for this mod
message.content.set_configs(name=name, record=ConfigsRecord())
message.content.configs_records[name] = ConfigsRecord()
_increment_context_counter(context)
out_message: Message = app(message, context)
footprint.append(name)
_increment_context_counter(context)
# add empty ConfigRegcord to out_message for this mod
out_message.content.set_configs(name=name, record=ConfigsRecord())
out_message.content.configs_records[name] = ConfigsRecord()
return out_message

return mod
Expand All @@ -65,9 +65,9 @@ def make_mock_app(name: str, footprint: List[str]) -> ClientAppCallable:

def app(message: Message, context: Context) -> Message:
footprint.append(name)
message.content.set_configs(name=name, record=ConfigsRecord())
message.content.configs_records[name] = ConfigsRecord()
out_message = Message(metadata=message.metadata, content=RecordSet())
out_message.content.set_configs(name=name, record=ConfigsRecord())
out_message.content.configs_records[name] = ConfigsRecord()
print(context)
return out_message

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_multiple_mods(self) -> None:
mock_mods = [make_mock_mod(name, footprint) for name in mock_mod_names]

state = RecordSet()
state.set_metrics(METRIC, record=MetricsRecord({COUNTER: 0.0}))
state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0})
context = Context(state=state)
message = _get_dummy_flower_message()

Expand All @@ -114,11 +114,14 @@ def test_multiple_mods(self) -> None:
trace = mock_mod_names + ["app"]
self.assertEqual(footprint, trace + list(reversed(mock_mod_names)))
# pylint: disable-next=no-member
self.assertEqual("".join(message.content.configs.keys()), "".join(trace))
self.assertEqual(
"".join(out_message.content.configs.keys()), "".join(reversed(trace))
"".join(message.content.configs_records.keys()), "".join(trace)
)
self.assertEqual(state.get_metrics(METRIC)[COUNTER], 2 * len(mock_mods))
self.assertEqual(
"".join(out_message.content.configs_records.keys()),
"".join(reversed(trace)),
)
self.assertEqual(state.metrics_records[METRIC][COUNTER], 2 * len(mock_mods))

def test_filter(self) -> None:
"""Test if a mod can filter incoming TaskIns."""
Expand All @@ -134,9 +137,9 @@ def filter_mod(
_2: ClientAppCallable,
) -> Message:
footprint.append("filter")
message.content.set_configs(name="filter", record=ConfigsRecord())
message.content.configs_records["filter"] = ConfigsRecord()
out_message = Message(metadata=message.metadata, content=RecordSet())
out_message.content.set_configs(name="filter", record=ConfigsRecord())
out_message.content.configs_records["filter"] = ConfigsRecord()
# Skip calling app
return out_message

Expand All @@ -147,5 +150,5 @@ def filter_mod(
# Assert
self.assertEqual(footprint, ["filter"])
# pylint: disable-next=no-member
self.assertEqual(list(message.content.configs.keys())[0], "filter")
self.assertEqual(list(out_message.content.configs.keys())[0], "filter")
self.assertEqual(list(message.content.configs_records.keys())[0], "filter")
self.assertEqual(list(out_message.content.configs_records.keys())[0], "filter")
14 changes: 8 additions & 6 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
"""Node state tests."""


from typing import cast

from flwr.client.node_state import NodeState
from flwr.common import ConfigsRecord, Context
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611


def _run_dummy_task(context: Context) -> Context:
counter_value: str = "1"
if "counter" in context.state.configs.keys():
counter_value = context.get_configs("counter")["count"] # type: ignore
if "counter" in context.state.configs_records.keys():
counter_value = cast(str, context.state.configs_records["counter"]["count"])
counter_value += "1"

context.state.set_configs(
name="counter", record=ConfigsRecord({"count": counter_value})
)
context.state.configs_records["counter"] = ConfigsRecord({"count": counter_value})

return context

Expand Down Expand Up @@ -60,4 +60,6 @@ def test_multirun_in_node_state() -> None:

# Verify values
for run_id, context in node_state.run_contexts.items():
assert context.state.get_configs("counter")["count"] == expected_values[run_id]
assert (
context.state.configs_records["counter"]["count"] == expected_values[run_id]
)
Loading

0 comments on commit 8ab5e68

Please sign in to comment.