Skip to content

Commit

Permalink
Merge branch 'main' into feat/base-gpu-image
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert-Steiner authored Nov 6, 2024
2 parents 28c8359 + caafd02 commit 0d68883
Show file tree
Hide file tree
Showing 17 changed files with 65 additions and 25 deletions.
7 changes: 5 additions & 2 deletions doc/source/docker/run-quickstart-examples-docker-compose.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ Run the Quickstart Example
.. code-block:: bash
:substitutions:
$ curl https://raw.githubusercontent.com/adap/flower/refs/tags/v|stable_flwr_version|/src/docker/complete/compose.yml \
$ curl https://raw.githubusercontent.com/adap/flower/24b2861465431a5ab234a8c4f76faea7a742b1fd/src/docker/complete/compose.yml \
-o compose.yml
3. Build and start the services using the following command:
3. Export the version of Flower that your environment uses. Then, build and start the
services using the following command:

.. code-block:: bash
:substitutions:
$ export FLWR_VERSION="|stable_flwr_version|" # update with your version
$ docker compose up --build -d
4. Append the following lines to the end of the ``pyproject.toml`` file and save it:
Expand Down
9 changes: 5 additions & 4 deletions src/proto/flwr/proto/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ message Message {
}

message Context {
uint64 node_id = 1;
map<string, Scalar> node_config = 2;
RecordSet state = 3;
map<string, Scalar> run_config = 4;
uint64 run_id = 1;
uint64 node_id = 2;
map<string, Scalar> node_config = 3;
RecordSet state = 4;
map<string, Scalar> run_config = 5;
}

message Metadata {
Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/client/clientapp/clientappio_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_set_inputs(self) -> None:
content=self.maker.recordset(2, 2, 1),
)
context = Context(
run_id=1,
node_id=1,
node_config={"nodeconfig1": 4.2},
state=self.maker.recordset(2, 2, 1),
Expand Down Expand Up @@ -122,6 +123,7 @@ def test_get_outputs(self) -> None:
content=self.maker.recordset(2, 2, 1),
)
context = Context(
run_id=1,
node_id=1,
node_config={"nodeconfig1": 4.2},
state=self.maker.recordset(2, 2, 1),
Expand Down Expand Up @@ -186,6 +188,7 @@ def test_push_clientapp_outputs(self) -> None:
content=self.maker.recordset(2, 2, 1),
)
context = Context(
run_id=1,
node_id=1,
node_config={"nodeconfig1": 4.2},
state=self.maker.recordset(2, 2, 1),
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def test_client_without_get_properties() -> None:
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}),
context=Context(
run_id=2234, node_id=1123, node_config={}, state=RecordSet(), run_config={}
),
)

# Assert
Expand Down Expand Up @@ -206,7 +208,9 @@ def test_client_with_get_properties() -> None:
actual_msg = handle_legacy_message_from_msgtype(
client_fn=_get_client_fn(client),
message=message,
context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}),
context=Context(
run_id=2234, node_id=1123, node_config={}, state=RecordSet(), run_config={}
),
)

# Assert
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def func(configs: dict[str, ConfigsRecordValues]) -> ConfigsRecord:
def _make_ctxt() -> Context:
cfg = ConfigsRecord(SecAggPlusState().to_dict())
return Context(
run_id=234,
node_id=123,
node_config={},
state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}),
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/client/mod/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def test_multiple_mods(self) -> None:

state = RecordSet()
state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0})
context = Context(node_id=0, node_config={}, state=state, run_config={})
context = Context(
run_id=1, node_id=0, node_config={}, state=state, run_config={}
)
message = _get_dummy_flower_message()

# Execute
Expand All @@ -129,7 +131,9 @@ def test_filter(self) -> None:
# Prepare
footprint: list[str] = []
mock_app = make_mock_app("app", footprint)
context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})
context = Context(
run_id=1, node_id=0, node_config={}, state=RecordSet(), run_config={}
)
message = _get_dummy_flower_message()

