From efdb900671d84cc189af489bea3ac04006fa15d7 Mon Sep 17 00:00:00 2001 From: Robert Steiner Date: Thu, 19 Sep 2024 11:09:38 +0200 Subject: [PATCH 01/11] docs(framework) Add quickstart examples Docker compose guide (#4189) Signed-off-by: Robert Steiner Co-authored-by: Javier --- doc/source/docker/index.rst | 1 + ...run-quickstart-examples-docker-compose.rst | 122 ++++++++++++++++++ .../tutorial-quickstart-docker-compose.rst | 5 + 3 files changed, 128 insertions(+) create mode 100644 doc/source/docker/run-quickstart-examples-docker-compose.rst diff --git a/doc/source/docker/index.rst b/doc/source/docker/index.rst index ac6124b4c138..968f01581b34 100644 --- a/doc/source/docker/index.rst +++ b/doc/source/docker/index.rst @@ -44,3 +44,4 @@ Run Flower using Docker Compose :maxdepth: 1 tutorial-quickstart-docker-compose + run-quickstart-examples-docker-compose diff --git a/doc/source/docker/run-quickstart-examples-docker-compose.rst b/doc/source/docker/run-quickstart-examples-docker-compose.rst new file mode 100644 index 000000000000..b279fb66c45b --- /dev/null +++ b/doc/source/docker/run-quickstart-examples-docker-compose.rst @@ -0,0 +1,122 @@ +Run Flower Quickstart Examples with Docker Compose +================================================== + +Flower provides a set of `quickstart examples `_ +to help you get started with the framework. These examples are designed to demonstrate the +capabilities of Flower and by default run using the Simulation Engine. This guide demonstrates +how to run them using Flower's Deployment Engine via Docker Compose. + +.. important:: + + Some quickstart examples may have limitations or requirements that prevent them from running + on every environment. For more information, please see `Limitations`_. + +Prerequisites +------------- + +Before you start, make sure that: + +- The ``flwr`` CLI is :doc:`installed <../how-to-install-flower>` locally. +- The Docker daemon is running. +- Docker Compose is `installed `_. + +Run the Quickstart Example +-------------------------- + +#. Clone the quickstart example you like to run. For example, ``quickstart-pytorch``: + + .. code-block:: bash + + $ git clone --depth=1 https://github.com/adap/flower.git \ + && mv flower/examples/quickstart-pytorch . \ + && rm -rf flower && cd quickstart-pytorch + +#. Download the `compose.yml `_ file into the example directory: + + .. code-block:: bash + + $ curl https://raw.githubusercontent.com/adap/flower/refs/heads/main/src/docker/complete/compose.yml \ + -o compose.yml + +#. Build and start the services using the following command: + + .. code-block:: bash + + $ docker compose up --build -d + +#. Append the following lines to the end of the ``pyproject.toml`` file and save it: + + .. code-block:: toml + :caption: pyproject.toml + + [tool.flwr.federations.local-deployment] + address = "127.0.0.1:9093" + insecure = true + + .. note:: + + You can customize the string that follows ``tool.flwr.federations.`` to fit your needs. + However, please note that the string cannot contain a dot (``.``). + + In this example, ``local-deployment`` has been used. Just remember to replace + ``local-deployment`` with your chosen name in both the ``tool.flwr.federations.`` string + and the corresponding ``flwr run .`` command. + +#. Run the example: + + .. code-block:: bash + + $ flwr run . local-deployment + +#. Follow the logs of the SuperExec service: + + .. code-block:: bash + + $ docker compose logs superexec -f + +That is all it takes! You can monitor the progress of the run through the logs of the SuperExec. + +Run a Different Quickstart Example +---------------------------------- + +To run a different quickstart example, such as ``quickstart-tensorflow``, first, shut down the Docker +Compose services of the current example: + +.. code-block:: bash + + $ docker compose down + +After that, you can repeat the steps above. + +Limitations +----------- + +.. list-table:: + :header-rows: 1 + + * - Quickstart Example + - Limitations + * - quickstart-fastai + - None + * - examples/quickstart-huggingface + - For CPU-only environments, it requires at least 32GB of memory. + * - quickstart-jax + - The example has not yet been updated to work with the latest ``flwr`` version. + * - quickstart-mlcube + - The example has not yet been updated to work with the latest ``flwr`` version. + * - quickstart-mlx + - `Requires to run on macOS with Apple Silicon `_. + * - quickstart-monai + - None + * - quickstart-pandas + - The example has not yet been updated to work with the latest ``flwr`` version. + * - quickstart-pytorch-lightning + - Requires an older pip version that is not supported by the Flower Docker images. + * - quickstart-pytorch + - None + * - quickstart-sklearn-tabular + - None + * - quickstart-tabnet + - The example has not yet been updated to work with the latest ``flwr`` version. + * - quickstart-tensorflow + - Only runs on AMD64. diff --git a/doc/source/docker/tutorial-quickstart-docker-compose.rst b/doc/source/docker/tutorial-quickstart-docker-compose.rst index 36d9bd58e29e..7aeae1e2fb6b 100644 --- a/doc/source/docker/tutorial-quickstart-docker-compose.rst +++ b/doc/source/docker/tutorial-quickstart-docker-compose.rst @@ -396,3 +396,8 @@ Remove all services and volumes: $ docker compose down -v $ docker compose -f certs.yml down -v + +Where to Go Next +---------------- + +* :doc:`run-quickstart-examples-docker-compose` From a70449d431ddda26a7461bb66300f84a4b7cbc2a Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 19 Sep 2024 10:33:41 +0100 Subject: [PATCH 02/11] fix(framework:skip) Fix the `SqliteState.get_run` method (#4222) --- src/py/flwr/server/superlink/state/sqlite_state.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 286ab881f891..28d957a90bd3 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -782,8 +782,9 @@ def get_run(self, run_id: int) -> Optional[Run]: # Convert the uint64 value to sint64 for SQLite sint64_run_id = convert_uint64_to_sint64(run_id) query = "SELECT * FROM run WHERE run_id = ?;" - try: - row = self.query(query, (sint64_run_id,))[0] + rows = self.query(query, (sint64_run_id,)) + if rows: + row = rows[0] return Run( run_id=convert_sint64_to_uint64(row["run_id"]), fab_id=row["fab_id"], @@ -791,9 +792,8 @@ def get_run(self, run_id: int) -> Optional[Run]: fab_hash=row["fab_hash"], override_config=json.loads(row["override_config"]), ) - except sqlite3.IntegrityError: - log(ERROR, "`run_id` does not exist.") - return None + log(ERROR, "`run_id` does not exist.") + return None def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" From ebb78ae705e413718231ffb873846405c88a733a Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 19 Sep 2024 12:19:12 +0100 Subject: [PATCH 03/11] feat(framework) Add SuperExec logcatcher (#3584) Signed-off-by: Robert Steiner Co-authored-by: Robert Steiner Co-authored-by: Chong Shen Ng Co-authored-by: Charles Beauville Co-authored-by: jafermarq Co-authored-by: Taner Topal Co-authored-by: Daniel J. Beutel --- src/py/flwr/superexec/exec_servicer.py | 64 ++++++++++++++++++++- src/py/flwr/superexec/exec_servicer_test.py | 21 ++++++- src/py/flwr/superexec/executor.py | 3 +- 3 files changed, 82 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 5b729dbc2b8e..8bf384312fca 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -15,6 +15,9 @@ """SuperExec API servicer.""" +import select +import threading +import time from collections.abc import Generator from logging import ERROR, INFO from typing import Any @@ -33,6 +36,8 @@ from .executor import Executor, RunTracker +SELECT_TIMEOUT = 1 # Timeout for selecting ready-to-read file descriptors (in seconds) + class ExecServicer(exec_pb2_grpc.ExecServicer): """SuperExec API servicer.""" @@ -59,13 +64,66 @@ def StartRun( self.runs[run.run_id] = run + # Start a background thread to capture the log output + capture_thread = threading.Thread( + target=_capture_logs, args=(run,), daemon=True + ) + capture_thread.start() + return StartRunResponse(run_id=run.run_id) - def StreamLogs( + def StreamLogs( # pylint: disable=C0103 self, request: StreamLogsRequest, context: grpc.ServicerContext ) -> Generator[StreamLogsResponse, Any, None]: """Get logs.""" - logs = ["a", "b", "c"] + log(INFO, "ExecServicer.StreamLogs") + + # Exit if `run_id` not found + if request.run_id not in self.runs: + context.abort(grpc.StatusCode.NOT_FOUND, "Run ID not found") + + last_sent_index = 0 while context.is_active(): - for i in range(len(logs)): # pylint: disable=C0200 + # Yield n'th row of logs, if n'th row < len(logs) + logs = self.runs[request.run_id].logs + for i in range(last_sent_index, len(logs)): yield StreamLogsResponse(log_output=logs[i]) + last_sent_index = len(logs) + + # Wait for and continue to yield more log responses only if the + # run isn't completed yet. If the run is finished, the entire log + # is returned at this point and the server ends the stream. + if self.runs[request.run_id].proc.poll() is not None: + log(INFO, "All logs for run ID `%s` returned", request.run_id) + return + + time.sleep(1.0) # Sleep briefly to avoid busy waiting + + +def _capture_logs( + run: RunTracker, +) -> None: + while True: + # Explicitly check if Popen.poll() is None. Required for `pytest`. + if run.proc.poll() is None: + # Select streams only when ready to read + ready_to_read, _, _ = select.select( + [run.proc.stdout, run.proc.stderr], + [], + [], + SELECT_TIMEOUT, + ) + # Read from std* and append to RunTracker.logs + for stream in ready_to_read: + line = stream.readline().rstrip() + if line: + run.logs.append(f"{line}") + + # Close std* to prevent blocking + elif run.proc.poll() is not None: + log(INFO, "Subprocess finished, exiting log capture") + if run.proc.stdout: + run.proc.stdout.close() + if run.proc.stderr: + run.proc.stderr.close() + break diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index 83717d63a36e..b777bc806fe5 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -16,11 +16,11 @@ import subprocess -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 -from .exec_servicer import ExecServicer +from .exec_servicer import ExecServicer, _capture_logs def test_start_run() -> None: @@ -50,3 +50,20 @@ def test_start_run() -> None: response = servicer.StartRun(request, context_mock) assert response.run_id == 10 + + +def test_capture_logs() -> None: + """Test capture_logs function.""" + run_res = Mock() + run_res.logs = [] + with subprocess.Popen( + ["echo", "success"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) as proc: + run_res.proc = proc + _capture_logs(run_res) + + assert len(run_res.logs) == 1 + assert run_res.logs[0] == "success" diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index 8d630d108b66..08b66a438e4d 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -15,7 +15,7 @@ """Execute and monitor a Flower run.""" from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from subprocess import Popen from typing import Optional @@ -28,6 +28,7 @@ class RunTracker: run_id: int proc: Popen # type: ignore + logs: list[str] = field(default_factory=list) class Executor(ABC): From 839e5d9e75e7c5b31c44d355a741022c28086214 Mon Sep 17 00:00:00 2001 From: Yan Gao Date: Thu, 19 Sep 2024 12:34:34 +0100 Subject: [PATCH 04/11] refactor(framework) Update dependency and fix a print issue for FlowerTune template (#4220) --- src/py/flwr/cli/new/new.py | 2 +- src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index d2f7179b45b4..e7dbee894314 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -275,7 +275,7 @@ def new( ) ) - _add = " huggingface-cli login\n" if framework_str == "flowertune" else "" + _add = " huggingface-cli login\n" if llm_challenge_str else "" print( typer.style( f" cd {package_name}\n" + " pip install -e .\n" + _add + " flwr run\n", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 8c15739e92f6..8a4d49e7fd84 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -17,6 +17,7 @@ dependencies = [ "transformers==4.39.3", "sentencepiece==0.2.0", "omegaconf==2.3.0", + "hf_transfer==0.1.8", ] [tool.hatch.build.targets.wheel] From 94ba2375bb5d8bff94410bbbb1828263fbce18d2 Mon Sep 17 00:00:00 2001 From: Javier Date: Thu, 19 Sep 2024 16:13:36 +0200 Subject: [PATCH 05/11] refactor(examples) Remove mention of Docker in `quickstart-mlx` example (#4202) --- examples/quickstart-mlx/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/quickstart-mlx/README.md b/examples/quickstart-mlx/README.md index ef28c3728279..5914ce5f31dd 100644 --- a/examples/quickstart-mlx/README.md +++ b/examples/quickstart-mlx/README.md @@ -67,4 +67,4 @@ flwr run . --run-config "num-server-rounds=5 learning-rate=0.05" ### Run with the Deployment Engine > \[!NOTE\] -> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker. +> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates. From 1073b56bdb6725e2a4339ae0756ad910b6d508e5 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 19 Sep 2024 16:12:59 +0100 Subject: [PATCH 06/11] feat(framework) Add minimal `Control` service (#4239) --- src/proto/flwr/proto/control.proto | 25 ++++++++++ src/py/flwr/proto/control_pb2.py | 27 +++++++++++ src/py/flwr/proto/control_pb2.pyi | 7 +++ src/py/flwr/proto/control_pb2_grpc.py | 67 ++++++++++++++++++++++++++ src/py/flwr/proto/control_pb2_grpc.pyi | 27 +++++++++++ src/py/flwr_tool/protoc_test.py | 2 +- 6 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 src/proto/flwr/proto/control.proto create mode 100644 src/py/flwr/proto/control_pb2.py create mode 100644 src/py/flwr/proto/control_pb2.pyi create mode 100644 src/py/flwr/proto/control_pb2_grpc.py create mode 100644 src/py/flwr/proto/control_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/control.proto b/src/proto/flwr/proto/control.proto new file mode 100644 index 000000000000..4e0867bbf9e4 --- /dev/null +++ b/src/proto/flwr/proto/control.proto @@ -0,0 +1,25 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +import "flwr/proto/driver.proto"; + +service Control { + // Request to create a new run + rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} +} diff --git a/src/py/flwr/proto/control_pb2.py b/src/py/flwr/proto/control_pb2.py new file mode 100644 index 000000000000..d206e75cb388 --- /dev/null +++ b/src/py/flwr/proto/control_pb2.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/control.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x17\x66lwr/proto/driver.proto2U\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.control_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_CONTROL']._serialized_start=65 + _globals['_CONTROL']._serialized_end=150 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/control_pb2.pyi b/src/py/flwr/proto/control_pb2.pyi new file mode 100644 index 000000000000..e08fa11c2caa --- /dev/null +++ b/src/py/flwr/proto/control_pb2.pyi @@ -0,0 +1,7 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import google.protobuf.descriptor + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor diff --git a/src/py/flwr/proto/control_pb2_grpc.py b/src/py/flwr/proto/control_pb2_grpc.py new file mode 100644 index 000000000000..9c671be88a47 --- /dev/null +++ b/src/py/flwr/proto/control_pb2_grpc.py @@ -0,0 +1,67 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 + + +class ControlStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.CreateRun = channel.unary_unary( + '/flwr.proto.Control/CreateRun', + request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + ) + + +class ControlServicer(object): + """Missing associated documentation comment in .proto file.""" + + def CreateRun(self, request, context): + """Request to create a new run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ControlServicer_to_server(servicer, server): + rpc_method_handlers = { + 'CreateRun': grpc.unary_unary_rpc_method_handler( + servicer.CreateRun, + request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'flwr.proto.Control', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Control(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def CreateRun(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/CreateRun', + flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/control_pb2_grpc.pyi b/src/py/flwr/proto/control_pb2_grpc.pyi new file mode 100644 index 000000000000..f4613fa0e2f3 --- /dev/null +++ b/src/py/flwr/proto/control_pb2_grpc.pyi @@ -0,0 +1,27 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import abc +import flwr.proto.driver_pb2 +import grpc + +class ControlStub: + def __init__(self, channel: grpc.Channel) -> None: ... + CreateRun: grpc.UnaryUnaryMultiCallable[ + flwr.proto.driver_pb2.CreateRunRequest, + flwr.proto.driver_pb2.CreateRunResponse] + """Request to create a new run""" + + +class ControlServicer(metaclass=abc.ABCMeta): + @abc.abstractmethod + def CreateRun(self, + request: flwr.proto.driver_pb2.CreateRunRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.driver_pb2.CreateRunResponse: + """Request to create a new run""" + pass + + +def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 6f9127304f25..f0784a4498d2 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 13 + assert len(PROTO_FILES) == 14 From fe16ff9c2c4d859d050f7c8a8ed3e7384801a756 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 19 Sep 2024 16:50:56 +0100 Subject: [PATCH 07/11] feat(framework) Move run-related request/response to `run.proto` (#4240) --- src/proto/flwr/proto/control.proto | 2 +- src/proto/flwr/proto/driver.proto | 10 ---- src/proto/flwr/proto/run.proto | 12 +++++ src/py/flwr/proto/control_pb2.py | 8 +-- src/py/flwr/proto/control_pb2_grpc.py | 14 ++--- src/py/flwr/proto/control_pb2_grpc.pyi | 10 ++-- src/py/flwr/proto/driver_pb2.py | 39 ++++++-------- src/py/flwr/proto/driver_pb2.pyi | 52 ------------------- src/py/flwr/proto/driver_pb2_grpc.py | 12 ++--- src/py/flwr/proto/driver_pb2_grpc.pyi | 8 +-- src/py/flwr/proto/run_pb2.py | 27 ++++++---- src/py/flwr/proto/run_pb2.pyi | 52 +++++++++++++++++++ src/py/flwr/server/run_serverapp.py | 4 +- .../superlink/driver/driver_servicer.py | 4 +- src/py/flwr/superexec/deployment.py | 2 +- 15 files changed, 129 insertions(+), 127 deletions(-) diff --git a/src/proto/flwr/proto/control.proto b/src/proto/flwr/proto/control.proto index 4e0867bbf9e4..9747c2e3ea11 100644 --- a/src/proto/flwr/proto/control.proto +++ b/src/proto/flwr/proto/control.proto @@ -17,7 +17,7 @@ syntax = "proto3"; package flwr.proto; -import "flwr/proto/driver.proto"; +import "flwr/proto/run.proto"; service Control { // Request to create a new run diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index c7ae7dcf30f0..e26003862a76 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -21,7 +21,6 @@ import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; import "flwr/proto/run.proto"; import "flwr/proto/fab.proto"; -import "flwr/proto/transport.proto"; service Driver { // Request run_id @@ -43,15 +42,6 @@ service Driver { rpc GetFab(GetFabRequest) returns (GetFabResponse) {} } -// CreateRun -message CreateRunRequest { - string fab_id = 1; - string fab_version = 2; - map override_config = 3; - Fab fab = 4; -} -message CreateRunResponse { uint64 run_id = 1; } - // GetNodes messages message GetNodesRequest { uint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index fc3294f7a583..ada72610e182 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package flwr.proto; +import "flwr/proto/fab.proto"; import "flwr/proto/transport.proto"; message Run { @@ -26,5 +27,16 @@ message Run { map override_config = 4; string fab_hash = 5; } + +// CreateRun +message CreateRunRequest { + string fab_id = 1; + string fab_version = 2; + map override_config = 3; + Fab fab = 4; +} +message CreateRunResponse { uint64 run_id = 1; } + +// GetRun message GetRunRequest { uint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/py/flwr/proto/control_pb2.py b/src/py/flwr/proto/control_pb2.py index d206e75cb388..2b8776509d32 100644 --- a/src/py/flwr/proto/control_pb2.py +++ b/src/py/flwr/proto/control_pb2.py @@ -12,16 +12,16 @@ _sym_db = _symbol_database.Default() -from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x17\x66lwr/proto/driver.proto2U\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2U\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.control_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CONTROL']._serialized_start=65 - _globals['_CONTROL']._serialized_end=150 + _globals['_CONTROL']._serialized_start=62 + _globals['_CONTROL']._serialized_end=147 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/control_pb2_grpc.py b/src/py/flwr/proto/control_pb2_grpc.py index 9c671be88a47..987b9d8d7433 100644 --- a/src/py/flwr/proto/control_pb2_grpc.py +++ b/src/py/flwr/proto/control_pb2_grpc.py @@ -2,7 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2 +from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 class ControlStub(object): @@ -16,8 +16,8 @@ def __init__(self, channel): """ self.CreateRun = channel.unary_unary( '/flwr.proto.Control/CreateRun', - request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, ) @@ -36,8 +36,8 @@ def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { 'CreateRun': grpc.unary_unary_rpc_method_handler( servicer.CreateRun, - request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, - response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, + request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -61,7 +61,7 @@ def CreateRun(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/CreateRun', - flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/control_pb2_grpc.pyi b/src/py/flwr/proto/control_pb2_grpc.pyi index f4613fa0e2f3..c38b7f15d125 100644 --- a/src/py/flwr/proto/control_pb2_grpc.pyi +++ b/src/py/flwr/proto/control_pb2_grpc.pyi @@ -3,23 +3,23 @@ isort:skip_file """ import abc -import flwr.proto.driver_pb2 +import flwr.proto.run_pb2 import grpc class ControlStub: def __init__(self, channel: grpc.Channel) -> None: ... CreateRun: grpc.UnaryUnaryMultiCallable[ - flwr.proto.driver_pb2.CreateRunRequest, - flwr.proto.driver_pb2.CreateRunResponse] + flwr.proto.run_pb2.CreateRunRequest, + flwr.proto.run_pb2.CreateRunResponse] """Request to create a new run""" class ControlServicer(metaclass=abc.ABCMeta): @abc.abstractmethod def CreateRun(self, - request: flwr.proto.driver_pb2.CreateRunRequest, + request: flwr.proto.run_pb2.CreateRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.driver_pb2.CreateRunResponse: + ) -> flwr.proto.run_pb2.CreateRunResponse: """Request to create a new run""" pass diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index 1322660cfac1..d294b03be5af 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -16,36 +16,27 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 -from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc7\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc7\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_CREATERUNREQUEST']._serialized_start=158 - _globals['_CREATERUNREQUEST']._serialized_end=393 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=320 - _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=393 - _globals['_CREATERUNRESPONSE']._serialized_start=395 - _globals['_CREATERUNRESPONSE']._serialized_end=430 - _globals['_GETNODESREQUEST']._serialized_start=432 - _globals['_GETNODESREQUEST']._serialized_end=465 - _globals['_GETNODESRESPONSE']._serialized_start=467 - _globals['_GETNODESRESPONSE']._serialized_end=518 - _globals['_PUSHTASKINSREQUEST']._serialized_start=520 - _globals['_PUSHTASKINSREQUEST']._serialized_end=584 - _globals['_PUSHTASKINSRESPONSE']._serialized_start=586 - _globals['_PUSHTASKINSRESPONSE']._serialized_end=625 - _globals['_PULLTASKRESREQUEST']._serialized_start=627 - _globals['_PULLTASKRESREQUEST']._serialized_end=697 - _globals['_PULLTASKRESRESPONSE']._serialized_start=699 - _globals['_PULLTASKRESRESPONSE']._serialized_end=764 - _globals['_DRIVER']._serialized_start=767 - _globals['_DRIVER']._serialized_end=1222 + _globals['_GETNODESREQUEST']._serialized_start=129 + _globals['_GETNODESREQUEST']._serialized_end=162 + _globals['_GETNODESRESPONSE']._serialized_start=164 + _globals['_GETNODESRESPONSE']._serialized_end=215 + _globals['_PUSHTASKINSREQUEST']._serialized_start=217 + _globals['_PUSHTASKINSREQUEST']._serialized_end=281 + _globals['_PUSHTASKINSRESPONSE']._serialized_start=283 + _globals['_PUSHTASKINSRESPONSE']._serialized_end=322 + _globals['_PULLTASKRESREQUEST']._serialized_start=324 + _globals['_PULLTASKRESREQUEST']._serialized_end=394 + _globals['_PULLTASKRESRESPONSE']._serialized_start=396 + _globals['_PULLTASKRESRESPONSE']._serialized_end=461 + _globals['_DRIVER']._serialized_start=464 + _globals['_DRIVER']._serialized_end=919 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index d025e00474eb..77ceb496d70c 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -3,10 +3,8 @@ isort:skip_file """ import builtins -import flwr.proto.fab_pb2 import flwr.proto.node_pb2 import flwr.proto.task_pb2 -import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -15,56 +13,6 @@ import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -class CreateRunRequest(google.protobuf.message.Message): - """CreateRun""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor - class OverrideConfigEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: typing.Text - @property - def value(self) -> flwr.proto.transport_pb2.Scalar: ... - def __init__(self, - *, - key: typing.Text = ..., - value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... - - FAB_ID_FIELD_NUMBER: builtins.int - FAB_VERSION_FIELD_NUMBER: builtins.int - OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int - FAB_FIELD_NUMBER: builtins.int - fab_id: typing.Text - fab_version: typing.Text - @property - def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... - @property - def fab(self) -> flwr.proto.fab_pb2.Fab: ... - def __init__(self, - *, - fab_id: typing.Text = ..., - fab_version: typing.Text = ..., - override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., - fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., - ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... -global___CreateRunRequest = CreateRunRequest - -class CreateRunResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - RUN_ID_FIELD_NUMBER: builtins.int - run_id: builtins.int - def __init__(self, - *, - run_id: builtins.int = ..., - ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... -global___CreateRunResponse = CreateRunResponse - class GetNodesRequest(google.protobuf.message.Message): """GetNodes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index 6745bc7af62a..91e9fd8b9bdd 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -18,8 +18,8 @@ def __init__(self, channel): """ self.CreateRun = channel.unary_unary( '/flwr.proto.Driver/CreateRun', - request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, ) self.GetNodes = channel.unary_unary( '/flwr.proto.Driver/GetNodes', @@ -98,8 +98,8 @@ def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { 'CreateRun': grpc.unary_unary_rpc_method_handler( servicer.CreateRun, - request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, - response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, + request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, ), 'GetNodes': grpc.unary_unary_rpc_method_handler( servicer.GetNodes, @@ -148,8 +148,8 @@ def CreateRun(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateRun', - flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, - flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, + flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 7f9fd0acbd82..8f665301073d 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -11,8 +11,8 @@ import grpc class DriverStub: def __init__(self, channel: grpc.Channel) -> None: ... CreateRun: grpc.UnaryUnaryMultiCallable[ - flwr.proto.driver_pb2.CreateRunRequest, - flwr.proto.driver_pb2.CreateRunResponse] + flwr.proto.run_pb2.CreateRunRequest, + flwr.proto.run_pb2.CreateRunResponse] """Request run_id""" GetNodes: grpc.UnaryUnaryMultiCallable[ @@ -44,9 +44,9 @@ class DriverStub: class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod def CreateRun(self, - request: flwr.proto.driver_pb2.CreateRunRequest, + request: flwr.proto.run_pb2.CreateRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.driver_pb2.CreateRunResponse: + ) -> flwr.proto.run_pb2.CreateRunResponse: """Request run_id""" pass diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 13fc43f90f8c..99ca4df5c44c 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -12,10 +12,11 @@ _sym_db = _symbol_database.Default() +from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2 from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,12 +25,20 @@ DESCRIPTOR._options = None _globals['_RUN_OVERRIDECONFIGENTRY']._options = None _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' - _globals['_RUN']._serialized_start=65 - _globals['_RUN']._serialized_end=278 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=205 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=278 - _globals['_GETRUNREQUEST']._serialized_start=280 - _globals['_GETRUNREQUEST']._serialized_end=311 - _globals['_GETRUNRESPONSE']._serialized_start=313 - _globals['_GETRUNRESPONSE']._serialized_end=359 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_RUN']._serialized_start=87 + _globals['_RUN']._serialized_end=300 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=227 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=300 + _globals['_CREATERUNREQUEST']._serialized_start=303 + _globals['_CREATERUNREQUEST']._serialized_end=538 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=227 + _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=300 + _globals['_CREATERUNRESPONSE']._serialized_start=540 + _globals['_CREATERUNRESPONSE']._serialized_end=575 + _globals['_GETRUNREQUEST']._serialized_start=577 + _globals['_GETRUNREQUEST']._serialized_end=608 + _globals['_GETRUNRESPONSE']._serialized_start=610 + _globals['_GETRUNRESPONSE']._serialized_end=656 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index e65feee9c518..26b69e7eed27 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -3,6 +3,7 @@ isort:skip_file """ import builtins +import flwr.proto.fab_pb2 import flwr.proto.transport_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers @@ -51,7 +52,58 @@ class Run(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run +class CreateRunRequest(google.protobuf.message.Message): + """CreateRun""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class OverrideConfigEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> flwr.proto.transport_pb2.Scalar: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + FAB_ID_FIELD_NUMBER: builtins.int + FAB_VERSION_FIELD_NUMBER: builtins.int + OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int + FAB_FIELD_NUMBER: builtins.int + fab_id: typing.Text + fab_version: typing.Text + @property + def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... + @property + def fab(self) -> flwr.proto.fab_pb2.Fab: ... + def __init__(self, + *, + fab_id: typing.Text = ..., + fab_version: typing.Text = ..., + override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., + fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config"]) -> None: ... +global___CreateRunRequest = CreateRunRequest + +class CreateRunResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int + def __init__(self, + *, + run_id: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___CreateRunResponse = CreateRunResponse + class GetRunRequest(google.protobuf.message.Message): + """GetRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor RUN_ID_FIELD_NUMBER: builtins.int run_id: builtins.int diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index c0628db04360..a9ec05fe90e0 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -35,11 +35,11 @@ from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app from flwr.common.typing import UserConfig -from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 +from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 CreateRunRequest, CreateRunResponse, ) -from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from .driver import Driver from .driver.grpc_driver import GrpcDriver diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 4d7d6cb6ce89..3cafb3e71f7c 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -32,8 +32,6 @@ from flwr.common.typing import Fab from flwr.proto import driver_pb2_grpc # pylint: disable=E0611 from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 - CreateRunRequest, - CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -44,6 +42,8 @@ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + CreateRunRequest, + CreateRunResponse, GetRunRequest, GetRunResponse, Run, diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 55e519cf5f27..d60870995d39 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -28,8 +28,8 @@ from flwr.common.logger import log from flwr.common.serde import fab_to_proto, user_config_to_proto from flwr.common.typing import Fab, UserConfig -from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 from flwr.proto.driver_pb2_grpc import DriverStub +from flwr.proto.run_pb2 import CreateRunRequest # pylint: disable=E0611 from .executor import Executor, RunTracker From 02e181382df4b8f880bf92f92aafb357de50f161 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 19 Sep 2024 18:17:12 +0100 Subject: [PATCH 08/11] fix(framework:skip) Fix SuperExec log capturing (#4242) --- src/py/flwr/superexec/deployment.py | 2 ++ src/py/flwr/superexec/exec_servicer.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index d60870995d39..331fd817228e 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -167,6 +167,8 @@ def start_run( # Execute the command proc = subprocess.Popen( # pylint: disable=consider-using-with command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, ) log(INFO, "Started run %s", str(run_id)) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 8bf384312fca..ebb12b5ddbd2 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -16,6 +16,7 @@ import select +import sys import threading import time from collections.abc import Generator @@ -95,7 +96,8 @@ def StreamLogs( # pylint: disable=C0103 # is returned at this point and the server ends the stream. if self.runs[request.run_id].proc.poll() is not None: log(INFO, "All logs for run ID `%s` returned", request.run_id) - return + context.set_code(grpc.StatusCode.OK) + context.cancel() time.sleep(1.0) # Sleep briefly to avoid busy waiting @@ -115,7 +117,12 @@ def _capture_logs( ) # Read from std* and append to RunTracker.logs for stream in ready_to_read: - line = stream.readline().rstrip() + # Flush stdout to view output in real time + readline = stream.readline() + sys.stdout.write(readline) + sys.stdout.flush() + # Append to logs + line = readline.rstrip() if line: run.logs.append(f"{line}") From 247cada644c059df4575063d2ed3010a0c6a753b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Fri, 20 Sep 2024 16:32:04 +0100 Subject: [PATCH 09/11] fix(examples:skip) Specify `min_fit_clients` in `FedAvg` in the Secure Aggregation example (#4247) --- examples/flower-secure-aggregation/secaggexample/server_app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/flower-secure-aggregation/secaggexample/server_app.py b/examples/flower-secure-aggregation/secaggexample/server_app.py index 0f1b594317fa..0b95d68e4183 100644 --- a/examples/flower-secure-aggregation/secaggexample/server_app.py +++ b/examples/flower-secure-aggregation/secaggexample/server_app.py @@ -40,6 +40,7 @@ def main(driver: Driver, context: Context) -> None: strategy = FedAvg( # Select all available clients fraction_fit=1.0, + min_fit_clients=5, # Disable evaluation in demo fraction_evaluate=(0.0 if is_demo else context.run_config["fraction-evaluate"]), min_available_clients=5, From b6babe95addcbd1380f3894cc3ea71772119b72c Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 23 Sep 2024 11:37:28 +0100 Subject: [PATCH 10/11] feat(framework) Add new RPCs to `Control` service (#4241) --- src/proto/flwr/proto/control.proto | 7 +++ src/proto/flwr/proto/run.proto | 20 ++++++ src/py/flwr/proto/control_pb2.py | 6 +- src/py/flwr/proto/control_pb2_grpc.py | 68 ++++++++++++++++++++ src/py/flwr/proto/control_pb2_grpc.pyi | 26 ++++++++ src/py/flwr/proto/run_pb2.py | 32 +++++++--- src/py/flwr/proto/run_pb2.pyi | 86 ++++++++++++++++++++++++++ 7 files changed, 233 insertions(+), 12 deletions(-) diff --git a/src/proto/flwr/proto/control.proto b/src/proto/flwr/proto/control.proto index 9747c2e3ea11..8b75c66fccaa 100644 --- a/src/proto/flwr/proto/control.proto +++ b/src/proto/flwr/proto/control.proto @@ -22,4 +22,11 @@ import "flwr/proto/run.proto"; service Control { // Request to create a new run rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} + + // Get the status of a given run + rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {} + + // Update the status of a given run + rpc UpdateRunStatus(UpdateRunStatusRequest) + returns (UpdateRunStatusResponse) {} } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index ada72610e182..2c9bd877f66c 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -28,6 +28,15 @@ message Run { string fab_hash = 5; } +message RunStatus { + // "starting", "running", "finished" + string status = 1; + // "completed", "failed", "stopped" or "" (non-finished) + string sub_status = 2; + // failure details + string details = 3; +} + // CreateRun message CreateRunRequest { string fab_id = 1; @@ -40,3 +49,14 @@ message CreateRunResponse { uint64 run_id = 1; } // GetRun message GetRunRequest { uint64 run_id = 1; } message GetRunResponse { Run run = 1; } + +// UpdateRunStatus +message UpdateRunStatusRequest { + uint64 run_id = 1; + RunStatus run_status = 2; +} +message UpdateRunStatusResponse {} + +// GetRunStatus +message GetRunStatusRequest { repeated uint64 run_ids = 1; } +message GetRunStatusResponse { map run_status_dict = 1; } diff --git a/src/py/flwr/proto/control_pb2.py b/src/py/flwr/proto/control_pb2.py index 2b8776509d32..eb1c18d8dcff 100644 --- a/src/py/flwr/proto/control_pb2.py +++ b/src/py/flwr/proto/control_pb2.py @@ -15,13 +15,13 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2U\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/control.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/run.proto2\x88\x02\n\x07\x43ontrol\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12S\n\x0cGetRunStatus\x12\x1f.flwr.proto.GetRunStatusRequest\x1a .flwr.proto.GetRunStatusResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.control_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_CONTROL']._serialized_start=62 - _globals['_CONTROL']._serialized_end=147 + _globals['_CONTROL']._serialized_start=63 + _globals['_CONTROL']._serialized_end=327 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/control_pb2_grpc.py b/src/py/flwr/proto/control_pb2_grpc.py index 987b9d8d7433..a59f90f15935 100644 --- a/src/py/flwr/proto/control_pb2_grpc.py +++ b/src/py/flwr/proto/control_pb2_grpc.py @@ -19,6 +19,16 @@ def __init__(self, channel): request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString, response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, ) + self.GetRunStatus = channel.unary_unary( + '/flwr.proto.Control/GetRunStatus', + request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + ) + self.UpdateRunStatus = channel.unary_unary( + '/flwr.proto.Control/UpdateRunStatus', + request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, + ) class ControlServicer(object): @@ -31,6 +41,20 @@ def CreateRun(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetRunStatus(self, request, context): + """Get the status of a given run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateRunStatus(self, request, context): + """Update the status of a given run + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { @@ -39,6 +63,16 @@ def add_ControlServicer_to_server(servicer, server): request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString, response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString, ), + 'GetRunStatus': grpc.unary_unary_rpc_method_handler( + servicer.GetRunStatus, + request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString, + ), + 'UpdateRunStatus': grpc.unary_unary_rpc_method_handler( + servicer.UpdateRunStatus, + request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString, + response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'flwr.proto.Control', rpc_method_handlers) @@ -65,3 +99,37 @@ def CreateRun(request, flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetRunStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/GetRunStatus', + flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateRunStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/UpdateRunStatus', + flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString, + flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/control_pb2_grpc.pyi b/src/py/flwr/proto/control_pb2_grpc.pyi index c38b7f15d125..7817e2b12e31 100644 --- a/src/py/flwr/proto/control_pb2_grpc.pyi +++ b/src/py/flwr/proto/control_pb2_grpc.pyi @@ -13,6 +13,16 @@ class ControlStub: flwr.proto.run_pb2.CreateRunResponse] """Request to create a new run""" + GetRunStatus: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.GetRunStatusRequest, + flwr.proto.run_pb2.GetRunStatusResponse] + """Get the status of a given run""" + + UpdateRunStatus: grpc.UnaryUnaryMultiCallable[ + flwr.proto.run_pb2.UpdateRunStatusRequest, + flwr.proto.run_pb2.UpdateRunStatusResponse] + """Update the status of a given run""" + class ControlServicer(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -23,5 +33,21 @@ class ControlServicer(metaclass=abc.ABCMeta): """Request to create a new run""" pass + @abc.abstractmethod + def GetRunStatus(self, + request: flwr.proto.run_pb2.GetRunStatusRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.GetRunStatusResponse: + """Get the status of a given run""" + pass + + @abc.abstractmethod + def UpdateRunStatus(self, + request: flwr.proto.run_pb2.UpdateRunStatusRequest, + context: grpc.ServicerContext, + ) -> flwr.proto.run_pb2.UpdateRunStatusResponse: + """Update the status of a given run""" + pass + def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr/proto/run_pb2.py b/src/py/flwr/proto/run_pb2.py index 99ca4df5c44c..d59cc26fbb48 100644 --- a/src/py/flwr/proto/run_pb2.py +++ b/src/py/flwr/proto/run_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/fab.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd5\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\xeb\x01\n\x10\x43reateRunRequest\x12\x0e\n\x06\x66\x61\x62_id\x18\x01 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x02 \x01(\t\x12I\n\x0foverride_config\x18\x03 \x03(\x0b\x32\x30.flwr.proto.CreateRunRequest.OverrideConfigEntry\x12\x1c\n\x03\x66\x61\x62\x18\x04 \x01(\x0b\x32\x0f.flwr.proto.Fab\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"&\n\x13GetRunStatusRequest\x12\x0f\n\x07run_ids\x18\x01 \x03(\x04\"\xb1\x01\n\x14GetRunStatusResponse\x12L\n\x0frun_status_dict\x18\x01 \x03(\x0b\x32\x33.flwr.proto.GetRunStatusResponse.RunStatusDictEntry\x1aK\n\x12RunStatusDictEntry\x12\x0b\n\x03key\x18\x01 \x01(\x04\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -27,18 +27,32 @@ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._options = None _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._options = None + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_options = b'8\001' _globals['_RUN']._serialized_start=87 _globals['_RUN']._serialized_end=300 _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=227 _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_CREATERUNREQUEST']._serialized_start=303 - _globals['_CREATERUNREQUEST']._serialized_end=538 + _globals['_RUNSTATUS']._serialized_start=302 + _globals['_RUNSTATUS']._serialized_end=366 + _globals['_CREATERUNREQUEST']._serialized_start=369 + _globals['_CREATERUNREQUEST']._serialized_end=604 _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_start=227 _globals['_CREATERUNREQUEST_OVERRIDECONFIGENTRY']._serialized_end=300 - _globals['_CREATERUNRESPONSE']._serialized_start=540 - _globals['_CREATERUNRESPONSE']._serialized_end=575 - _globals['_GETRUNREQUEST']._serialized_start=577 - _globals['_GETRUNREQUEST']._serialized_end=608 - _globals['_GETRUNRESPONSE']._serialized_start=610 - _globals['_GETRUNRESPONSE']._serialized_end=656 + _globals['_CREATERUNRESPONSE']._serialized_start=606 + _globals['_CREATERUNRESPONSE']._serialized_end=641 + _globals['_GETRUNREQUEST']._serialized_start=643 + _globals['_GETRUNREQUEST']._serialized_end=674 + _globals['_GETRUNRESPONSE']._serialized_start=676 + _globals['_GETRUNRESPONSE']._serialized_end=722 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=724 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=807 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=809 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=834 + _globals['_GETRUNSTATUSREQUEST']._serialized_start=836 + _globals['_GETRUNSTATUSREQUEST']._serialized_end=874 + _globals['_GETRUNSTATUSRESPONSE']._serialized_start=877 + _globals['_GETRUNSTATUSRESPONSE']._serialized_end=1054 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_start=979 + _globals['_GETRUNSTATUSRESPONSE_RUNSTATUSDICTENTRY']._serialized_end=1054 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/run_pb2.pyi b/src/py/flwr/proto/run_pb2.pyi index 26b69e7eed27..cec90c4d2d4c 100644 --- a/src/py/flwr/proto/run_pb2.pyi +++ b/src/py/flwr/proto/run_pb2.pyi @@ -52,6 +52,29 @@ class Run(google.protobuf.message.Message): def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ... global___Run = Run +class RunStatus(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + STATUS_FIELD_NUMBER: builtins.int + SUB_STATUS_FIELD_NUMBER: builtins.int + DETAILS_FIELD_NUMBER: builtins.int + status: typing.Text + """"starting", "running", "finished" """ + + sub_status: typing.Text + """"completed", "failed", "stopped" or "" (non-finished)""" + + details: typing.Text + """failure details""" + + def __init__(self, + *, + status: typing.Text = ..., + sub_status: typing.Text = ..., + details: typing.Text = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["details",b"details","status",b"status","sub_status",b"sub_status"]) -> None: ... +global___RunStatus = RunStatus + class CreateRunRequest(google.protobuf.message.Message): """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -126,3 +149,66 @@ class GetRunResponse(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ... global___GetRunResponse = GetRunResponse + +class UpdateRunStatusRequest(google.protobuf.message.Message): + """UpdateRunStatus""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_ID_FIELD_NUMBER: builtins.int + RUN_STATUS_FIELD_NUMBER: builtins.int + run_id: builtins.int + @property + def run_status(self) -> global___RunStatus: ... + def __init__(self, + *, + run_id: builtins.int = ..., + run_status: typing.Optional[global___RunStatus] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["run_status",b"run_status"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","run_status",b"run_status"]) -> None: ... +global___UpdateRunStatusRequest = UpdateRunStatusRequest + +class UpdateRunStatusResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + def __init__(self, + ) -> None: ... +global___UpdateRunStatusResponse = UpdateRunStatusResponse + +class GetRunStatusRequest(google.protobuf.message.Message): + """GetRunStatus""" + DESCRIPTOR: google.protobuf.descriptor.Descriptor + RUN_IDS_FIELD_NUMBER: builtins.int + @property + def run_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + def __init__(self, + *, + run_ids: typing.Optional[typing.Iterable[builtins.int]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_ids",b"run_ids"]) -> None: ... +global___GetRunStatusRequest = GetRunStatusRequest + +class GetRunStatusResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class RunStatusDictEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.int + @property + def value(self) -> global___RunStatus: ... + def __init__(self, + *, + key: builtins.int = ..., + value: typing.Optional[global___RunStatus] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + RUN_STATUS_DICT_FIELD_NUMBER: builtins.int + @property + def run_status_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___RunStatus]: ... + def __init__(self, + *, + run_status_dict: typing.Optional[typing.Mapping[builtins.int, global___RunStatus]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_status_dict",b"run_status_dict"]) -> None: ... +global___GetRunStatusResponse = GetRunStatusResponse From b370b0618fe53d6f01e5486c3269ee09dc0dcc6e Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Mon, 23 Sep 2024 13:54:22 +0100 Subject: [PATCH 11/11] feat(framework) Add `flwr log` (#3577) Co-authored-by: Charles Beauville Co-authored-by: jafermarq Co-authored-by: Taner Topal --- src/py/flwr/cli/app.py | 2 + src/py/flwr/cli/log.py | 196 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 src/py/flwr/cli/log.py diff --git a/src/py/flwr/cli/app.py b/src/py/flwr/cli/app.py index 93effea6df98..8baccb4638fc 100644 --- a/src/py/flwr/cli/app.py +++ b/src/py/flwr/cli/app.py @@ -19,6 +19,7 @@ from .build import build from .install import install +from .log import log from .new import new from .run import run @@ -35,6 +36,7 @@ app.command()(run) app.command()(build) app.command()(install) +app.command()(log) typer_click_object = get_command(app) diff --git a/src/py/flwr/cli/log.py b/src/py/flwr/cli/log.py new file mode 100644 index 000000000000..6915de1e00c5 --- /dev/null +++ b/src/py/flwr/cli/log.py @@ -0,0 +1,196 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `log` command.""" + +import sys +import time +from logging import DEBUG, ERROR, INFO +from pathlib import Path +from typing import Annotated, Optional + +import grpc +import typer + +from flwr.cli.config_utils import load_and_validate +from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel +from flwr.common.logger import log as logger + +CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds) + + +# pylint: disable=unused-argument +def stream_logs(run_id: int, channel: grpc.Channel, period: int) -> None: + """Stream logs from the beginning of a run with connection refresh.""" + + +# pylint: disable=unused-argument +def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None: + """Print logs from the beginning of a run.""" + + +def on_channel_state_change(channel_connectivity: str) -> None: + """Log channel connectivity.""" + logger(DEBUG, channel_connectivity) + + +def log( + run_id: Annotated[ + int, + typer.Argument(help="The Flower run ID to query"), + ], + app: Annotated[ + Path, + typer.Argument(help="Path of the Flower project to run"), + ] = Path("."), + federation: Annotated[ + Optional[str], + typer.Argument(help="Name of the federation to run the app on"), + ] = None, + stream: Annotated[ + bool, + typer.Option( + "--stream/--show", + help="Flag to stream or print logs from the Flower run", + ), + ] = True, +) -> None: + """Get logs from a Flower project run.""" + typer.secho("Loading project configuration... ", fg=typer.colors.BLUE) + + pyproject_path = app / "pyproject.toml" if app else None + config, errors, warnings = load_and_validate(path=pyproject_path) + + if config is None: + typer.secho( + "Project configuration could not be loaded.\n" + "pyproject.toml is invalid:\n" + + "\n".join([f"- {line}" for line in errors]), + fg=typer.colors.RED, + bold=True, + ) + sys.exit() + + if warnings: + typer.secho( + "Project configuration is missing the following " + "recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]), + fg=typer.colors.RED, + bold=True, + ) + + typer.secho("Success", fg=typer.colors.GREEN) + + federation = federation or config["tool"]["flwr"]["federations"].get("default") + + if federation is None: + typer.secho( + "❌ No federation name was provided and the project's `pyproject.toml` " + "doesn't declare a default federation (with a SuperExec address or an " + "`options.num-supernodes` value).", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + # Validate the federation exists in the configuration + federation_config = config["tool"]["flwr"]["federations"].get(federation) + if federation_config is None: + available_feds = { + fed for fed in config["tool"]["flwr"]["federations"] if fed != "default" + } + typer.secho( + f"❌ There is no `{federation}` federation declared in the " + "`pyproject.toml`.\n The following federations were found:\n\n" + + "\n".join(available_feds), + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + if "address" not in federation_config: + typer.secho( + "❌ `flwr log` currently works with `SuperExec`. Ensure that the correct" + "`SuperExec` address is provided in the `pyproject.toml`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + _log_with_superexec(federation_config, run_id, stream) + + +# pylint: disable-next=too-many-branches +def _log_with_superexec( + federation_config: dict[str, str], + run_id: int, + stream: bool, +) -> None: + insecure_str = federation_config.get("insecure") + if root_certificates := federation_config.get("root-certificates"): + root_certificates_bytes = Path(root_certificates).read_bytes() + if insecure := bool(insecure_str): + typer.secho( + "❌ `root_certificates` were provided but the `insecure` parameter" + "is set to `True`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + else: + root_certificates_bytes = None + if insecure_str is None: + typer.secho( + "❌ To disable TLS, set `insecure = true` in `pyproject.toml`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + if not (insecure := bool(insecure_str)): + typer.secho( + "❌ No certificate were given yet `insecure` is set to `False`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + channel = create_channel( + server_address=federation_config["address"], + insecure=insecure, + root_certificates=root_certificates_bytes, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + interceptors=None, + ) + channel.subscribe(on_channel_state_change) + + if stream: + try: + while True: + logger(INFO, "Starting logstream for run_id `%s`", run_id) + stream_logs(run_id, channel, CONN_REFRESH_PERIOD) + time.sleep(2) + logger(DEBUG, "Reconnecting to logstream") + except KeyboardInterrupt: + logger(INFO, "Exiting logstream") + except grpc.RpcError as e: + # pylint: disable=E1101 + if e.code() == grpc.StatusCode.NOT_FOUND: + logger(ERROR, "Invalid run_id `%s`, exiting", run_id) + if e.code() == grpc.StatusCode.CANCELLED: + pass + finally: + channel.close() + else: + logger(INFO, "Printing logstream for run_id `%s`", run_id) + print_logs(run_id, channel, timeout=5)