From a1f74f6dada500b9828b260c834afd87fb4a9057 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 6 Nov 2024 12:25:06 +0000 Subject: [PATCH 1/2] docs(framework) Update instructions in quickstart compose guide (#4409) Co-authored-by: Robert Steiner --- .../docker/run-quickstart-examples-docker-compose.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/source/docker/run-quickstart-examples-docker-compose.rst b/doc/source/docker/run-quickstart-examples-docker-compose.rst index a92f5fffdc3f..70e9b190faaf 100644 --- a/doc/source/docker/run-quickstart-examples-docker-compose.rst +++ b/doc/source/docker/run-quickstart-examples-docker-compose.rst @@ -39,13 +39,16 @@ Run the Quickstart Example .. code-block:: bash :substitutions: - $ curl https://raw.githubusercontent.com/adap/flower/refs/tags/v|stable_flwr_version|/src/docker/complete/compose.yml \ + $ curl https://raw.githubusercontent.com/adap/flower/24b2861465431a5ab234a8c4f76faea7a742b1fd/src/docker/complete/compose.yml \ -o compose.yml -3. Build and start the services using the following command: +3. Export the version of Flower that your environment uses. Then, build and start the + services using the following command: .. code-block:: bash + :substitutions: + $ export FLWR_VERSION="|stable_flwr_version|" # update with your version $ docker compose up --build -d 4. Append the following lines to the end of the ``pyproject.toml`` file and save it: From caafd022693149cd1fa67a304d3783b81e5499f4 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 6 Nov 2024 12:33:44 +0000 Subject: [PATCH 2/2] feat(framework:skip) Add `run_id` to `Context` (#4429) --- src/proto/flwr/proto/message.proto | 9 +++++---- .../clientapp/clientappio_servicer_test.py | 3 +++ .../message_handler/message_handler_test.py | 8 ++++++-- .../secure_aggregation/secaggplus_mod_test.py | 1 + src/py/flwr/client/mod/utils_test.py | 8 ++++++-- src/py/flwr/client/run_info_store.py | 1 + src/py/flwr/common/context.py | 13 +++++++++---- src/py/flwr/common/serde.py | 2 ++ src/py/flwr/common/serde_test.py | 1 + src/py/flwr/proto/message_pb2.py | 16 ++++++++-------- src/py/flwr/proto/message_pb2.pyi | 5 ++++- src/py/flwr/server/run_serverapp.py | 3 +++ src/py/flwr/server/server_app_test.py | 4 +++- .../server/superlink/linkstate/linkstate_test.py | 2 ++ src/py/flwr/simulation/run_simulation.py | 3 +++ src/py/flwr/superexec/deployment.py | 4 +++- 16 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/proto/flwr/proto/message.proto b/src/proto/flwr/proto/message.proto index 7066da5b7e76..cbe4bf7e027f 100644 --- a/src/proto/flwr/proto/message.proto +++ b/src/proto/flwr/proto/message.proto @@ -28,10 +28,11 @@ message Message { } message Context { - uint64 node_id = 1; - map node_config = 2; - RecordSet state = 3; - map run_config = 4; + uint64 run_id = 1; + uint64 node_id = 2; + map node_config = 3; + RecordSet state = 4; + map run_config = 5; } message Metadata { diff --git a/src/py/flwr/client/clientapp/clientappio_servicer_test.py b/src/py/flwr/client/clientapp/clientappio_servicer_test.py index 82c9f16e8201..3c862884a5f3 100644 --- a/src/py/flwr/client/clientapp/clientappio_servicer_test.py +++ b/src/py/flwr/client/clientapp/clientappio_servicer_test.py @@ -66,6 +66,7 @@ def test_set_inputs(self) -> None: content=self.maker.recordset(2, 2, 1), ) context = Context( + run_id=1, node_id=1, node_config={"nodeconfig1": 4.2}, state=self.maker.recordset(2, 2, 1), @@ -122,6 +123,7 @@ def test_get_outputs(self) -> None: content=self.maker.recordset(2, 2, 1), ) context = Context( + run_id=1, node_id=1, node_config={"nodeconfig1": 4.2}, state=self.maker.recordset(2, 2, 1), @@ -186,6 +188,7 @@ def test_push_clientapp_outputs(self) -> None: content=self.maker.recordset(2, 2, 1), ) context = Context( + run_id=1, node_id=1, node_config={"nodeconfig1": 4.2}, state=self.maker.recordset(2, 2, 1), diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 311f8c37e1b1..0be5ab30e026 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -142,7 +142,9 @@ def test_client_without_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), + context=Context( + run_id=2234, node_id=1123, node_config={}, state=RecordSet(), run_config={} + ), ) # Assert @@ -206,7 +208,9 @@ def test_client_with_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), + context=Context( + run_id=2234, node_id=1123, node_config={}, state=RecordSet(), run_config={} + ), ) # Assert diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index e68bf5177797..89729bca1b9c 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -74,6 +74,7 @@ def func(configs: dict[str, ConfigsRecordValues]) -> ConfigsRecord: def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) return Context( + run_id=234, node_id=123, node_config={}, state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}), diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index e75fb5530b2c..248ee5bdae81 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -104,7 +104,9 @@ def test_multiple_mods(self) -> None: state = RecordSet() state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) - context = Context(node_id=0, node_config={}, state=state, run_config={}) + context = Context( + run_id=1, node_id=0, node_config={}, state=state, run_config={} + ) message = _get_dummy_flower_message() # Execute @@ -129,7 +131,9 @@ def test_filter(self) -> None: # Prepare footprint: list[str] = [] mock_app = make_mock_app("app", footprint) - context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) + context = Context( + run_id=1, node_id=0, node_config={}, state=RecordSet(), run_config={} + ) message = _get_dummy_flower_message() def filter_mod( diff --git a/src/py/flwr/client/run_info_store.py b/src/py/flwr/client/run_info_store.py index 6b0c3bd3a493..a5cd5129bc3a 100644 --- a/src/py/flwr/client/run_info_store.py +++ b/src/py/flwr/client/run_info_store.py @@ -83,6 +83,7 @@ def register_context( self.run_infos[run_id] = RunInfo( initial_run_config=initial_run_config, context=Context( + run_id=run_id, node_id=self.node_id, node_config=self.node_config, state=RecordSet(), diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 1544b96d3fa3..edf2024c2b1c 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -27,36 +27,41 @@ class Context: Parameters ---------- + run_id : int + The ID that identifies the run. node_id : int The ID that identifies the node. node_config : UserConfig A config (key/value mapping) unique to the node and independent of the `run_config`. This config persists across all runs this node participates in. state : RecordSet - Holds records added by the entity in a given run and that will stay local. + Holds records added by the entity in a given `run_id` and that will stay local. This means that the data it holds will never leave the system it's running from. This can be used as an intermediate storage or scratchpad when executing mods. It can also be used as a memory to access at different points during the lifecycle of this entity (e.g. across multiple rounds) run_config : UserConfig - A config (key/value mapping) held by the entity in a given run and that will - stay local. It can be used at any point during the lifecycle of this entity + A config (key/value mapping) held by the entity in a given `run_id` and that + will stay local. It can be used at any point during the lifecycle of this entity (e.g. across multiple rounds) """ + run_id: int node_id: int node_config: UserConfig state: RecordSet run_config: UserConfig - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments self, + run_id: int, node_id: int, node_config: UserConfig, state: RecordSet, run_config: UserConfig, ) -> None: + self.run_id = run_id self.node_id = node_id self.node_config = node_config self.state = state diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index acac1ca046b7..99c52289b5a1 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -840,6 +840,7 @@ def message_from_proto(message_proto: ProtoMessage) -> Message: def context_to_proto(context: Context) -> ProtoContext: """Serialize `Context` to ProtoBuf.""" proto = ProtoContext( + run_id=context.run_id, node_id=context.node_id, node_config=user_config_to_proto(context.node_config), state=recordset_to_proto(context.state), @@ -851,6 +852,7 @@ def context_to_proto(context: Context) -> ProtoContext: def context_from_proto(context_proto: ProtoContext) -> Context: """Deserialize `Context` from ProtoBuf.""" context = Context( + run_id=context_proto.run_id, node_id=context_proto.node_id, node_config=user_config_from_proto(context_proto.node_config), state=recordset_from_proto(context_proto.state), diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 19e9889158a0..38ad1894f411 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -503,6 +503,7 @@ def test_context_serialization_deserialization() -> None: # Prepare maker = RecordMaker() original = Context( + run_id=0, node_id=1, node_config=maker.user_config(), state=maker.recordset(1, 1, 1), diff --git a/src/py/flwr/proto/message_pb2.py b/src/py/flwr/proto/message_pb2.py index d2201cb07b56..92e37d3b7ed4 100644 --- a/src/py/flwr/proto/message_pb2.py +++ b/src/py/flwr/proto/message_pb2.py @@ -17,7 +17,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x04\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x04\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xcf\x02\n\x07\x43ontext\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0f\n\x07node_id\x18\x02 \x01(\x04\x12\x38\n\x0bnode_config\x18\x03 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x04 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x05 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x04\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x04\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,11 +31,11 @@ _globals['_MESSAGE']._serialized_start=120 _globals['_MESSAGE']._serialized_end=243 _globals['_CONTEXT']._serialized_start=246 - _globals['_CONTEXT']._serialized_end=565 - _globals['_CONTEXT_NODECONFIGENTRY']._serialized_start=426 - _globals['_CONTEXT_NODECONFIGENTRY']._serialized_end=495 - _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=497 - _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=565 - _globals['_METADATA']._serialized_start=568 - _globals['_METADATA']._serialized_end=755 + _globals['_CONTEXT']._serialized_end=581 + _globals['_CONTEXT_NODECONFIGENTRY']._serialized_start=442 + _globals['_CONTEXT_NODECONFIGENTRY']._serialized_end=511 + _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=513 + _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=581 + _globals['_METADATA']._serialized_start=584 + _globals['_METADATA']._serialized_end=771 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/message_pb2.pyi b/src/py/flwr/proto/message_pb2.pyi index b352917f217e..766829a4798c 100644 --- a/src/py/flwr/proto/message_pb2.pyi +++ b/src/py/flwr/proto/message_pb2.pyi @@ -67,10 +67,12 @@ class Context(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + RUN_ID_FIELD_NUMBER: builtins.int NODE_ID_FIELD_NUMBER: builtins.int NODE_CONFIG_FIELD_NUMBER: builtins.int STATE_FIELD_NUMBER: builtins.int RUN_CONFIG_FIELD_NUMBER: builtins.int + run_id: builtins.int node_id: builtins.int @property def node_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... @@ -80,13 +82,14 @@ class Context(google.protobuf.message.Message): def run_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... def __init__(self, *, + run_id: builtins.int = ..., node_id: builtins.int = ..., node_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., state: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., run_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["state",b"state"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["node_config",b"node_config","node_id",b"node_id","run_config",b"run_config","state",b"state"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["node_config",b"node_config","node_id",b"node_id","run_config",b"run_config","run_id",b"run_id","state",b"state"]) -> None: ... global___Context = Context class Metadata(google.protobuf.message.Message): diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 9937b993fd02..2215b87295b6 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -188,6 +188,7 @@ def run_server_app() -> None: app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir)) config = get_project_config(app_path) + run_id = run_.run_id else: # User provided `app_dir`, but not `--run-id` # Create run if run_id is not provided @@ -204,6 +205,7 @@ def run_server_app() -> None: res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212 # Fetch full `Run` using `run_id` driver.init_run(res.run_id) # pylint: disable=W0212 + run_id = res.run_id # Obtain server app reference and the run config server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"] @@ -221,6 +223,7 @@ def run_server_app() -> None: # Initialize Context context = Context( + run_id=run_id, node_id=0, node_config={}, state=RecordSet(), diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index b0672b3202ed..b2515f09fdac 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -29,7 +29,9 @@ def test_server_app_custom_mode() -> None: # Prepare app = ServerApp() driver = MagicMock() - context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) + context = Context( + run_id=1, node_id=0, node_config={}, state=RecordSet(), run_config={} + ) called = {"called": False} diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 9e00e4a0c49a..2cdea58a7cb7 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -1036,6 +1036,7 @@ def test_get_set_serverapp_context(self) -> None: # Prepare state: LinkState = self.state_factory() context = Context( + run_id=1, node_id=0, node_config={"mock": "mock"}, state=RecordSet(), @@ -1057,6 +1058,7 @@ def test_set_context_invalid_run_id(self) -> None: # Prepare state: LinkState = self.state_factory() context = Context( + run_id=1, node_id=0, node_config={"mock": "mock"}, state=RecordSet(), diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 929824843f54..88d3fc8b213c 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -234,6 +234,7 @@ def run_serverapp_th( f_stop: threading.Event, has_exception: threading.Event, enable_tf_gpu_growth: bool, + run_id: int, ) -> threading.Thread: """Run SeverApp in a thread.""" @@ -258,6 +259,7 @@ def server_th_with_start_checks( # Initialize Context context = Context( + run_id=run_id, node_id=0, node_config={}, state=RecordSet(), @@ -357,6 +359,7 @@ def _main_loop( f_stop=f_stop, has_exception=server_app_thread_has_exception, enable_tf_gpu_growth=enable_tf_gpu_growth, + run_id=run.run_id, ) # Buffer time so the `ServerApp` in separate thread is ready diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 96d184661048..5d31bcd5edc4 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -139,7 +139,9 @@ def _create_run( def _create_context(self, run_id: int) -> None: """Register a Context for a Run.""" # Create an empty context for the Run - context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) + context = Context( + run_id=run_id, node_id=0, node_config={}, state=RecordSet(), run_config={} + ) # Register the context at the LinkState self.linkstate.set_serverapp_context(run_id=run_id, context=context)