def filter_mod(
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/run_info_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def register_context(
self.run_infos[run_id] = RunInfo(
initial_run_config=initial_run_config,
context=Context(
run_id=run_id,
node_id=self.node_id,
node_config=self.node_config,
state=RecordSet(),
Expand Down
13 changes: 9 additions & 4 deletions src/py/flwr/common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,41 @@ class Context:
Parameters
----------
run_id : int
The ID that identifies the run.
node_id : int
The ID that identifies the node.
node_config : UserConfig
A config (key/value mapping) unique to the node and independent of the
`run_config`. This config persists across all runs this node participates in.
state : RecordSet
Holds records added by the entity in a given run and that will stay local.
Holds records added by the entity in a given `run_id` and that will stay local.
This means that the data it holds will never leave the system it's running from.
This can be used as an intermediate storage or scratchpad when
executing mods. It can also be used as a memory to access
at different points during the lifecycle of this entity (e.g. across
multiple rounds)
run_config : UserConfig
A config (key/value mapping) held by the entity in a given run and that will
stay local. It can be used at any point during the lifecycle of this entity
A config (key/value mapping) held by the entity in a given `run_id` and that
will stay local. It can be used at any point during the lifecycle of this entity
(e.g. across multiple rounds)
"""

run_id: int
node_id: int
node_config: UserConfig
state: RecordSet
run_config: UserConfig

def __init__( # pylint: disable=too-many-arguments
def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments
self,
run_id: int,
node_id: int,
node_config: UserConfig,
state: RecordSet,
run_config: UserConfig,
) -> None:
self.run_id = run_id
self.node_id = node_id
self.node_config = node_config
self.state = state
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ def message_from_proto(message_proto: ProtoMessage) -> Message:
def context_to_proto(context: Context) -> ProtoContext:
"""Serialize `Context` to ProtoBuf."""
proto = ProtoContext(
run_id=context.run_id,
node_id=context.node_id,
node_config=user_config_to_proto(context.node_config),
state=recordset_to_proto(context.state),
Expand All @@ -851,6 +852,7 @@ def context_to_proto(context: Context) -> ProtoContext:
def context_from_proto(context_proto: ProtoContext) -> Context:
"""Deserialize `Context` from ProtoBuf."""
context = Context(
run_id=context_proto.run_id,
node_id=context_proto.node_id,
node_config=user_config_from_proto(context_proto.node_config),
state=recordset_from_proto(context_proto.state),
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def test_context_serialization_deserialization() -> None:
# Prepare
maker = RecordMaker()
original = Context(
run_id=0,
node_id=1,
node_config=maker.user_config(),
state=maker.recordset(1, 1, 1),
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/proto/message_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion src/py/flwr/proto/message_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ class Context(google.protobuf.message.Message):
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

RUN_ID_FIELD_NUMBER: builtins.int
NODE_ID_FIELD_NUMBER: builtins.int
NODE_CONFIG_FIELD_NUMBER: builtins.int
STATE_FIELD_NUMBER: builtins.int
RUN_CONFIG_FIELD_NUMBER: builtins.int
run_id: builtins.int
node_id: builtins.int
@property
def node_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
Expand All @@ -80,13 +82,14 @@ class Context(google.protobuf.message.Message):
def run_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
def __init__(self,
*,
run_id: builtins.int = ...,
node_id: builtins.int = ...,
node_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
state: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ...,
run_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["state",b"state"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["node_config",b"node_config","node_id",b"node_id","run_config",b"run_config","state",b"state"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["node_config",b"node_config","node_id",b"node_id","run_config",b"run_config","run_id",b"run_id","state",b"state"]) -> None: ...
global___Context = Context

class Metadata(google.protobuf.message.Message):
Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def run_server_app() -> None:

app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir))
config = get_project_config(app_path)
run_id = run_.run_id
else:
# User provided `app_dir`, but not `--run-id`
# Create run if run_id is not provided
Expand All @@ -204,6 +205,7 @@ def run_server_app() -> None:
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
# Fetch full `Run` using `run_id`
driver.init_run(res.run_id) # pylint: disable=W0212
run_id = res.run_id

# Obtain server app reference and the run config
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
Expand All @@ -221,6 +223,7 @@ def run_server_app() -> None:

# Initialize Context
context = Context(
run_id=run_id,
node_id=0,
node_config={},
state=RecordSet(),
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/server_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def test_server_app_custom_mode() -> None:
# Prepare
app = ServerApp()
driver = MagicMock()
context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})
context = Context(
run_id=1, node_id=0, node_config={}, state=RecordSet(), run_config={}
)

called = {"called": False}

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ def test_get_set_serverapp_context(self) -> None:
# Prepare
state: LinkState = self.state_factory()
context = Context(
run_id=1,
node_id=0,
node_config={"mock": "mock"},
state=RecordSet(),
Expand All @@ -1057,6 +1058,7 @@ def test_set_context_invalid_run_id(self) -> None:
# Prepare
state: LinkState = self.state_factory()
context = Context(
run_id=1,
node_id=0,
node_config={"mock": "mock"},
state=RecordSet(),
Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def run_serverapp_th(
f_stop: threading.Event,
has_exception: threading.Event,
enable_tf_gpu_growth: bool,
run_id: int,
) -> threading.Thread:
"""Run SeverApp in a thread."""

Expand All @@ -258,6 +259,7 @@ def server_th_with_start_checks(

# Initialize Context
context = Context(
run_id=run_id,
node_id=0,
node_config={},
state=RecordSet(),
Expand Down Expand Up @@ -357,6 +359,7 @@ def _main_loop(
f_stop=f_stop,
has_exception=server_app_thread_has_exception,
enable_tf_gpu_growth=enable_tf_gpu_growth,
run_id=run.run_id,
)

# Buffer time so the `ServerApp` in separate thread is ready
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def _create_run(
def _create_context(self, run_id: int) -> None:
"""Register a Context for a Run."""
# Create an empty context for the Run
context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={})
context = Context(
run_id=run_id, node_id=0, node_config={}, state=RecordSet(), run_config={}
)

# Register the context at the LinkState
self.linkstate.set_serverapp_context(run_id=run_id, context=context)
Expand Down

0 comments on commit 0d68883

Please sign in to comment.