From 1fbfd82124f6f0d37aae67d1a3c24dc8e8ad4e1c Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 21 Oct 2024 09:14:51 +0200 Subject: [PATCH 01/10] docs(datasets) Update FDS version to 0.4.0 (#4112) --- datasets/doc/source/conf.py | 2 +- datasets/doc/source/how-to-install-flwr-datasets.rst | 2 +- datasets/pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datasets/doc/source/conf.py b/datasets/doc/source/conf.py index dcba63dd221c..92d59d7df370 100644 --- a/datasets/doc/source/conf.py +++ b/datasets/doc/source/conf.py @@ -38,7 +38,7 @@ author = "The Flower Authors" # The full version, including alpha/beta/rc tags -release = "0.3.0" +release = "0.4.0" # -- General configuration --------------------------------------------------- diff --git a/datasets/doc/source/how-to-install-flwr-datasets.rst b/datasets/doc/source/how-to-install-flwr-datasets.rst index 2068fc11da85..3f79daceb753 100644 --- a/datasets/doc/source/how-to-install-flwr-datasets.rst +++ b/datasets/doc/source/how-to-install-flwr-datasets.rst @@ -42,5 +42,5 @@ If everything worked, it should print the version of Flower Datasets to the comm .. code-block:: none - 0.3.0 + 0.4.0 diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml index 27303c0547c9..af7c1f1bde2a 100644 --- a/datasets/pyproject.toml +++ b/datasets/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "flwr-datasets" -version = "0.3.0" +version = "0.4.0" description = "Flower Datasets" license = "Apache-2.0" authors = ["The Flower Authors "] From cdc8c43d63e1dbb50d960662caedd205afee7429 Mon Sep 17 00:00:00 2001 From: Robert Steiner Date: Mon, 21 Oct 2024 13:32:17 +0200 Subject: [PATCH 02/10] docs(framework:skip) Fix Docker TLS chown cmd (#4342) Signed-off-by: Robert Steiner --- doc/source/docker/enable-tls.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/docker/enable-tls.rst b/doc/source/docker/enable-tls.rst index f50edb8c651d..7225f86a5ddb 100644 --- a/doc/source/docker/enable-tls.rst +++ b/doc/source/docker/enable-tls.rst @@ -21,7 +21,7 @@ For example, to change the user ID of all files in the ``certificates/`` directo can run ``sudo chown -R 49999:49999 certificates/*``. If you later want to delete the directory, you can change the user ID back to the -current user ID by running ``sudo chown -R $USER:$(id -gn) state``. +current user ID by running ``sudo chown -R $USER:$(id -gn) certificates``. SuperLink --------- From c2682963a43f07a514a24f4a8e9e3454169bed28 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 22 Oct 2024 13:42:41 +0100 Subject: [PATCH 03/10] refactor(framework) Run exec servicer in SuperLink (#4283) --- .github/workflows/e2e.yml | 4 +- e2e/{test_superexec.sh => test_exec_api.sh} | 34 +++++++------- e2e/test_superlink.sh | 2 +- src/py/flwr/common/constant.py | 5 +-- src/py/flwr/server/app.py | 50 +++++++++++++++++++-- src/py/flwr/superexec/app.py | 4 +- 6 files changed, 69 insertions(+), 30 deletions(-) rename e2e/{test_superexec.sh => test_exec_api.sh} (73%) diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 012f584561ac..5e93da349602 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -102,12 +102,12 @@ jobs: python -m pip install "${WHEEL_URL}" fi - name: > - Run SuperExec test / + Run Exec API test / ${{ matrix.connection }} / ${{ matrix.authentication }} / ${{ matrix.engine }} working-directory: e2e/${{ matrix.directory }} - run: ./../test_superexec.sh "${{ matrix.connection }}" "${{ matrix.authentication}}" "${{ matrix.engine }}" + run: ./../test_exec_api.sh "${{ matrix.connection }}" "${{ matrix.authentication}}" "${{ matrix.engine }}" frameworks: runs-on: ubuntu-22.04 diff --git a/e2e/test_superexec.sh b/e2e/test_exec_api.sh similarity index 73% rename from e2e/test_superexec.sh rename to e2e/test_exec_api.sh index ae79128c6ac1..fd5e8c69d1de 100755 --- a/e2e/test_superexec.sh +++ b/e2e/test_exec_api.sh @@ -9,14 +9,13 @@ case "$1" in --ssl-certfile ../certificates/server.pem --ssl-keyfile ../certificates/server.key' client_arg='--root-certificates ../certificates/ca.crt' - # For $superexec_arg, note special ordering of single- and double-quotes - superexec_arg='--executor-config 'root-certificates=\"../certificates/ca.crt\"'' - superexec_arg="$server_arg $superexec_arg" + # For $executor_config, note special ordering of single- and double-quotes + executor_config='root-certificates="../certificates/ca.crt"' ;; insecure) server_arg='--insecure' client_arg=$server_arg - superexec_arg=$server_arg + executor_config='' ;; esac @@ -43,11 +42,11 @@ esac # Set engine case "$3" in deployment-engine) - superexec_engine_arg='--executor flwr.superexec.deployment:executor' + executor_arg="--executor flwr.superexec.deployment:executor" ;; simulation-engine) - superexec_engine_arg='--executor flwr.superexec.simulation:executor - --executor-config 'num-supernodes=10'' + executor_config="$executor_config num-supernodes=10" + executor_arg="--executor flwr.superexec.simulation:executor" ;; esac @@ -69,14 +68,17 @@ pip install -e . --no-deps # Check if the first argument is 'insecure' if [ "$1" == "insecure" ]; then # If $1 is 'insecure', append the first line - echo -e $"\n[tool.flwr.federations.superexec]\naddress = \"127.0.0.1:9093\"\ninsecure = true" >> pyproject.toml + echo -e $"\n[tool.flwr.federations.e2e]\naddress = \"127.0.0.1:9093\"\ninsecure = true" >> pyproject.toml else # Otherwise, append the second line - echo -e $"\n[tool.flwr.federations.superexec]\naddress = \"127.0.0.1:9093\"\nroot-certificates = \"../certificates/ca.crt\"" >> pyproject.toml + echo -e $"\n[tool.flwr.federations.e2e]\naddress = \"127.0.0.1:9093\"\nroot-certificates = \"../certificates/ca.crt\"" >> pyproject.toml fi -timeout 2m flower-superlink $server_arg $server_auth & -sl_pid=$! +# Combine the arguments into a single command for flower-superlink +combined_args="$server_arg $server_auth $exec_api_arg $executor_arg" + +timeout 2m flower-superlink $combined_args --executor-config "$executor_config" 2>&1 | tee flwr_output.log & +sl_pid=$(pgrep -f "flower-superlink") sleep 2 timeout 2m flower-supernode ./ $client_arg \ @@ -91,11 +93,7 @@ timeout 2m flower-supernode ./ $client_arg \ cl2_pid=$! sleep 2 -timeout 2m flower-superexec $superexec_arg $superexec_engine_arg 2>&1 | tee flwr_output.log & -se_pid=$(pgrep -f "flower-superexec") -sleep 2 - -timeout 1m flwr run --run-config num-server-rounds=1 ../e2e-tmp-test superexec +timeout 1m flwr run --run-config num-server-rounds=1 ../e2e-tmp-test e2e # Initialize a flag to track if training is successful found_success=false @@ -107,7 +105,7 @@ while [ "$found_success" = false ] && [ $elapsed -lt $timeout ]; do if grep -q "Run finished" flwr_output.log; then echo "Training worked correctly!" found_success=true - kill $cl1_pid; kill $cl2_pid; sleep 1; kill $sl_pid; kill $se_pid; + kill $cl1_pid; kill $cl2_pid; sleep 1; kill $sl_pid; else echo "Waiting for training ... ($elapsed seconds elapsed)" fi @@ -118,5 +116,5 @@ done if [ "$found_success" = false ]; then echo "Training had an issue and timed out." - kill $cl1_pid; kill $cl2_pid; kill $sl_pid; kill $se_pid; + kill $cl1_pid; kill $cl2_pid; kill $sl_pid; fi diff --git a/e2e/test_superlink.sh b/e2e/test_superlink.sh index 2016f6da1933..630c6dcf8e96 100755 --- a/e2e/test_superlink.sh +++ b/e2e/test_superlink.sh @@ -19,7 +19,7 @@ case "$2" in rest) rest_arg_superlink="--fleet-api-type rest" rest_arg_supernode="--rest" - server_address="http://localhost:9093" + server_address="http://localhost:9095" server_app_address="127.0.0.1:9091" db_arg="--database :flwr-in-memory-state:" server_auth="" diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index e99e0edaacd4..98607a46835e 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -40,15 +40,14 @@ # Addresses # SuperNode CLIENTAPPIO_API_DEFAULT_ADDRESS = "0.0.0.0:9094" -# SuperExec -EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093" # SuperLink DRIVER_API_DEFAULT_ADDRESS = "0.0.0.0:9091" FLEET_API_GRPC_RERE_DEFAULT_ADDRESS = "0.0.0.0:9092" FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS = ( "[::]:8080" # IPv6 to keep start_server compatible ) -FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9093" +FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9095" +EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093" # Constants for ping PING_DEFAULT_INTERVAL = 30 diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 58918dbb79ab..764b9b9b3025 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -35,9 +35,10 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event from flwr.common.address import parse_address -from flwr.common.config import get_flwr_dir +from flwr.common.config import get_flwr_dir, parse_config_args from flwr.common.constant import ( DRIVER_API_DEFAULT_ADDRESS, + EXEC_API_DEFAULT_ADDRESS, FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS, FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, FLEET_API_REST_DEFAULT_ADDRESS, @@ -56,6 +57,8 @@ add_FleetServicer_to_server, ) from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server +from flwr.superexec.app import load_executor +from flwr.superexec.exec_grpc import run_superexec_api_grpc from .client_manager import ClientManager from .history import History @@ -205,8 +208,9 @@ def run_superlink() -> None: event(EventType.RUN_SUPERLINK_ENTER) - # Parse IP address + # Parse IP addresses driver_address, _, _ = _format_address(args.driver_api_address) + exec_address, _, _ = _format_address(args.exec_api_address) # Obtain certificates certificates = _try_obtain_certificates(args) @@ -224,8 +228,9 @@ def run_superlink() -> None: ffs_factory=ffs_factory, certificates=certificates, ) - grpc_servers = [driver_server] + + # Start Fleet API bckg_threads = [] if not args.fleet_api_address: if args.fleet_api_type in [ @@ -250,7 +255,6 @@ def run_superlink() -> None: ) num_workers = 1 - # Start Fleet API if args.fleet_api_type == TRANSPORT_TYPE_REST: if ( importlib.util.find_spec("requests") @@ -318,6 +322,17 @@ def run_superlink() -> None: else: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") + # Start Exec API + exec_server: grpc.Server = run_superexec_api_grpc( + address=exec_address, + executor=load_executor(args), + certificates=certificates, + config=parse_config_args( + [args.executor_config] if args.executor_config else args.executor_config + ), + ) + grpc_servers.append(exec_server) + # Graceful shutdown register_exit_handlers( event_type=EventType.RUN_SUPERLINK_LEAVE, @@ -587,6 +602,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser: _add_args_common(parser=parser) _add_args_driver_api(parser=parser) _add_args_fleet_api(parser=parser) + _add_args_exec_api(parser=parser) return parser @@ -681,3 +697,29 @@ def _add_args_fleet_api(parser: argparse.ArgumentParser) -> None: type=int, help="Set the number of concurrent workers for the Fleet API server.", ) + + +def _add_args_exec_api(parser: argparse.ArgumentParser) -> None: + """Add command line arguments for Exec API.""" + parser.add_argument( + "--exec-api-address", + help="Exec API server address (IPv4, IPv6, or a domain name)", + default=EXEC_API_DEFAULT_ADDRESS, + ) + parser.add_argument( + "--executor", + help="For example: `deployment:exec` or `project.package.module:wrapper.exec`. " + "The default is `flwr.superexec.deployment:executor`", + default="flwr.superexec.deployment:executor", + ) + parser.add_argument( + "--executor-dir", + help="The directory for the executor.", + default=".", + ) + parser.add_argument( + "--executor-config", + help="Key-value pairs for the executor config, separated by spaces. " + "For example:\n\n`--executor-config 'verbose=true " + 'root-certificates="certificates/superlink-ca.crt"\'`', + ) diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index c00aa0f88e7b..1da0557ceab9 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -54,7 +54,7 @@ def run_superexec() -> None: # Start SuperExec API superexec_server: grpc.Server = run_superexec_api_grpc( address=address, - executor=_load_executor(args), + executor=load_executor(args), certificates=certificates, config=parse_config_args( [args.executor_config] if args.executor_config else args.executor_config @@ -163,7 +163,7 @@ def _try_obtain_certificates( ) -def _load_executor( +def load_executor( args: argparse.Namespace, ) -> Executor: """Get the executor plugin.""" From 2a804964be8a33fbad244906b5f9274a5a81965b Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 22 Oct 2024 13:54:41 +0100 Subject: [PATCH 04/10] refactor(framework) Rename `State` to `LinkState` (#4347) --- src/py/flwr/client/app.py | 2 +- .../clientapp/clientappio_servicer_test.py | 2 +- src/py/flwr/server/app.py | 10 +-- src/py/flwr/server/driver/inmemory_driver.py | 4 +- .../server/driver/inmemory_driver_test.py | 16 ++-- .../server/superlink/driver/driver_grpc.py | 4 +- .../superlink/driver/driver_servicer.py | 16 ++-- .../grpc_adapter/grpc_adapter_servicer.py | 6 +- .../fleet/grpc_rere/fleet_servicer.py | 6 +- .../fleet/grpc_rere/server_interceptor.py | 4 +- .../grpc_rere/server_interceptor_test.py | 4 +- .../fleet/message_handler/message_handler.py | 14 +-- .../superlink/fleet/rest_rere/rest_api.py | 14 +-- .../server/superlink/fleet/vce/vce_api.py | 14 +-- .../superlink/fleet/vce/vce_api_test.py | 12 +-- .../{state => linkstate}/__init__.py | 18 ++-- .../in_memory_linkstate.py} | 16 ++-- .../state.py => linkstate/linkstate.py} | 20 ++--- .../linkstate_factory.py} | 18 ++-- .../linkstate_test.py} | 86 ++++++++++--------- .../sqlite_linkstate.py} | 28 +++--- .../sqlite_linkstate_test.py} | 4 +- .../superlink/{state => linkstate}/utils.py | 0 .../{state => linkstate}/utils_test.py | 0 src/py/flwr/simulation/app.py | 2 +- src/py/flwr/simulation/run_simulation.py | 6 +- src/py/flwr/superexec/simulation.py | 2 +- 27 files changed, 171 insertions(+), 157 deletions(-) rename src/py/flwr/server/superlink/{state => linkstate}/__init__.py (65%) rename src/py/flwr/server/superlink/{state/in_memory_state.py => linkstate/in_memory_linkstate.py} (97%) rename src/py/flwr/server/superlink/{state/state.py => linkstate/linkstate.py} (93%) rename src/py/flwr/server/superlink/{state/state_factory.py => linkstate/linkstate_factory.py} (80%) rename src/py/flwr/server/superlink/{state/state_test.py => linkstate/linkstate_test.py} (94%) rename src/py/flwr/server/superlink/{state/sqlite_state.py => linkstate/sqlite_linkstate.py} (97%) rename src/py/flwr/server/superlink/{state/sqlite_state_test.py => linkstate/sqlite_linkstate_test.py} (91%) rename src/py/flwr/server/superlink/{state => linkstate}/utils.py (100%) rename src/py/flwr/server/superlink/{state => linkstate}/utils_test.py (100%) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index fdb62578292a..3ae9b1b1e282 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -52,7 +52,7 @@ from flwr.common.typing import Fab, Run, UserConfig from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server -from flwr.server.superlink.state.utils import generate_rand_int_from_bytes +from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes from .clientapp.clientappio_servicer import ClientAppInputs, ClientAppIoServicer from .grpc_adapter_client.connection import grpc_adapter diff --git a/src/py/flwr/client/clientapp/clientappio_servicer_test.py b/src/py/flwr/client/clientapp/clientappio_servicer_test.py index a03400c12a86..82c9f16e8201 100644 --- a/src/py/flwr/client/clientapp/clientappio_servicer_test.py +++ b/src/py/flwr/client/clientapp/clientappio_servicer_test.py @@ -36,7 +36,7 @@ ) from flwr.proto.message_pb2 import Context as ProtoContext from flwr.proto.run_pb2 import Run as ProtoRun -from flwr.server.superlink.state.utils import generate_rand_int_from_bytes +from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes from .clientappio_servicer import ClientAppInputs, ClientAppIoServicer, ClientAppOutputs diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 764b9b9b3025..72cb2a9b3d9d 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -74,7 +74,7 @@ ) from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor -from .superlink.state import StateFactory +from .superlink.linkstate import LinkStateFactory DATABASE = ":flwr-in-memory-state:" BASE_DIR = get_flwr_dir() / "superlink" / "ffs" @@ -216,7 +216,7 @@ def run_superlink() -> None: certificates = _try_obtain_certificates(args) # Initialize StateFactory - state_factory = StateFactory(args.database) + state_factory = LinkStateFactory(args.database) # Initialize FfsFactory ffs_factory = FfsFactory(args.storage_dir) @@ -504,7 +504,7 @@ def _try_obtain_certificates( def _run_fleet_api_grpc_rere( address: str, - state_factory: StateFactory, + state_factory: LinkStateFactory, ffs_factory: FfsFactory, certificates: Optional[tuple[bytes, bytes, bytes]], interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, @@ -532,7 +532,7 @@ def _run_fleet_api_grpc_rere( def _run_fleet_api_grpc_adapter( address: str, - state_factory: StateFactory, + state_factory: LinkStateFactory, ffs_factory: FfsFactory, certificates: Optional[tuple[bytes, bytes, bytes]], ) -> grpc.Server: @@ -563,7 +563,7 @@ def _run_fleet_api_rest( port: int, ssl_keyfile: Optional[str], ssl_certfile: Optional[str], - state_factory: StateFactory, + state_factory: LinkStateFactory, ffs_factory: FfsFactory, num_workers: int, ) -> None: diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 130562c6defa..4eb1eb9c1040 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -25,7 +25,7 @@ from flwr.common.serde import message_from_taskres, message_to_taskins from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.linkstate import LinkStateFactory from .driver import Driver @@ -46,7 +46,7 @@ class InMemoryDriver(Driver): def __init__( self, run_id: int, - state_factory: StateFactory, + state_factory: LinkStateFactory, pull_interval: float = 0.1, ) -> None: self._run_id = run_id diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 9e5aaeaa9ca7..bd961bd05936 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -32,8 +32,12 @@ ) from flwr.common.typing import Run from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory -from flwr.server.superlink.state.utils import generate_rand_int_from_bytes +from flwr.server.superlink.linkstate import ( + InMemoryLinkState, + LinkStateFactory, + SqliteLinkState, +) +from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes from .inmemory_driver import InMemoryDriver @@ -227,12 +231,12 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: """Test tasks are deleted in sqlite state once messages are pulled.""" # Prepare - state = StateFactory("").state() + state = LinkStateFactory("").state() self.driver = InMemoryDriver( state.create_run("", "", "", {}), MagicMock(state=lambda: state) ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) - assert isinstance(state, SqliteState) + assert isinstance(state, SqliteLinkState) # Check recorded task_ins = state.query("SELECT * FROM task_ins;") @@ -253,11 +257,11 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: """Test tasks are deleted in in-memory state once messages are pulled.""" # Prepare - state_factory = StateFactory(":flwr-in-memory-state:") + state_factory = LinkStateFactory(":flwr-in-memory-state:") state = state_factory.state() self.driver = InMemoryDriver(state.create_run("", "", "", {}), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) - assert isinstance(state, InMemoryState) + assert isinstance(state, InMemoryLinkState) # Check recorded self.assertEqual(len(state.task_ins_store), len(list(msg_ids))) diff --git a/src/py/flwr/server/superlink/driver/driver_grpc.py b/src/py/flwr/server/superlink/driver/driver_grpc.py index 70354387812e..327d8244ba11 100644 --- a/src/py/flwr/server/superlink/driver/driver_grpc.py +++ b/src/py/flwr/server/superlink/driver/driver_grpc.py @@ -25,7 +25,7 @@ add_DriverServicer_to_server, ) from flwr.server.superlink.ffs.ffs_factory import FfsFactory -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.linkstate import LinkStateFactory from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server from .driver_servicer import DriverServicer @@ -33,7 +33,7 @@ def run_driver_api_grpc( address: str, - state_factory: StateFactory, + state_factory: LinkStateFactory, ffs_factory: FfsFactory, certificates: Optional[tuple[bytes, bytes, bytes]], ) -> grpc.Server: diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 72c0d110ac14..41a1a64e8879 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -51,14 +51,16 @@ from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs import Ffs from flwr.server.superlink.ffs.ffs_factory import FfsFactory -from flwr.server.superlink.state import State, StateFactory +from flwr.server.superlink.linkstate import LinkState, LinkStateFactory from flwr.server.utils.validator import validate_task_ins_or_res class DriverServicer(driver_pb2_grpc.DriverServicer): """Driver API servicer.""" - def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None: + def __init__( + self, state_factory: LinkStateFactory, ffs_factory: FfsFactory + ) -> None: self.state_factory = state_factory self.ffs_factory = ffs_factory @@ -67,7 +69,7 @@ def GetNodes( ) -> GetNodesResponse: """Get available nodes.""" log(DEBUG, "DriverServicer.GetNodes") - state: State = self.state_factory.state() + state: LinkState = self.state_factory.state() all_ids: set[int] = state.get_nodes(request.run_id) nodes: list[Node] = [ Node(node_id=node_id, anonymous=False) for node_id in all_ids @@ -79,7 +81,7 @@ def CreateRun( ) -> CreateRunResponse: """Create run ID.""" log(DEBUG, "DriverServicer.CreateRun") - state: State = self.state_factory.state() + state: LinkState = self.state_factory.state() if request.HasField("fab"): fab = fab_from_proto(request.fab) ffs: Ffs = self.ffs_factory.ffs() @@ -116,7 +118,7 @@ def PushTaskIns( _raise_if(bool(validation_errors), ", ".join(validation_errors)) # Init state - state: State = self.state_factory.state() + state: LinkState = self.state_factory.state() # Store each TaskIns task_ids: list[Optional[UUID]] = [] @@ -138,7 +140,7 @@ def PullTaskRes( task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids} # Init state - state: State = self.state_factory.state() + state: LinkState = self.state_factory.state() # Register callback def on_rpc_done() -> None: @@ -167,7 +169,7 @@ def GetRun( log(DEBUG, "DriverServicer.GetRun") # Init state - state: State = self.state_factory.state() + state: LinkState = self.state_factory.state() # Retrieve run information run = state.get_run(request.run_id) diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index 75aa6d370511..ffef57d89e8c 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -48,7 +48,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.fleet.message_handler import message_handler -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.linkstate import LinkStateFactory T = TypeVar("T", bound=GrpcMessage) @@ -77,7 +77,9 @@ def _handle( class GrpcAdapterServicer(grpcadapter_pb2_grpc.GrpcAdapterServicer): """Fleet API via GrpcAdapter servicer.""" - def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None: + def __init__( + self, state_factory: LinkStateFactory, ffs_factory: FfsFactory + ) -> None: self.state_factory = state_factory self.ffs_factory = ffs_factory diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py index 02e34e0bba02..dacbab135057 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py @@ -37,13 +37,15 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.fleet.message_handler import message_handler -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.linkstate import LinkStateFactory class FleetServicer(fleet_pb2_grpc.FleetServicer): """Fleet API servicer.""" - def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None: + def __init__( + self, state_factory: LinkStateFactory, ffs_factory: FfsFactory + ) -> None: self.state_factory = state_factory self.ffs_factory = ffs_factory diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 855fab353ae6..2e7623c34241 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -45,7 +45,7 @@ ) from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 -from flwr.server.superlink.state import State +from flwr.server.superlink.linkstate import LinkState _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" @@ -84,7 +84,7 @@ def _get_value_from_tuples( class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore """Server interceptor for node authentication.""" - def __init__(self, state: State): + def __init__(self, state: LinkState): self.state = state self.node_public_keys = state.get_node_public_keys() diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index cf7e05f0fb00..d44f4eb7e8f9 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -45,7 +45,7 @@ from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 from flwr.server.app import _run_fleet_api_grpc_rere from flwr.server.superlink.ffs.ffs_factory import FfsFactory -from flwr.server.superlink.state.state_factory import StateFactory +from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory from .server_interceptor import ( _AUTH_TOKEN_HEADER, @@ -62,7 +62,7 @@ def setUp(self) -> None: self._node_private_key, self._node_public_key = generate_key_pairs() self._server_private_key, self._server_public_key = generate_key_pairs() - state_factory = StateFactory(":flwr-in-memory-state:") + state_factory = LinkStateFactory(":flwr-in-memory-state:") self.state = state_factory.state() ffs_factory = FfsFactory(".") self.ffs = ffs_factory.ffs() diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 85f3fa34e0ac..38df6f441a20 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -43,12 +43,12 @@ ) from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs import Ffs -from flwr.server.superlink.state import State +from flwr.server.superlink.linkstate import LinkState def create_node( request: CreateNodeRequest, # pylint: disable=unused-argument - state: State, + state: LinkState, ) -> CreateNodeResponse: """.""" # Create node @@ -56,7 +56,7 @@ def create_node( return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) -def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse: +def delete_node(request: DeleteNodeRequest, state: LinkState) -> DeleteNodeResponse: """.""" # Validate node_id if request.node.anonymous or request.node.node_id == 0: @@ -69,14 +69,14 @@ def delete_node(request: DeleteNodeRequest, state: State) -> DeleteNodeResponse: def ping( request: PingRequest, # pylint: disable=unused-argument - state: State, # pylint: disable=unused-argument + state: LinkState, # pylint: disable=unused-argument ) -> PingResponse: """.""" res = state.acknowledge_ping(request.node.node_id, request.ping_interval) return PingResponse(success=res) -def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsResponse: +def pull_task_ins(request: PullTaskInsRequest, state: LinkState) -> PullTaskInsResponse: """Pull TaskIns handler.""" # Get node_id if client node is not anonymous node = request.node # pylint: disable=no-member @@ -92,7 +92,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo return response -def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResResponse: +def push_task_res(request: PushTaskResRequest, state: LinkState) -> PushTaskResResponse: """Push TaskRes handler.""" # pylint: disable=no-member task_res: TaskRes = request.task_res_list[0] @@ -113,7 +113,7 @@ def push_task_res(request: PushTaskResRequest, state: State) -> PushTaskResRespo def get_run( - request: GetRunRequest, state: State # pylint: disable=W0613 + request: GetRunRequest, state: LinkState # pylint: disable=W0613 ) -> GetRunResponse: """Get run information.""" run = state.get_run(request.run_id) diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index a988252b3ea2..d38a2b0a500b 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -40,7 +40,7 @@ from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.ffs.ffs import Ffs from flwr.server.superlink.fleet.message_handler import message_handler -from flwr.server.superlink.state import State +from flwr.server.superlink.linkstate import LinkState try: from starlette.applications import Starlette @@ -90,7 +90,7 @@ async def wrapper(request: Request) -> Response: async def create_node(request: CreateNodeRequest) -> CreateNodeResponse: """Create Node.""" # Get state from app - state: State = app.state.STATE_FACTORY.state() + state: LinkState = app.state.STATE_FACTORY.state() # Handle message return message_handler.create_node(request=request, state=state) @@ -100,7 +100,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse: async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse: """Delete Node Id.""" # Get state from app - state: State = app.state.STATE_FACTORY.state() + state: LinkState = app.state.STATE_FACTORY.state() # Handle message return message_handler.delete_node(request=request, state=state) @@ -110,7 +110,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse: async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse: """Pull TaskIns.""" # Get state from app - state: State = app.state.STATE_FACTORY.state() + state: LinkState = app.state.STATE_FACTORY.state() # Handle message return message_handler.pull_task_ins(request=request, state=state) @@ -121,7 +121,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse: async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse: """Push TaskRes.""" # Get state from app - state: State = app.state.STATE_FACTORY.state() + state: LinkState = app.state.STATE_FACTORY.state() # Handle message return message_handler.push_task_res(request=request, state=state) @@ -131,7 +131,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse: async def ping(request: PingRequest) -> PingResponse: """Ping.""" # Get state from app - state: State = app.state.STATE_FACTORY.state() + state: LinkState = app.state.STATE_FACTORY.state() # Handle message return message_handler.ping(request=request, state=state) @@ -141,7 +141,7 @@ async def ping(request: PingRequest) -> PingResponse: async def get_run(request: GetRunRequest) -> GetRunResponse: """GetRun.""" # Get state from app - state: State = app.state.STATE_FACTORY.state() + state: LinkState = app.state.STATE_FACTORY.state() # Handle message return message_handler.get_run(request=request, state=state) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 785390534001..570aa17edb53 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -40,7 +40,7 @@ from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.common.typing import Run from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import State, StateFactory +from flwr.server.superlink.linkstate import LinkState, LinkStateFactory from .backend import Backend, error_messages_backends, supported_backends @@ -48,7 +48,7 @@ def _register_nodes( - num_nodes: int, state_factory: StateFactory + num_nodes: int, state_factory: LinkStateFactory ) -> NodeToPartitionMapping: """Register nodes with the StateFactory and create node-id:partition-id mapping.""" nodes_mapping: NodeToPartitionMapping = {} @@ -145,7 +145,7 @@ def worker( def add_taskins_to_queue( - state: State, + state: LinkState, queue: "Queue[TaskIns]", nodes_mapping: NodeToPartitionMapping, f_stop: threading.Event, @@ -160,7 +160,7 @@ def add_taskins_to_queue( def put_taskres_into_state( - state: State, queue: "Queue[TaskRes]", f_stop: threading.Event + state: LinkState, queue: "Queue[TaskRes]", f_stop: threading.Event ) -> None: """Put TaskRes into State from a queue.""" while not f_stop.is_set(): @@ -177,7 +177,7 @@ def run_api( app_fn: Callable[[], ClientApp], backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, - state_factory: StateFactory, + state_factory: LinkStateFactory, node_states: dict[int, NodeState], f_stop: threading.Event, ) -> None: @@ -264,7 +264,7 @@ def start_vce( client_app: Optional[ClientApp] = None, client_app_attr: Optional[str] = None, num_supernodes: Optional[int] = None, - state_factory: Optional[StateFactory] = None, + state_factory: Optional[LinkStateFactory] = None, existing_nodes_mapping: Optional[NodeToPartitionMapping] = None, ) -> None: """Start Fleet API with the Simulation Engine.""" @@ -303,7 +303,7 @@ def start_vce( if not state_factory: log(INFO, "A StateFactory was not supplied to the SimulationEngine.") # Create an empty in-memory state factory - state_factory = StateFactory(":flwr-in-memory-state:") + state_factory = LinkStateFactory(":flwr-in-memory-state:") log(INFO, "Created new %s.", state_factory.__class__.__name__) if num_supernodes: diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index bc34b825c333..d14ce86c58c4 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -48,7 +48,7 @@ _register_nodes, start_vce, ) -from flwr.server.superlink.state import InMemoryState, StateFactory +from flwr.server.superlink.linkstate import InMemoryLinkState, LinkStateFactory class DummyClient(NumPyClient): @@ -86,11 +86,11 @@ def terminate_simulation(f_stop: threading.Event, sleep_duration: int) -> None: def init_state_factory_nodes_mapping( num_nodes: int, num_messages: int, -) -> tuple[StateFactory, NodeToPartitionMapping, dict[UUID, float]]: +) -> tuple[LinkStateFactory, NodeToPartitionMapping, dict[UUID, float]]: """Instatiate StateFactory, register nodes and pre-insert messages in the state.""" # Register a state and a run_id in it run_id = 1234 - state_factory = StateFactory(":flwr-in-memory-state:") + state_factory = LinkStateFactory(":flwr-in-memory-state:") # Register a few nodes nodes_mapping = _register_nodes(num_nodes=num_nodes, state_factory=state_factory) @@ -106,13 +106,13 @@ def init_state_factory_nodes_mapping( # pylint: disable=too-many-locals def register_messages_into_state( - state_factory: StateFactory, + state_factory: LinkStateFactory, nodes_mapping: NodeToPartitionMapping, run_id: int, num_messages: int, ) -> dict[UUID, float]: """Register `num_messages` into the state factory.""" - state: InMemoryState = state_factory.state() # type: ignore + state: InMemoryLinkState = state_factory.state() # type: ignore state.run_ids[run_id] = Run( run_id=run_id, fab_id="Mock/mock", @@ -176,7 +176,7 @@ def start_and_shutdown( client_app_attr: Optional[str] = None, app_dir: str = "", num_supernodes: Optional[int] = None, - state_factory: Optional[StateFactory] = None, + state_factory: Optional[LinkStateFactory] = None, nodes_mapping: Optional[NodeToPartitionMapping] = None, duration: int = 0, backend_config: str = "{}", diff --git a/src/py/flwr/server/superlink/state/__init__.py b/src/py/flwr/server/superlink/linkstate/__init__.py similarity index 65% rename from src/py/flwr/server/superlink/state/__init__.py rename to src/py/flwr/server/superlink/linkstate/__init__.py index 9d3bd220403b..471cfbd2b5ec 100644 --- a/src/py/flwr/server/superlink/state/__init__.py +++ b/src/py/flwr/server/superlink/linkstate/__init__.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Flower server state.""" +"""Flower LinkState.""" -from .in_memory_state import InMemoryState as InMemoryState -from .sqlite_state import SqliteState as SqliteState -from .state import State as State -from .state_factory import StateFactory as StateFactory +from .in_memory_linkstate import InMemoryLinkState as InMemoryLinkState +from .linkstate import LinkState as LinkState +from .linkstate_factory import LinkStateFactory as LinkStateFactory +from .sqlite_linkstate import SqliteLinkState as SqliteLinkState __all__ = [ - "InMemoryState", - "SqliteState", - "State", - "StateFactory", + "InMemoryLinkState", + "LinkState", + "LinkStateFactory", + "SqliteLinkState", ] diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py similarity index 97% rename from src/py/flwr/server/superlink/state/in_memory_state.py rename to src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index a9c4176ee5f2..8fdb5a1ed9ec 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""In-memory State implementation.""" +"""In-memory LinkState implementation.""" import threading @@ -29,14 +29,14 @@ ) from flwr.common.typing import Run, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state.state import State +from flwr.server.superlink.linkstate.linkstate import LinkState from flwr.server.utils import validate_task_ins_or_res from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres -class InMemoryState(State): # pylint: disable=R0902,R0904 - """In-memory State implementation.""" +class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904 + """In-memory LinkState implementation.""" def __init__(self) -> None: @@ -277,7 +277,7 @@ def num_task_res(self) -> int: def create_node( self, ping_interval: float, public_key: Optional[bytes] = None ) -> int: - """Create, store in state, and return `node_id`.""" + """Create, store in the link state, and return `node_id`.""" # Sample a random int64 as node_id node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -365,7 +365,7 @@ def create_run( def store_server_private_public_key( self, private_key: bytes, public_key: bytes ) -> None: - """Store `server_private_key` and `server_public_key` in state.""" + """Store `server_private_key` and `server_public_key` in the link state.""" with self.lock: if self.server_private_key is None and self.server_public_key is None: self.server_private_key = private_key @@ -382,12 +382,12 @@ def get_server_public_key(self) -> Optional[bytes]: return self.server_public_key def store_node_public_keys(self, public_keys: set[bytes]) -> None: - """Store a set of `node_public_keys` in state.""" + """Store a set of `node_public_keys` in the link state.""" with self.lock: self.node_public_keys = public_keys def store_node_public_key(self, public_key: bytes) -> None: - """Store a `node_public_key` in state.""" + """Store a `node_public_key` in the link state.""" with self.lock: self.node_public_keys.add(public_key) diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/linkstate/linkstate.py similarity index 93% rename from src/py/flwr/server/superlink/state/state.py rename to src/py/flwr/server/superlink/linkstate/linkstate.py index b220aad3ebcc..e8e254873957 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Abstract base class State.""" +"""Abstract base class LinkState.""" import abc @@ -23,8 +23,8 @@ from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -class State(abc.ABC): # pylint: disable=R0904 - """Abstract State.""" +class LinkState(abc.ABC): # pylint: disable=R0904 + """Abstract LinkState.""" @abc.abstractmethod def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: @@ -32,8 +32,8 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: Usually, the Driver API calls this to schedule instructions. - Stores the value of the `task_ins` in the state and, if successful, returns the - `task_id` (UUID) of the `task_ins`. If, for any reason, + Stores the value of the `task_ins` in the link state and, if successful, + returns the `task_id` (UUID) of the `task_ins`. If, for any reason, storing the `task_ins` fails, `None` is returned. Constraints @@ -130,11 +130,11 @@ def delete_tasks(self, task_ids: set[UUID]) -> None: def create_node( self, ping_interval: float, public_key: Optional[bytes] = None ) -> int: - """Create, store in state, and return `node_id`.""" + """Create, store in the link state, and return `node_id`.""" @abc.abstractmethod def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: - """Remove `node_id` from state.""" + """Remove `node_id` from the link state.""" @abc.abstractmethod def get_nodes(self, run_id: int) -> set[int]: @@ -182,7 +182,7 @@ def get_run(self, run_id: int) -> Optional[Run]: def store_server_private_public_key( self, private_key: bytes, public_key: bytes ) -> None: - """Store `server_private_key` and `server_public_key` in state.""" + """Store `server_private_key` and `server_public_key` in the link state.""" @abc.abstractmethod def get_server_private_key(self) -> Optional[bytes]: @@ -194,11 +194,11 @@ def get_server_public_key(self) -> Optional[bytes]: @abc.abstractmethod def store_node_public_keys(self, public_keys: set[bytes]) -> None: - """Store a set of `node_public_keys` in state.""" + """Store a set of `node_public_keys` in the link state.""" @abc.abstractmethod def store_node_public_key(self, public_key: bytes) -> None: - """Store a `node_public_key` in state.""" + """Store a `node_public_key` in the link state.""" @abc.abstractmethod def get_node_public_keys(self) -> set[bytes]: diff --git a/src/py/flwr/server/superlink/state/state_factory.py b/src/py/flwr/server/superlink/linkstate/linkstate_factory.py similarity index 80% rename from src/py/flwr/server/superlink/state/state_factory.py rename to src/py/flwr/server/superlink/linkstate/linkstate_factory.py index 96c8d445c16e..403b9bf5b4cc 100644 --- a/src/py/flwr/server/superlink/state/state_factory.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_factory.py @@ -20,13 +20,13 @@ from flwr.common.logger import log -from .in_memory_state import InMemoryState -from .sqlite_state import SqliteState -from .state import State +from .in_memory_linkstate import InMemoryLinkState +from .linkstate import LinkState +from .sqlite_linkstate import SqliteLinkState -class StateFactory: - """Factory class that creates State instances. +class LinkStateFactory: + """Factory class that creates LinkState instances. Parameters ---------- @@ -39,19 +39,19 @@ class StateFactory: def __init__(self, database: str) -> None: self.database = database - self.state_instance: Optional[State] = None + self.state_instance: Optional[LinkState] = None - def state(self) -> State: + def state(self) -> LinkState: """Return a State instance and create it, if necessary.""" # InMemoryState if self.database == ":flwr-in-memory-state:": if self.state_instance is None: - self.state_instance = InMemoryState() + self.state_instance = InMemoryLinkState() log(DEBUG, "Using InMemoryState") return self.state_instance # SqliteState - state = SqliteState(self.database) + state = SqliteLinkState(self.database) state.initialize() log(DEBUG, "Using SqliteState") return state diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py similarity index 94% rename from src/py/flwr/server/superlink/state/state_test.py rename to src/py/flwr/server/superlink/linkstate/linkstate_test.py index a4663f80f630..dec0a3b705e7 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests all state implemenations have to conform to.""" +"""Tests all LinkState implemenations have to conform to.""" # pylint: disable=invalid-name, too-many-lines, R0904, R0913 import tempfile @@ -33,7 +33,11 @@ from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import InMemoryState, SqliteState, State +from flwr.server.superlink.linkstate import ( + InMemoryLinkState, + LinkState, + SqliteLinkState, +) class StateTest(unittest.TestCase): @@ -43,14 +47,14 @@ class StateTest(unittest.TestCase): __test__ = False @abstractmethod - def state_factory(self) -> State: + def state_factory(self) -> LinkState: """Provide state implementation to test.""" raise NotImplementedError() def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) # Execute @@ -188,7 +192,7 @@ def test_init_state(self) -> None: state = self.state_factory() # Assert - assert isinstance(state, State) + assert isinstance(state, LinkState) # TaskIns tests def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: @@ -197,7 +201,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: Create anonymous task and retrieve it. """ # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -212,7 +216,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -226,7 +230,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) @@ -240,7 +244,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) @@ -257,7 +261,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) @@ -278,7 +282,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: def test_get_task_ins_limit_throws_for_limit_zero(self) -> None: """Fail call with limit=0.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() # Execute & Assert with self.assertRaises(AssertionError): @@ -287,7 +291,7 @@ def test_get_task_ins_limit_throws_for_limit_zero(self) -> None: def test_task_ins_store_invalid_run_id_and_fail(self) -> None: """Store TaskIns with invalid run_id and fail.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=61016) # Execute @@ -300,7 +304,7 @@ def test_task_ins_store_invalid_run_id_and_fail(self) -> None: def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -326,7 +330,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) # Execute @@ -338,7 +342,7 @@ def test_node_ids_initial_state(self) -> None: def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) node_ids = [] @@ -354,7 +358,7 @@ def test_create_node_and_get_nodes(self) -> None: def test_create_node_public_key(self) -> None: """Test creating a client node with public key.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}) @@ -370,7 +374,7 @@ def test_create_node_public_key(self) -> None: def test_create_node_public_key_twice(self) -> None: """Test creating a client node with same public key twice.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -386,14 +390,14 @@ def test_create_node_public_key_twice(self) -> None: assert retrieved_node_id == node_id # Assert node_ids and public_key_to_node_id are synced - if isinstance(state, InMemoryState): + if isinstance(state, InMemoryLinkState): assert len(state.node_ids) == 1 assert len(state.public_key_to_node_id) == 1 def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) node_id = state.create_node(ping_interval=10) @@ -407,7 +411,7 @@ def test_delete_node(self) -> None: def test_delete_node_public_key(self) -> None: """Test deleting a client node with public key.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}) node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -424,7 +428,7 @@ def test_delete_node_public_key(self) -> None: def test_delete_node_public_key_none(self) -> None: """Test deleting a client node with public key.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}) node_id = 0 @@ -442,7 +446,7 @@ def test_delete_node_public_key_none(self) -> None: def test_delete_node_wrong_public_key(self) -> None: """Test deleting a client node with wrong public key.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" run_id = state.create_run(None, None, "9f86d08", {}) @@ -461,7 +465,7 @@ def test_delete_node_wrong_public_key(self) -> None: def test_get_node_id_wrong_public_key(self) -> None: """Test retrieving a client node with wrong public key.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" run_id = state.create_run(None, None, "9f86d08", {}) @@ -478,7 +482,7 @@ def test_get_node_id_wrong_public_key(self) -> None: def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() state.create_run(None, None, "9f86d08", {}) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -492,7 +496,7 @@ def test_get_nodes_invalid_run_id(self) -> None: def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -510,7 +514,7 @@ def test_num_task_ins(self) -> None: def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -544,7 +548,7 @@ def test_num_task_res(self) -> None: def test_server_private_public_key(self) -> None: """Test get server private and public key after inserting.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() private_key, public_key = generate_key_pairs() private_key_bytes = private_key_to_bytes(private_key) public_key_bytes = public_key_to_bytes(public_key) @@ -561,7 +565,7 @@ def test_server_private_public_key(self) -> None: def test_server_private_public_key_none(self) -> None: """Test get server private and public key without inserting.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() # Execute server_private_key = state.get_server_private_key() @@ -574,7 +578,7 @@ def test_server_private_public_key_none(self) -> None: def test_store_server_private_public_key_twice(self) -> None: """Test inserting private and public key twice.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() private_key, public_key = generate_key_pairs() private_key_bytes = private_key_to_bytes(private_key) public_key_bytes = public_key_to_bytes(public_key) @@ -594,7 +598,7 @@ def test_store_server_private_public_key_twice(self) -> None: def test_node_public_keys(self) -> None: """Test store_node_public_keys and get_node_public_keys from state.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() key_pairs = [generate_key_pairs() for _ in range(3)] public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} @@ -608,7 +612,7 @@ def test_node_public_keys(self) -> None: def test_node_public_key(self) -> None: """Test store_node_public_key and get_node_public_keys from state.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() key_pairs = [generate_key_pairs() for _ in range(3)] public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} @@ -623,7 +627,7 @@ def test_node_public_key(self) -> None: def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: @@ -642,7 +646,7 @@ def test_acknowledge_ping(self) -> None: def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) @@ -684,7 +688,7 @@ def test_node_unavailable_error(self) -> None: def test_store_task_res_task_ins_expired(self) -> None: """Test behavior of store_task_res when the TaskIns it references is expired.""" # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -738,7 +742,7 @@ def test_store_task_res_limit_ttl(self) -> None: ) in test_cases: # Prepare - state: State = self.state_factory() + state: LinkState = self.state_factory() run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( @@ -957,9 +961,9 @@ class InMemoryStateTest(StateTest): __test__ = True - def state_factory(self) -> State: + def state_factory(self) -> LinkState: """Return InMemoryState.""" - return InMemoryState() + return InMemoryLinkState() class SqliteInMemoryStateTest(StateTest, unittest.TestCase): @@ -967,9 +971,9 @@ class SqliteInMemoryStateTest(StateTest, unittest.TestCase): __test__ = True - def state_factory(self) -> SqliteState: + def state_factory(self) -> SqliteLinkState: """Return SqliteState with in-memory database.""" - state = SqliteState(":memory:") + state = SqliteLinkState(":memory:") state.initialize() return state @@ -990,11 +994,11 @@ class SqliteFileBasedTest(StateTest, unittest.TestCase): __test__ = True - def state_factory(self) -> SqliteState: + def state_factory(self) -> SqliteLinkState: """Return SqliteState with file-based database.""" # pylint: disable-next=consider-using-with,attribute-defined-outside-init self.tmp_file = tempfile.NamedTemporaryFile() - state = SqliteState(database_path=self.tmp_file.name) + state = SqliteLinkState(database_path=self.tmp_file.name) state.initialize() return state diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py similarity index 97% rename from src/py/flwr/server/superlink/state/sqlite_state.py rename to src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 6d644c3b2232..4344ce8b062d 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""SQLite based implemenation of server state.""" +"""SQLite based implemenation of the link state.""" # pylint: disable=too-many-lines @@ -37,7 +37,7 @@ from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.utils.validator import validate_task_ins_or_res -from .state import State +from .linkstate import LinkState from .utils import ( convert_sint64_to_uint64, convert_sint64_values_in_dict_to_uint64, @@ -126,8 +126,8 @@ DictOrTuple = Union[tuple[Any, ...], dict[str, Any]] -class SqliteState(State): # pylint: disable=R0904 - """SQLite-based state implementation.""" +class SqliteLinkState(LinkState): # pylint: disable=R0904 + """SQLite-based LinkState implementation.""" def __init__( self, @@ -183,7 +183,7 @@ def query( ) -> list[dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: - raise AttributeError("State is not initialized.") + raise AttributeError("LinkState is not initialized.") if data is None: data = [] @@ -216,9 +216,9 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: Usually, the Driver API calls this to schedule instructions. - Stores the value of the task_ins in the state and, if successful, returns the - task_id (UUID) of the task_ins. If, for any reason, storing the task_ins fails, - `None` is returned. + Stores the value of the task_ins in the link state and, if successful, + returns the task_id (UUID) of the task_ins. If, for any reason, storing + the task_ins fails, `None` is returned. Constraints ----------- @@ -645,7 +645,7 @@ def delete_tasks(self, task_ids: set[UUID]) -> None: """ if self.conn is None: - raise AttributeError("State not intitialized") + raise AttributeError("LinkState not intitialized") with self.conn: self.conn.execute(query_1, data) @@ -656,7 +656,7 @@ def delete_tasks(self, task_ids: set[UUID]) -> None: def create_node( self, ping_interval: float, public_key: Optional[bytes] = None ) -> int: - """Create, store in state, and return `node_id`.""" + """Create, store in the link state, and return `node_id`.""" # Sample a random uint64 as node_id uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -706,7 +706,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: params += (public_key,) # type: ignore if self.conn is None: - raise AttributeError("State is not initialized.") + raise AttributeError("LinkState is not initialized.") try: with self.conn: @@ -800,7 +800,7 @@ def create_run( def store_server_private_public_key( self, private_key: bytes, public_key: bytes ) -> None: - """Store `server_private_key` and `server_public_key` in state.""" + """Store `server_private_key` and `server_public_key` in the link state.""" query = "SELECT COUNT(*) FROM credential" count = self.query(query)[0]["COUNT(*)"] if count < 1: @@ -833,13 +833,13 @@ def get_server_public_key(self) -> Optional[bytes]: return public_key def store_node_public_keys(self, public_keys: set[bytes]) -> None: - """Store a set of `node_public_keys` in state.""" + """Store a set of `node_public_keys` in the link state.""" query = "INSERT INTO public_key (public_key) VALUES (?)" data = [(key,) for key in public_keys] self.query(query, data) def store_node_public_key(self, public_key: bytes) -> None: - """Store a `node_public_key` in state.""" + """Store a `node_public_key` in the link state.""" query = "INSERT INTO public_key (public_key) VALUES (:public_key)" self.query(query, {"public_key": public_key}) diff --git a/src/py/flwr/server/superlink/state/sqlite_state_test.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate_test.py similarity index 91% rename from src/py/flwr/server/superlink/state/sqlite_state_test.py rename to src/py/flwr/server/superlink/linkstate/sqlite_linkstate_test.py index 10e12da96bd5..ed2960ef76fa 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state_test.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate_test.py @@ -17,8 +17,8 @@ import unittest -from flwr.server.superlink.state.sqlite_state import task_ins_to_dict -from flwr.server.superlink.state.state_test import create_task_ins +from flwr.server.superlink.linkstate.linkstate_test import create_task_ins +from flwr.server.superlink.linkstate.sqlite_linkstate import task_ins_to_dict class SqliteStateTest(unittest.TestCase): diff --git a/src/py/flwr/server/superlink/state/utils.py b/src/py/flwr/server/superlink/linkstate/utils.py similarity index 100% rename from src/py/flwr/server/superlink/state/utils.py rename to src/py/flwr/server/superlink/linkstate/utils.py diff --git a/src/py/flwr/server/superlink/state/utils_test.py b/src/py/flwr/server/superlink/linkstate/utils_test.py similarity index 100% rename from src/py/flwr/server/superlink/state/utils_test.py rename to src/py/flwr/server/superlink/linkstate/utils_test.py diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 0070d75c53dc..62efc5197d3f 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -36,7 +36,7 @@ from flwr.server.server import Server, init_defaults, run_fl from flwr.server.server_config import ServerConfig from flwr.server.strategy import Strategy -from flwr.server.superlink.state.utils import generate_rand_int_from_bytes +from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes from flwr.simulation.ray_transport.ray_actor import ( ClientAppActor, VirtualClientEngineActor, diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 8c4e42c34744..e9b2352e0c0c 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -44,8 +44,8 @@ from flwr.server.server_app import ServerApp from flwr.server.superlink.fleet import vce from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig -from flwr.server.superlink.state import StateFactory -from flwr.server.superlink.state.utils import generate_rand_int_from_bytes +from flwr.server.superlink.linkstate import LinkStateFactory +from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes from flwr.simulation.ray_transport.utils import ( enable_tf_gpu_growth as enable_gpu_growth, ) @@ -389,7 +389,7 @@ def _main_loop( ) -> None: """Start ServerApp on a separate thread, then launch Simulation Engine.""" # Initialize StateFactory - state_factory = StateFactory(":flwr-in-memory-state:") + state_factory = LinkStateFactory(":flwr-in-memory-state:") f_stop = threading.Event() # A Threading event to indicate if an exception was raised in the ServerApp thread diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index e913b6812556..820d80a89ac7 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -29,7 +29,7 @@ from flwr.common.constant import RUN_ID_NUM_BYTES from flwr.common.logger import log from flwr.common.typing import UserConfig -from flwr.server.superlink.state.utils import generate_rand_int_from_bytes +from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes from .executor import Executor, RunTracker From 2351cd7207a7e1b3461eac6ad72688996f615701 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 22 Oct 2024 14:57:08 +0100 Subject: [PATCH 05/10] refactor(framework) Rename `NodeState` to `DeprecatedRunInfoStore` (#4348) --- src/py/flwr/client/app.py | 20 ++++++------ src/py/flwr/client/node_state_tests.py | 15 ++++----- .../{node_state.py => run_info_store.py} | 6 ++-- .../fleet/vce/backend/raybackend_test.py | 8 +++-- .../server/superlink/fleet/vce/vce_api.py | 32 +++++++++---------- .../ray_transport/ray_client_proxy.py | 4 +-- .../ray_transport/ray_client_proxy_test.py | 14 ++++---- 7 files changed, 50 insertions(+), 49 deletions(-) rename src/py/flwr/client/{node_state.py => run_info_store.py} (97%) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 3ae9b1b1e282..5b4eff51a7d6 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -59,8 +59,8 @@ from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message -from .node_state import NodeState from .numpy_client import NumPyClient +from .run_info_store import DeprecatedRunInfoStore ISOLATION_MODE_SUBPROCESS = "subprocess" ISOLATION_MODE_PROCESS = "process" @@ -364,8 +364,8 @@ def _on_backoff(retry_state: RetryState) -> None: on_backoff=_on_backoff, ) - # NodeState gets initialized when the first connection is established - node_state: Optional[NodeState] = None + # DeprecatedRunInfoStore gets initialized when the first connection is established + run_info_store: Optional[DeprecatedRunInfoStore] = None runs: dict[int, Run] = {} @@ -382,7 +382,7 @@ def _on_backoff(retry_state: RetryState) -> None: receive, send, create_node, delete_node, get_run, get_fab = conn # Register node when connecting the first time - if node_state is None: + if run_info_store is None: if create_node is None: if transport not in ["grpc-bidi", None]: raise NotImplementedError( @@ -391,7 +391,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) # gRPC-bidi doesn't have the concept of node_id, # so we set it to -1 - node_state = NodeState( + run_info_store = DeprecatedRunInfoStore( node_id=-1, node_config={}, ) @@ -402,7 +402,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) # pylint: disable=not-callable if node_id is None: raise ValueError("Node registration failed") - node_state = NodeState( + run_info_store = DeprecatedRunInfoStore( node_id=node_id, node_config=node_config, ) @@ -461,7 +461,7 @@ def _on_backoff(retry_state: RetryState) -> None: run.fab_id, run.fab_version = fab_id, fab_version # Register context for this run - node_state.register_context( + run_info_store.register_context( run_id=run_id, run=run, flwr_path=flwr_path, @@ -469,7 +469,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) # Retrieve context for this run - context = node_state.retrieve_context(run_id=run_id) + context = run_info_store.retrieve_context(run_id=run_id) # Create an error reply message that will never be used to prevent # the used-before-assignment linting error reply_message = message.create_error_reply( @@ -542,7 +542,7 @@ def _on_backoff(retry_state: RetryState) -> None: # Raise exception, crash process raise ex - # Don't update/change NodeState + # Don't update/change DeprecatedRunInfoStore e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION # Ex fmt: ":<'division by zero'>" @@ -567,7 +567,7 @@ def _on_backoff(retry_state: RetryState) -> None: ) else: # No exception, update node state - node_state.update_context( + run_info_store.update_context( run_id=run_id, context=context, ) diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index 26ac4fea6855..06ceb80a94ad 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -17,7 +17,7 @@ from typing import cast -from flwr.client.node_state import NodeState +from flwr.client.run_info_store import DeprecatedRunInfoStore from flwr.common import ConfigsRecord, Context from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 @@ -34,32 +34,31 @@ def _run_dummy_task(context: Context) -> Context: def test_multirun_in_node_state() -> None: - """Test basic NodeState logic.""" + """Test basic DeprecatedRunInfoStore logic.""" # Tasks to perform tasks = [TaskIns(run_id=run_id) for run_id in [0, 1, 1, 2, 3, 2, 1, 5]] # the "tasks" is to count how many times each run is executed expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} - # NodeState - node_state = NodeState(node_id=0, node_config={}) + node_info_store = DeprecatedRunInfoStore(node_id=0, node_config={}) for task in tasks: run_id = task.run_id # Register - node_state.register_context(run_id=run_id) + node_info_store.register_context(run_id=run_id) # Get run state - context = node_state.retrieve_context(run_id=run_id) + context = node_info_store.retrieve_context(run_id=run_id) # Run "task" updated_state = _run_dummy_task(context) # Update run state - node_state.update_context(run_id=run_id, context=updated_state) + node_info_store.update_context(run_id=run_id, context=updated_state) # Verify values - for run_id, run_info in node_state.run_infos.items(): + for run_id, run_info in node_info_store.run_infos.items(): assert ( run_info.context.state.configs_records["counter"]["count"] == expected_values[run_id] diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/run_info_store.py similarity index 97% rename from src/py/flwr/client/node_state.py rename to src/py/flwr/client/run_info_store.py index 843c9890c5d2..6b0c3bd3a493 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/run_info_store.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Node state.""" +"""Deprecated Run Info Store.""" from dataclasses import dataclass @@ -36,7 +36,7 @@ class RunInfo: initial_run_config: UserConfig -class NodeState: +class DeprecatedRunInfoStore: """State of a node where client nodes execute runs.""" def __init__( diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index 1cbdc230c938..753f450e835c 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -22,7 +22,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp -from flwr.client.node_state import NodeState +from flwr.client.run_info_store import DeprecatedRunInfoStore from flwr.common import ( DEFAULT_TTL, Config, @@ -104,8 +104,10 @@ def _create_message_and_context() -> tuple[Message, Context, float]: ), ) - # Construct NodeState and retrieve context - node_state = NodeState(node_id=run_id, node_config={PARTITION_ID_KEY: str(0)}) + # Construct DeprecatedRunInfoStore and retrieve context + node_state = DeprecatedRunInfoStore( + node_id=run_id, node_config={PARTITION_ID_KEY: str(0)} + ) node_state.register_context(run_id=run_id) context = node_state.retrieve_context(run_id=run_id) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 570aa17edb53..7a2d28dec4fb 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -28,7 +28,7 @@ from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.clientapp.utils import get_load_client_app_fn -from flwr.client.node_state import NodeState +from flwr.client.run_info_store import DeprecatedRunInfoStore from flwr.common.constant import ( NUM_PARTITIONS_KEY, PARTITION_ID_KEY, @@ -60,16 +60,16 @@ def _register_nodes( return nodes_mapping -def _register_node_states( +def _register_node_info_stores( nodes_mapping: NodeToPartitionMapping, run: Run, app_dir: Optional[str] = None, -) -> dict[int, NodeState]: - """Create NodeState objects and pre-register the context for the run.""" - node_states: dict[int, NodeState] = {} +) -> dict[int, DeprecatedRunInfoStore]: + """Create DeprecatedRunInfoStore objects and register the context for the run.""" + node_info_store: dict[int, DeprecatedRunInfoStore] = {} num_partitions = len(set(nodes_mapping.values())) for node_id, partition_id in nodes_mapping.items(): - node_states[node_id] = NodeState( + node_info_store[node_id] = DeprecatedRunInfoStore( node_id=node_id, node_config={ PARTITION_ID_KEY: partition_id, @@ -78,18 +78,18 @@ def _register_node_states( ) # Pre-register Context objects - node_states[node_id].register_context( + node_info_store[node_id].register_context( run_id=run.run_id, run=run, app_dir=app_dir ) - return node_states + return node_info_store # pylint: disable=too-many-arguments,too-many-locals def worker( taskins_queue: "Queue[TaskIns]", taskres_queue: "Queue[TaskRes]", - node_states: dict[int, NodeState], + node_info_store: dict[int, DeprecatedRunInfoStore], backend: Backend, f_stop: threading.Event, ) -> None: @@ -103,7 +103,7 @@ def worker( node_id = task_ins.task.consumer.node_id # Retrieve context - context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) + context = node_info_store[node_id].retrieve_context(run_id=task_ins.run_id) # Convert TaskIns to Message message = message_from_taskins(task_ins) @@ -112,7 +112,7 @@ def worker( out_mssg, updated_context = backend.process_message(message, context) # Update Context - node_states[node_id].update_context( + node_info_store[node_id].update_context( task_ins.run_id, context=updated_context ) except Empty: @@ -178,7 +178,7 @@ def run_api( backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, state_factory: LinkStateFactory, - node_states: dict[int, NodeState], + node_info_stores: dict[int, DeprecatedRunInfoStore], f_stop: threading.Event, ) -> None: """Run the VCE.""" @@ -223,7 +223,7 @@ def run_api( worker, taskins_queue, taskres_queue, - node_states, + node_info_stores, backend, f_stop, ) @@ -312,8 +312,8 @@ def start_vce( num_nodes=num_supernodes, state_factory=state_factory ) - # Construct mapping of NodeStates - node_states = _register_node_states( + # Construct mapping of DeprecatedRunInfoStore + node_info_stores = _register_node_info_stores( nodes_mapping=nodes_mapping, run=run, app_dir=app_dir if is_app else None ) @@ -376,7 +376,7 @@ def _load_client_app() -> ClientApp: backend_fn, nodes_mapping, state_factory, - node_states, + node_info_stores, f_stop, ) except LoadClientAppError as loadapp_ex: diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index ad9be6bd1fc0..a5d4b27d3e5a 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -22,7 +22,7 @@ from flwr import common from flwr.client import ClientFnExt from flwr.client.client_app import ClientApp -from flwr.client.node_state import NodeState +from flwr.client.run_info_store import DeprecatedRunInfoStore from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.constant import ( NUM_PARTITIONS_KEY, @@ -65,7 +65,7 @@ def _load_app() -> ClientApp: self.app_fn = _load_app self.actor_pool = actor_pool - self.proxy_state = NodeState( + self.proxy_state = DeprecatedRunInfoStore( node_id=node_id, node_config={ PARTITION_ID_KEY: str(partition_id), diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index ce0ef46d135f..780092ecb78e 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -22,7 +22,7 @@ from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp -from flwr.client.node_state import NodeState +from flwr.client.run_info_store import DeprecatedRunInfoStore from flwr.common import ( DEFAULT_TTL, Config, @@ -142,7 +142,7 @@ def test_cid_consistency_all_submit_first_run_consistency() -> None: """Test that ClientProxies get the result of client job they submit. All jobs are submitted at the same time. Then fetched one at a time. This also tests - NodeState (at each Proxy) and RunState basic functionality. + DeprecatedRunInfoStore (at each Proxy) and RunState basic functionality. """ proxies, _, _ = prep() run_id = 0 @@ -193,10 +193,10 @@ def test_cid_consistency_without_proxies() -> None: _, pool, mapping = prep() node_ids = list(mapping.keys()) - # register node states - node_states: dict[int, NodeState] = {} + # register DeprecatedRunInfoStores + node_info_stores: dict[int, DeprecatedRunInfoStore] = {} for node_id, partition_id in mapping.items(): - node_states[node_id] = NodeState( + node_info_stores[node_id] = DeprecatedRunInfoStore( node_id=node_id, node_config={ PARTITION_ID_KEY: str(partition_id), @@ -228,8 +228,8 @@ def _load_app() -> ClientApp: ), ) # register and retrieve context - node_states[node_id].register_context(run_id=run_id) - context = node_states[node_id].retrieve_context(run_id=run_id) + node_info_stores[node_id].register_context(run_id=run_id) + context = node_info_stores[node_id].retrieve_context(run_id=run_id) partition_id_str = str(context.node_config[PARTITION_ID_KEY]) pool.submit_client_job( lambda a, c_fn, j_fn, nid_, state: a.run.remote(c_fn, j_fn, nid_, state), From 2c0743a3b1989723d58eea39e708d96a85b9c167 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 22 Oct 2024 19:40:33 +0100 Subject: [PATCH 06/10] feat(framework) Introduce `flwr-serverapp` CLI entrypoint (#4350) --- pyproject.toml | 3 ++- src/py/flwr/server/serverapp/__init__.py | 22 ++++++++++++++++++++++ src/py/flwr/server/serverapp/app.py | 20 ++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 src/py/flwr/server/serverapp/__init__.py create mode 100644 src/py/flwr/server/serverapp/app.py diff --git a/pyproject.toml b/pyproject.toml index 4b8e671a50f3..d7d2d644a333 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ flower-simulation = "flwr.simulation.run_simulation:run_simulation_from_cli" # Deployment Engine flower-superlink = "flwr.server.app:run_superlink" flower-supernode = "flwr.client.supernode.app:run_supernode" -flower-server-app = "flwr.server.run_serverapp:run_server_app" +flwr-serverapp = "flwr.server.serverapp:flwr_serverapp" +flower-server-app = "flwr.server.run_serverapp:run_server_app" # Deprecated flwr-clientapp = "flwr.client.clientapp:flwr_clientapp" flower-client-app = "flwr.client.supernode:run_client_app" # Deprecated diff --git a/src/py/flwr/server/serverapp/__init__.py b/src/py/flwr/server/serverapp/__init__.py new file mode 100644 index 000000000000..2873438e3c60 --- /dev/null +++ b/src/py/flwr/server/serverapp/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower AppIO service.""" + + +from .app import flwr_serverapp as flwr_serverapp + +__all__ = [ + "flwr_serverapp", +] diff --git a/src/py/flwr/server/serverapp/app.py b/src/py/flwr/server/serverapp/app.py new file mode 100644 index 000000000000..a02761372097 --- /dev/null +++ b/src/py/flwr/server/serverapp/app.py @@ -0,0 +1,20 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower ServerApp process.""" + + +def flwr_serverapp() -> None: + """Run process-isolated Flower ServerApp.""" + raise NotImplementedError() From 6fbd6782a7917e0c226b05a3ed59a8f2c72db189 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 22 Oct 2024 19:47:42 +0100 Subject: [PATCH 07/10] refactor(framework) Show deprecation notice for `flower-superexec` (#4351) --- pyproject.toml | 2 +- src/py/flwr/superexec/app.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d7d2d644a333..2b789fc3d623 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ exclude = ["src/py/**/*_test.py"] # `flwr` CLI flwr = "flwr.cli.app:app" # SuperExec (can run with either Deployment Engine or Simulation Engine) -flower-superexec = "flwr.superexec.app:run_superexec" +flower-superexec = "flwr.superexec.app:run_superexec" # Deprecated # Simulation Engine flower-simulation = "flwr.simulation.run_simulation:run_simulation_from_cli" # Deployment Engine diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index 1da0557ceab9..4dcdfeefc4c9 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -27,6 +27,7 @@ from flwr.common.config import parse_config_args from flwr.common.constant import EXEC_API_DEFAULT_ADDRESS from flwr.common.exit_handlers import register_exit_handlers +from flwr.common.logger import warn_deprecated_feature from flwr.common.object_ref import load_app, validate from .exec_grpc import run_superexec_api_grpc @@ -37,6 +38,12 @@ def run_superexec() -> None: """Run Flower SuperExec.""" log(INFO, "Starting Flower SuperExec") + warn_deprecated_feature( + "Manually launching the SuperExec is deprecated. Since `flwr 1.13.0` " + "the executor service runs in the SuperLink. Launching it manually is not " + "recommended." + ) + event(EventType.RUN_SUPEREXEC_ENTER) args = _parse_args_run_superexec().parse_args() From 49327eb553b8b6eb93fdf0dc52ff36ab5243c9ec Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 23 Oct 2024 12:19:24 +0100 Subject: [PATCH 08/10] refactor(framework:skip) Set isolation modes as codebase constants (#4353) --- src/py/flwr/client/app.py | 5 ++--- src/py/flwr/client/supernode/app.py | 14 ++++++-------- src/py/flwr/common/constant.py | 4 ++++ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 5b4eff51a7d6..e803eaf88864 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -37,6 +37,8 @@ from flwr.common.address import parse_address from flwr.common.constant import ( CLIENTAPPIO_API_DEFAULT_ADDRESS, + ISOLATION_MODE_PROCESS, + ISOLATION_MODE_SUBPROCESS, MISSING_EXTRA_REST, RUN_ID_NUM_BYTES, TRANSPORT_TYPE_GRPC_ADAPTER, @@ -62,9 +64,6 @@ from .numpy_client import NumPyClient from .run_info_store import DeprecatedRunInfoStore -ISOLATION_MODE_SUBPROCESS = "subprocess" -ISOLATION_MODE_PROCESS = "process" - def _check_actionable_client( client: Optional[Client], client_fn: Optional[ClientFnExt] diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 4ddfe5d40aa3..92a3c9077f46 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -31,6 +31,8 @@ from flwr.common.config import parse_config_args from flwr.common.constant import ( FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, + ISOLATION_MODE_PROCESS, + ISOLATION_MODE_SUBPROCESS, TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, @@ -38,11 +40,7 @@ from flwr.common.exit_handlers import register_exit_handlers from flwr.common.logger import log, warn_deprecated_feature -from ..app import ( - ISOLATION_MODE_PROCESS, - ISOLATION_MODE_SUBPROCESS, - start_client_internal, -) +from ..app import start_client_internal from ..clientapp.utils import get_load_client_app_fn @@ -200,10 +198,10 @@ def _parse_args_run_supernode() -> argparse.ArgumentParser: ISOLATION_MODE_SUBPROCESS, ISOLATION_MODE_PROCESS, ], - help="Isolation mode when running `ClientApp` (optional, possible values: " - "`subprocess`, `process`). By default, `ClientApp` runs in the same process " + help="Isolation mode when running a `ClientApp` (optional, possible values: " + "`subprocess`, `process`). By default, a `ClientApp` runs in the same process " "that executes the SuperNode. Use `subprocess` to configure SuperNode to run " - "`ClientApp` in a subprocess. Use `process` to indicate that a separate " + "a `ClientApp` in a subprocess. Use `process` to indicate that a separate " "independent process gets created outside of SuperNode.", ) parser.add_argument( diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 98607a46835e..081fa49b2153 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -83,6 +83,10 @@ # Message TTL MESSAGE_TTL_TOLERANCE = 1e-1 +# Isolation modes +ISOLATION_MODE_SUBPROCESS = "subprocess" +ISOLATION_MODE_PROCESS = "process" + class MessageType: """Message type.""" From 7c5a2071e07f7d88479c1ad9c19103718372221f Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 23 Oct 2024 12:39:25 +0100 Subject: [PATCH 09/10] feat(framework) Introduce `--isolation` flag for `flower-superlink` (#4354) --- src/py/flwr/server/app.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 72cb2a9b3d9d..0b6325a81c2b 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -42,6 +42,8 @@ FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS, FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, FLEET_API_REST_DEFAULT_ADDRESS, + ISOLATION_MODE_PROCESS, + ISOLATION_MODE_SUBPROCESS, MISSING_EXTRA_REST, TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, @@ -634,6 +636,19 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None: "to create a secure connection.", type=str, ) + parser.add_argument( + "--isolation", + default=ISOLATION_MODE_SUBPROCESS, + required=False, + choices=[ + ISOLATION_MODE_SUBPROCESS, + ISOLATION_MODE_PROCESS, + ], + help="Isolation mode when running a `ServerApp` (`subprocess` by default, " + "possible values: `subprocess`, `process`). Use `subprocess` to configure " + "SuperLink to run a `ServerApp` in a subprocess. Use `process` to indicate " + "that a separate independent process gets created outside of SuperLink.", + ) parser.add_argument( "--database", help="A string representing the path to the database " From 76e1c28900f7161d37330792da54ea994a7ec9b7 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 23 Oct 2024 14:04:56 +0100 Subject: [PATCH 10/10] refactor(examples) Update `quickstart-jax` example (#4121) --- examples/quickstart-jax/README.md | 90 +++++------ examples/quickstart-jax/client.py | 54 ------- examples/quickstart-jax/jax_training.py | 74 --------- .../quickstart-jax/jaxexample/__init__.py | 1 + .../quickstart-jax/jaxexample/client_app.py | 66 ++++++++ .../quickstart-jax/jaxexample/server_app.py | 47 ++++++ examples/quickstart-jax/jaxexample/task.py | 152 ++++++++++++++++++ examples/quickstart-jax/pyproject.toml | 50 ++++-- examples/quickstart-jax/requirements.txt | 4 - examples/quickstart-jax/run.sh | 15 -- examples/quickstart-jax/server.py | 7 - 11 files changed, 338 insertions(+), 222 deletions(-) delete mode 100644 examples/quickstart-jax/client.py delete mode 100644 examples/quickstart-jax/jax_training.py create mode 100644 examples/quickstart-jax/jaxexample/__init__.py create mode 100644 examples/quickstart-jax/jaxexample/client_app.py create mode 100644 examples/quickstart-jax/jaxexample/server_app.py create mode 100644 examples/quickstart-jax/jaxexample/task.py delete mode 100644 examples/quickstart-jax/requirements.txt delete mode 100755 examples/quickstart-jax/run.sh delete mode 100644 examples/quickstart-jax/server.py diff --git a/examples/quickstart-jax/README.md b/examples/quickstart-jax/README.md index b47f3a82e13b..98f9ec8e7901 100644 --- a/examples/quickstart-jax/README.md +++ b/examples/quickstart-jax/README.md @@ -1,85 +1,67 @@ --- tags: [quickstart, linear regression] dataset: [Synthetic] -framework: [JAX] +framework: [JAX, FLAX] --- -# JAX: From Centralized To Federated +# Federated Learning with JAX and Flower (Quickstart Example) -This example demonstrates how an already existing centralized JAX-based machine learning project can be federated with Flower. +This introductory example to Flower uses JAX, but deep knowledge of JAX is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case. Running this example in itself is quite easy. This example uses [FLAX](https://flax.readthedocs.io/en/latest/index.html) to define and train a small CNN model. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the MINST dataset. -This introductory example for Flower uses JAX, but you're not required to be a JAX expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing JAX project. +## Set up the project -## Project Setup +### Clone the project -Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: +Start by cloning the example project: ```shell -git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-jax . && rm -rf flower && cd quickstart-jax +git clone --depth=1 https://github.com/adap/flower.git _tmp \ + && mv _tmp/examples/quickstart-jax . \ + && rm -rf _tmp \ + && cd quickstart-jax ``` -This will create a new directory called `quickstart-jax`, containing the following files: +This will create a new directory called `quickstart-jax` with the following structure: ```shell --- pyproject.toml --- requirements.txt --- jax_training.py --- client.py --- server.py --- README.md +quickstart-jax +├── jaxexample +│ ├── __init__.py +│ ├── client_app.py # Defines your ClientApp +│ ├── server_app.py # Defines your ServerApp +│ └── task.py # Defines your model, training and data loading +├── pyproject.toml # Project metadata like dependencies and configs +└── README.md ``` -### Installing Dependencies +### Install dependencies and project -Project dependencies (such as `jax` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. +Install the dependencies defined in `pyproject.toml` as well as the `jaxexample` package. -#### Poetry - -```shell -poetry install -poetry shell -``` - -Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: - -```shell -poetry run python3 -c "import flwr" -``` - -If you don't see any errors you're good to go! - -#### pip - -Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. - -```shell -pip install -r requirements.txt +```bash +pip install -e . ``` -## Run JAX Federated +## Run the project -This JAX example is based on the [Linear Regression with JAX](https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html) tutorial and uses a sklearn dataset (generating a random dataset for a regression problem). Feel free to consult the tutorial if you want to get a better understanding of JAX. If you play around with the dataset, please keep in mind that the data samples are generated randomly depending on the settings being done while calling the dataset function. Please checkout out the [scikit-learn tutorial for further information](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html). The file `jax_training.py` contains all the steps that are described in the tutorial. It loads the train and test dataset and a linear regression model, trains the model with the training set, and evaluates the trained model on the test set. +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. -The only things we need are a simple Flower server (in `server.py`) and a Flower client (in `client.py`). The Flower client basically takes model and training code tells Flower how to call it. +### Run with the Simulation Engine -Start the server in a terminal as follows: - -```shell -python3 server.py +```bash +flwr run . ``` -Now that the server is running and waiting for clients, we can start two clients that will participate in the federated learning process. To do so simply open two more terminal windows and run the following commands. +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example: -Start client 1 in the first terminal: - -```shell -python3 client.py +```bash +flwr run . --run-config "num-server-rounds=5 batch-size=32" ``` -Start client 2 in the second terminal: +> \[!TIP\] +> For a more detailed walk-through check our [quickstart JAX tutorial](https://flower.ai/docs/framework/tutorial-quickstart-jax.html) -```shell -python3 client.py -``` +### Run with the Deployment Engine -You are now training a JAX-based linear regression model, federated across two clients. The setup is of course simplified since both clients hold a similar dataset, but you can now continue with your own explorations. How about changing from a linear regression to a more sophisticated model? How about adding more clients? +> \[!NOTE\] +> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker. diff --git a/examples/quickstart-jax/client.py b/examples/quickstart-jax/client.py deleted file mode 100644 index 4a2aaf0e5a93..000000000000 --- a/examples/quickstart-jax/client.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Flower client example using JAX for linear regression.""" - -from typing import Callable, Dict, List, Tuple - -import flwr as fl -import jax -import jax.numpy as jnp -import jax_training -import numpy as np - -# Load data and determine model shape -train_x, train_y, test_x, test_y = jax_training.load_data() -grad_fn = jax.grad(jax_training.loss_fn) -model_shape = train_x.shape[1:] - - -class FlowerClient(fl.client.NumPyClient): - def __init__(self): - self.params = jax_training.load_model(model_shape) - - def get_parameters(self, config): - parameters = [] - for _, val in self.params.items(): - parameters.append(np.array(val)) - return parameters - - def set_parameters(self, parameters: List[np.ndarray]) -> None: - for key, value in list(zip(self.params.keys(), parameters)): - self.params[key] = value - - def fit( - self, parameters: List[np.ndarray], config: Dict - ) -> Tuple[List[np.ndarray], int, Dict]: - self.set_parameters(parameters) - self.params, loss, num_examples = jax_training.train( - self.params, grad_fn, train_x, train_y - ) - parameters = self.get_parameters(config={}) - return parameters, num_examples, {"loss": float(loss)} - - def evaluate( - self, parameters: List[np.ndarray], config: Dict - ) -> Tuple[float, int, Dict]: - self.set_parameters(parameters) - loss, num_examples = jax_training.evaluation( - self.params, grad_fn, test_x, test_y - ) - return float(loss), num_examples, {"loss": float(loss)} - - -# Start Flower client -fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() -) diff --git a/examples/quickstart-jax/jax_training.py b/examples/quickstart-jax/jax_training.py deleted file mode 100644 index f57db75d5963..000000000000 --- a/examples/quickstart-jax/jax_training.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Linear Regression with JAX. - -This code examples is based on the following code example: -https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html - -If you have any questions concerning the linear regression used with jax -please read the JAX documentation or the mentioned tutorial. -""" - -from typing import Callable, Dict, List, Tuple - -import jax -import jax.numpy as jnp -import numpy as np -from sklearn.datasets import make_regression -from sklearn.model_selection import train_test_split - -key = jax.random.PRNGKey(0) - - -def load_data() -> ( - Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]] -): - # Load dataset - X, y = make_regression(n_features=3, random_state=0) - X, X_test, y, y_test = train_test_split(X, y) - return X, y, X_test, y_test - - -def load_model(model_shape) -> Dict: - # Extract model parameters - params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)} - return params - - -def loss_fn(params, X, y) -> Callable: - # Return MSE as loss - err = jnp.dot(X, params["w"]) + params["b"] - y - return jnp.mean(jnp.square(err)) - - -def train(params, grad_fn, X, y) -> Tuple[np.array, float, int]: - num_examples = X.shape[0] - for epochs in range(50): - grads = grad_fn(params, X, y) - params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads) - loss = loss_fn(params, X, y) - if epochs % 10 == 0: - print(f"For Epoch {epochs} loss {loss}") - return params, loss, num_examples - - -def evaluation(params, grad_fn, X_test, y_test) -> Tuple[float, int]: - num_examples = X_test.shape[0] - err_test = loss_fn(params, X_test, y_test) - loss_test = jnp.mean(jnp.square(err_test)) - return loss_test, num_examples - - -def main(): - X, y, X_test, y_test = load_data() - model_shape = X.shape[1:] - grad_fn = jax.grad(loss_fn) - print("Model Shape", model_shape) - params = load_model(model_shape) - print("Params", params) - params, loss, num_examples = train(params, grad_fn, X, y) - print("Training loss:", loss) - loss, num_examples = evaluation(params, grad_fn, X_test, y_test) - print("Evaluation loss:", loss) - - -if __name__ == "__main__": - main() diff --git a/examples/quickstart-jax/jaxexample/__init__.py b/examples/quickstart-jax/jaxexample/__init__.py new file mode 100644 index 000000000000..f04ba7eccc81 --- /dev/null +++ b/examples/quickstart-jax/jaxexample/__init__.py @@ -0,0 +1 @@ +"""jaxexample: A Flower / JAX app.""" diff --git a/examples/quickstart-jax/jaxexample/client_app.py b/examples/quickstart-jax/jaxexample/client_app.py new file mode 100644 index 000000000000..915b0d4f16be --- /dev/null +++ b/examples/quickstart-jax/jaxexample/client_app.py @@ -0,0 +1,66 @@ +"""jaxexample: A Flower / JAX app.""" + +import numpy as np +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context + +from jaxexample.task import ( + apply_model, + create_train_state, + get_params, + load_data, + set_params, + train, +) + + +# Define Flower Client and client_fn +class FlowerClient(NumPyClient): + def __init__(self, train_state, trainset, testset): + self.train_state = train_state + self.trainset, self.testset = trainset, testset + + def fit(self, parameters, config): + self.train_state = set_params(self.train_state, parameters) + self.train_state, loss, acc = train(self.train_state, self.trainset) + params = get_params(self.train_state.params) + return ( + params, + len(self.trainset), + {"train_acc": float(acc), "train_loss": float(loss)}, + ) + + def evaluate(self, parameters, config): + self.train_state = set_params(self.train_state, parameters) + + losses = [] + accs = [] + for batch in self.testset: + _, loss, accuracy = apply_model( + self.train_state, batch["image"], batch["label"] + ) + losses.append(float(loss)) + accs.append(float(accuracy)) + + return np.mean(losses), len(self.testset), {"accuracy": np.mean(accs)} + + +def client_fn(context: Context): + + num_partitions = context.node_config["num-partitions"] + partition_id = context.node_config["partition-id"] + batch_size = context.run_config["batch-size"] + trainset, testset = load_data(partition_id, num_partitions, batch_size) + + # Create train state object (model + optimizer) + lr = context.run_config["learning-rate"] + train_state = create_train_state(lr) + + # Return Client instance + return FlowerClient(train_state, trainset, testset).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn, +) diff --git a/examples/quickstart-jax/jaxexample/server_app.py b/examples/quickstart-jax/jaxexample/server_app.py new file mode 100644 index 000000000000..1accf9dabd21 --- /dev/null +++ b/examples/quickstart-jax/jaxexample/server_app.py @@ -0,0 +1,47 @@ +"""jaxexample: A Flower / JAX app.""" + +from typing import List, Tuple + +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedAvg +from jax import random + +from jaxexample.task import create_model, get_params + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} + + +def server_fn(context: Context): + # Read from config + num_rounds = context.run_config["num-server-rounds"] + + # Initialize global model + rng = random.PRNGKey(0) + rng, _ = random.split(rng) + _, model_params = create_model(rng) + params = get_params(model_params) + initial_parameters = ndarrays_to_parameters(params) + + # Define strategy + strategy = FedAvg( + fraction_fit=0.4, + fraction_evaluate=0.5, + evaluate_metrics_aggregation_fn=weighted_average, + initial_parameters=initial_parameters, + ) + config = ServerConfig(num_rounds=num_rounds) + + return ServerAppComponents(strategy=strategy, config=config) + + +# Create ServerApp +app = ServerApp(server_fn=server_fn) diff --git a/examples/quickstart-jax/jaxexample/task.py b/examples/quickstart-jax/jaxexample/task.py new file mode 100644 index 000000000000..3b923dbe6ae8 --- /dev/null +++ b/examples/quickstart-jax/jaxexample/task.py @@ -0,0 +1,152 @@ +"""jaxexample: A Flower / JAX app.""" + +import warnings + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from datasets.utils.logging import disable_progress_bar +from flax import linen as nn +from flax.training.train_state import TrainState +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner + +disable_progress_bar() + +rng = jax.random.PRNGKey(0) +rng, init_rng = jax.random.split(rng) + +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=RuntimeWarning) + + +class CNN(nn.Module): + """A simple CNN model.""" + + @nn.compact + def __call__(self, x): + x = nn.Conv(features=6, kernel_size=(5, 5))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=16, kernel_size=(5, 5))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=120)(x) + x = nn.relu(x) + x = nn.Dense(features=84)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + return x + + +def create_model(rng): + cnn = CNN() + return cnn, cnn.init(rng, jnp.ones([1, 28, 28, 1]))["params"] + + +def create_train_state(learning_rate: float): + """Creates initial `TrainState`.""" + + tx = optax.sgd(learning_rate, momentum=0.9) + model, model_params = create_model(rng) + return TrainState.create(apply_fn=model.apply, params=model_params, tx=tx) + + +def get_params(params): + """Get model parameters as list of numpy arrays.""" + return [np.array(param) for param in jax.tree_util.tree_leaves(params)] + + +def set_params(train_state: TrainState, global_params) -> TrainState: + """Create a new trainstate with the global_params.""" + new_params_dict = jax.tree_util.tree_unflatten( + jax.tree_util.tree_structure(train_state.params), global_params + ) + return train_state.replace(params=new_params_dict) + + +@jax.jit +def apply_model(state, images, labels): + """Computes gradients, loss and accuracy for a single batch.""" + + def loss_fn(params): + logits = state.apply_fn({"params": params}, images) + one_hot = jax.nn.one_hot(labels, 10) + loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + return loss, logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(state.params) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + return grads, loss, accuracy + + +@jax.jit +def update_model(state, grads): + return state.apply_gradients(grads=grads) + + +def train(state, train_ds): + """Train for a single epoch.""" + + epoch_loss = [] + epoch_accuracy = [] + + for batch in train_ds: + batch_images = batch["image"] + batch_labels = batch["label"] + grads, loss, accuracy = apply_model(state, batch_images, batch_labels) + state = update_model(state, grads) + epoch_loss.append(loss) + epoch_accuracy.append(accuracy) + train_loss = np.mean(epoch_loss) + train_accuracy = np.mean(epoch_accuracy) + return state, train_loss, train_accuracy + + +fds = None # Cache FederatedDataset + + +def load_data(partition_id: int, num_partitions: int, batch_size: int): + """Load partition MNIST data.""" + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="mnist", + partitioners={"train": partitioner}, + ) + partition = fds.load_partition(partition_id) + + # Divide data on each node: 80% train, 20% test + partition = partition.train_test_split(test_size=0.2) + + partition["train"].set_format("jax") + partition["test"].set_format("jax") + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["image"] = [ + jnp.expand_dims(jnp.float32(img), 3) / 255 for img in batch["image"] + ] + batch["label"] = [jnp.int16(label) for label in batch["label"]] + return batch + + train_partition = ( + partition["train"] + .batch(batch_size, num_proc=2, drop_last_batch=True) + .with_transform(apply_transforms) + ) + test_partition = ( + partition["test"] + .batch(batch_size, num_proc=2, drop_last_batch=True) + .with_transform(apply_transforms) + ) + + train_partition.shuffle(seed=1234) + test_partition.shuffle(seed=1234) + + return train_partition, test_partition diff --git a/examples/quickstart-jax/pyproject.toml b/examples/quickstart-jax/pyproject.toml index 68a3455aedee..09fd32f7a318 100644 --- a/examples/quickstart-jax/pyproject.toml +++ b/examples/quickstart-jax/pyproject.toml @@ -1,16 +1,38 @@ -[tool.poetry] -name = "jax_example" -version = "0.1.0" -description = "JAX example training a linear regression model with federated learning" -authors = ["The Flower Authors "] +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry.dependencies] -python = ">=3.9,<3.11" -flwr = "1.0.0" -jax = "0.4.17" -jaxlib = "0.4.17" -scikit-learn = "1.1.1" +[project] +name = "jaxexample" +version = "1.0.0" +description = "" +license = "Apache-2.0" +dependencies = [ + "flwr[simulation]>=1.12.0", + "flwr-datasets[vision]>=0.4.0", + "datasets>=2.21.0", + "jax==0.4.31", + "jaxlib==0.4.31", + "flax==0.9.0", +] -[build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "jaxexample.server_app:app" +clientapp = "jaxexample.client_app:app" + +[tool.flwr.app.config] +num-server-rounds = 5 +learning-rate = 0.1 +batch-size = 64 + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 50 diff --git a/examples/quickstart-jax/requirements.txt b/examples/quickstart-jax/requirements.txt deleted file mode 100644 index 964f07a51b7d..000000000000 --- a/examples/quickstart-jax/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -flwr>=1.0,<2.0 -jax==0.4.17 -jaxlib==0.4.17 -scikit-learn==1.1.1 diff --git a/examples/quickstart-jax/run.sh b/examples/quickstart-jax/run.sh deleted file mode 100755 index c64f362086aa..000000000000 --- a/examples/quickstart-jax/run.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -echo "Starting server" -python server.py & -sleep 3 # Sleep for 3s to give the server enough time to start - -for i in `seq 0 1`; do - echo "Starting client $i" - python client.py & -done - -# This will allow you to use CTRL+C to stop all background processes -trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM -# Wait for all background processes to complete -wait diff --git a/examples/quickstart-jax/server.py b/examples/quickstart-jax/server.py deleted file mode 100644 index 2bc3716d84ae..000000000000 --- a/examples/quickstart-jax/server.py +++ /dev/null @@ -1,7 +0,0 @@ -import flwr as fl - -if __name__ == "__main__": - fl.server.start_server( - server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=3), - )