From 6a8b9b6e75e05abde02dc82afe8f3b57b978d932 Mon Sep 17 00:00:00 2001 From: Robert Steiner Date: Tue, 5 Nov 2024 16:07:49 +0100 Subject: [PATCH 1/9] feat(framework:skip) Remove SuperExec from Docker CI jobs (#4415) Signed-off-by: Robert Steiner --- .github/workflows/docker-build-main.yml | 1 - .github/workflows/release-nightly.yml | 1 - dev/build-docker-image-matrix.py | 7 ------- 3 files changed, 9 deletions(-) diff --git a/.github/workflows/docker-build-main.yml b/.github/workflows/docker-build-main.yml index 81ef845eae29..e54257048245 100644 --- a/.github/workflows/docker-build-main.yml +++ b/.github/workflows/docker-build-main.yml @@ -56,7 +56,6 @@ jobs: { repository: "flwr/superlink", file_dir: "src/docker/superlink" }, { repository: "flwr/supernode", file_dir: "src/docker/supernode" }, { repository: "flwr/serverapp", file_dir: "src/docker/serverapp" }, - { repository: "flwr/superexec", file_dir: "src/docker/superexec" }, { repository: "flwr/clientapp", file_dir: "src/docker/clientapp" } ] with: diff --git a/.github/workflows/release-nightly.yml b/.github/workflows/release-nightly.yml index fcefff300cb7..32f76cc86c5b 100644 --- a/.github/workflows/release-nightly.yml +++ b/.github/workflows/release-nightly.yml @@ -70,7 +70,6 @@ jobs: { repository: "flwr/superlink", file_dir: "src/docker/superlink" }, { repository: "flwr/supernode", file_dir: "src/docker/supernode" }, { repository: "flwr/serverapp", file_dir: "src/docker/serverapp" }, - { repository: "flwr/superexec", file_dir: "src/docker/superexec" }, { repository: "flwr/clientapp", file_dir: "src/docker/clientapp" } ] with: diff --git a/dev/build-docker-image-matrix.py b/dev/build-docker-image-matrix.py index 52c96e3cca7a..9c658b846e3b 100644 --- a/dev/build-docker-image-matrix.py +++ b/dev/build-docker-image-matrix.py @@ -172,13 +172,6 @@ def tag_latest_ubuntu_with_flwr_version(image: BaseImage) -> List[str]: lambda image: image.distro.name == DistroName.UBUNTU, ) # ubuntu images for each supported python version - + generate_binary_images( - "superexec", - base_images, - tag_latest_ubuntu_with_flwr_version, - lambda image: image.distro.name == DistroName.UBUNTU, - ) - # ubuntu images for each supported python version + generate_binary_images( "clientapp", base_images, From 57e49fee15be5f9d6b4c86f06fde2dc35b03a0e6 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 5 Nov 2024 16:23:04 +0000 Subject: [PATCH 2/9] refactor(framework) Remove support for non-app simulations (#4431) --- src/py/flwr/simulation/run_simulation.py | 146 ++++------------------- 1 file changed, 26 insertions(+), 120 deletions(-) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 012af8760c12..929824843f54 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -21,7 +21,6 @@ import sys import threading import traceback -from argparse import Namespace from logging import DEBUG, ERROR, INFO, WARNING from pathlib import Path from time import sleep @@ -35,7 +34,6 @@ from flwr.common.logger import ( set_logger_propagation, update_console_handler, - warn_deprecated_feature, warn_deprecated_feature_with_example, ) from flwr.common.typing import Run, RunStatus, UserConfig @@ -52,47 +50,6 @@ ) -def _check_args_do_not_interfere(args: Namespace) -> bool: - """Ensure decoupling of flags for different ways to start the simulation.""" - mode_one_args = ["app", "run_config"] - mode_two_args = ["client_app", "server_app"] - - def _resolve_message(conflict_keys: list[str]) -> str: - return ",".join([f"`--{key}`".replace("_", "-") for key in conflict_keys]) - - # When passing `--app`, `--app-dir` is ignored - if args.app and args.app_dir: - log(ERROR, "Either `--app` or `--app-dir` can be set, but not both.") - return False - - if any(getattr(args, key) for key in mode_one_args): - if any(getattr(args, key) for key in mode_two_args): - log( - ERROR, - "Passing any of {%s} alongside with any of {%s}", - _resolve_message(mode_one_args), - _resolve_message(mode_two_args), - ) - return False - - if not args.app: - log(ERROR, "You need to pass --app") - return False - - return True - - # Ensure all args are set (required for the non-FAB mode of execution) - if not all(getattr(args, key) for key in mode_two_args): - log( - ERROR, - "Passing all of %s keys are required.", - _resolve_message(mode_two_args), - ) - return False - - return True - - def _replace_keys(d: Any, match: str, target: str) -> Any: if isinstance(d, dict): return { @@ -115,19 +72,6 @@ def run_simulation_from_cli() -> None: event_details={"backend": args.backend, "num-supernodes": args.num_supernodes}, ) - # Add warnings for deprecated server_app and client_app arguments - if args.server_app: - warn_deprecated_feature( - "The `--server-app` argument is deprecated. " - "Please use the `--app` argument instead." - ) - - if args.client_app: - warn_deprecated_feature( - "The `--client-app` argument is deprecated. " - "Use the `--app` argument instead." - ) - if args.enable_tf_gpu_growth: warn_deprecated_feature_with_example( "Passing `--enable-tf-gpu-growth` is deprecated.", @@ -144,60 +88,39 @@ def run_simulation_from_cli() -> None: backend_config_dict = _replace_keys(backend_config_dict, match="-", target="_") log(DEBUG, "backend_config_dict: %s", backend_config_dict) - # We are supporting two modes for the CLI entrypoint: - # 1) Running an app dir containing a `pyproject.toml` - # 2) Running any ClientApp and SeverApp w/o pyproject.toml being present - # For 2), some CLI args are compulsory, but they are not required for 1) - # We first do these checks - args_check_pass = _check_args_do_not_interfere(args) - if not args_check_pass: - sys.exit("Simulation Engine cannot start.") - run_id = ( generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) if args.run_id is None else args.run_id ) - if args.app: - # Mode 1 - app_path = Path(args.app) - if not app_path.is_dir(): - log(ERROR, "--app is not a directory") - sys.exit("Simulation Engine cannot start.") - - # Load pyproject.toml - config, errors, warnings = load_and_validate( - app_path / "pyproject.toml", check_module=False - ) - if errors: - raise ValueError(errors) - if warnings: - log(WARNING, warnings) + app_path = Path(args.app) + if not app_path.is_dir(): + log(ERROR, "--app is not a directory") + sys.exit("Simulation Engine cannot start.") + + # Load pyproject.toml + config, errors, warnings = load_and_validate( + app_path / "pyproject.toml", check_module=False + ) + if errors: + raise ValueError(errors) - if config is None: - raise ValueError("Config extracted from FAB's pyproject.toml is not valid") + if warnings: + log(WARNING, warnings) - # Get ClientApp and SeverApp components - app_components = config["tool"]["flwr"]["app"]["components"] - client_app_attr = app_components["clientapp"] - server_app_attr = app_components["serverapp"] + if config is None: + raise ValueError("Config extracted from FAB's pyproject.toml is not valid") - override_config = parse_config_args( - [args.run_config] if args.run_config else args.run_config - ) - fused_config = get_fused_config_from_dir(app_path, override_config) - app_dir = args.app - is_app = True + # Get ClientApp and SeverApp components + app_components = config["tool"]["flwr"]["app"]["components"] + client_app_attr = app_components["clientapp"] + server_app_attr = app_components["serverapp"] - else: - # Mode 2 - client_app_attr = args.client_app - server_app_attr = args.server_app - override_config = {} - fused_config = None - app_dir = args.app_dir - is_app = False + override_config = parse_config_args( + [args.run_config] if args.run_config else args.run_config + ) + fused_config = get_fused_config_from_dir(app_path, override_config) # Create run run = Run( @@ -214,13 +137,13 @@ def run_simulation_from_cli() -> None: num_supernodes=args.num_supernodes, backend_name=args.backend, backend_config=backend_config_dict, - app_dir=app_dir, + app_dir=args.app, run=run, enable_tf_gpu_growth=args.enable_tf_gpu_growth, delay_start=args.delay_start, verbose_logging=args.verbose, server_app_run_config=fused_config, - is_app=is_app, + is_app=True, exit_event=EventType.CLI_FLOWER_SIMULATION_LEAVE, ) @@ -583,20 +506,10 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: parser.add_argument( "--app", type=str, - default=None, + required=True, help="Path to a directory containing a FAB-like structure with a " "pyproject.toml.", ) - parser.add_argument( - "--server-app", - help="(DEPRECATED: use --app instead) For example: `server:app` or " - "`project.package.module:wrapper.app`", - ) - parser.add_argument( - "--client-app", - help="(DEPRECATED: use --app instead) For example: `client:app` or " - "`project.package.module:wrapper.app`", - ) parser.add_argument( "--num-supernodes", type=int, @@ -645,13 +558,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser: help="When unset, only INFO, WARNING and ERROR log messages will be shown. " "If set, DEBUG-level logs will be displayed. ", ) - parser.add_argument( - "--app-dir", - default="", - help="Add specified directory to the PYTHONPATH and load" - "ClientApp and ServerApp from there." - " Default: current working directory.", - ) parser.add_argument( "--flwr-dir", default=None, From 87bea16e36085463a68c796014f7277a8dd371c5 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 5 Nov 2024 17:42:59 +0000 Subject: [PATCH 3/9] feat(framework) Validate node IDs for `TaskIns` in `LinkState.store_task_ins` (#4378) --- .../server/driver/inmemory_driver_test.py | 3 +- .../linkstate/in_memory_linkstate.py | 19 ++++- .../superlink/linkstate/linkstate_test.py | 79 +++++++++++++------ .../superlink/linkstate/sqlite_linkstate.py | 31 ++++++-- 4 files changed, 98 insertions(+), 34 deletions(-) diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index a5e893c707fa..c10c57648900 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -45,9 +45,8 @@ def push_messages(driver: InMemoryDriver, num_nodes: int) -> tuple[Iterable[str], int]: """Help push messages to state.""" for _ in range(num_nodes): - driver.state.create_node(ping_interval=PING_MAX_INTERVAL) + node_id = driver.state.create_node(ping_interval=PING_MAX_INTERVAL) num_messages = 3 - node_id = 1 msgs = [ driver.create_message(RecordSet(), "message_type", node_id, "") for _ in range(num_messages) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 0830c26fc49c..52194a5a9ac8 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -87,8 +87,25 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: return None # Validate run_id if task_ins.run_id not in self.run_ids: - log(ERROR, "`run_id` is invalid") + log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id) + return None + # Validate source node ID + if task_ins.task.producer.node_id != 0: + log( + ERROR, + "Invalid source node ID for TaskIns: %s", + task_ins.task.producer.node_id, + ) return None + # Validate destination node ID + if not task_ins.task.consumer.anonymous: + if task_ins.task.consumer.node_id not in self.node_ids: + log( + ERROR, + "Invalid destination node ID for TaskIns: %s", + task_ins.task.consumer.node_id, + ) + return None # Create task_id task_id = uuid4() diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 1fc21bf02a2a..9e00e4a0c49a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -192,11 +192,11 @@ def test_get_task_res_empty(self) -> None: def test_store_task_ins_one(self) -> None: """Test store_task_ins.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) assert task_ins.task.created_at < time.time() # pylint: disable=no-member @@ -204,7 +204,7 @@ def test_store_task_ins_one(self) -> None: # Execute state.store_task_ins(task_ins=task_ins) - task_ins_list = state.get_task_ins(node_id=consumer_node_id, limit=10) + task_ins_list = state.get_task_ins(node_id=node_id, limit=10) # Assert assert len(task_ins_list) == 1 @@ -224,20 +224,39 @@ def test_store_task_ins_one(self) -> None: ) assert actual_task.ttl > 0 + def test_store_task_ins_invalid_node_id(self) -> None: + """Test store_task_ins with invalid node_id.""" + # Prepare + state = self.state_factory() + node_id = state.create_node(1e3) + invalid_node_id = 61016 if node_id != 61016 else 61017 + run_id = state.create_run(None, None, "9f86d08", {}) + task_ins = create_task_ins( + consumer_node_id=invalid_node_id, anonymous=False, run_id=run_id + ) + task_ins2 = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) + task_ins2.task.producer.node_id = 61016 + + # Execute and assert + assert state.store_task_ins(task_ins) is None + assert state.store_task_ins(task_ins2) is None + def test_store_and_delete_tasks(self) -> None: """Test delete_tasks.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins_0 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins_1 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins_2 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) # Insert three TaskIns @@ -250,11 +269,11 @@ def test_store_and_delete_tasks(self) -> None: assert task_id_2 # Get TaskIns to mark them delivered - _ = state.get_task_ins(node_id=consumer_node_id, limit=None) + _ = state.get_task_ins(node_id=node_id, limit=None) # Insert one TaskRes and retrive it to mark it as delivered task_res_0 = create_task_res( - producer_node_id=consumer_node_id, + producer_node_id=node_id, anonymous=False, ancestry=[str(task_id_0)], run_id=run_id, @@ -265,7 +284,7 @@ def test_store_and_delete_tasks(self) -> None: # Insert one TaskRes, but don't retrive it task_res_1: TaskRes = create_task_res( - producer_node_id=consumer_node_id, + producer_node_id=node_id, anonymous=False, ancestry=[str(task_id_1)], run_id=run_id, @@ -332,8 +351,11 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: LinkState = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) - task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) + task_ins = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) # Execute _ = state.store_task_ins(task_ins) @@ -346,12 +368,15 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: LinkState = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) - task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) + task_ins = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) # Execute task_ins_uuid = state.store_task_ins(task_ins) - task_ins_list = state.get_task_ins(node_id=1, limit=None) + task_ins_list = state.get_task_ins(node_id=node_id, limit=None) # Assert assert len(task_ins_list) == 1 @@ -363,14 +388,17 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: LinkState = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) - task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) + task_ins = create_task_ins( + consumer_node_id=node_id, anonymous=False, run_id=run_id + ) # Execute _ = state.store_task_ins(task_ins) # 1st get: set to delivered - task_ins_list = state.get_task_ins(node_id=1, limit=None) + task_ins_list = state.get_task_ins(node_id=node_id, limit=None) assert len(task_ins_list) == 1 @@ -874,11 +902,11 @@ def test_store_task_res_limit_ttl(self) -> None: def test_get_task_ins_not_return_expired(self) -> None: """Test get_task_ins not to return expired tasks.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins.task.created_at = time.time() - 5 task_ins.task.ttl = 5.0 @@ -894,11 +922,11 @@ def test_get_task_ins_not_return_expired(self) -> None: def test_get_task_res_not_return_expired(self) -> None: """Test get_task_res not to return TaskRes if its TaskIns is expired.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins.task.created_at = time.time() - 5 task_ins.task.ttl = 5.1 @@ -948,11 +976,11 @@ def test_get_task_res_return_if_not_expired(self) -> None: """Test get_task_res to return TaskRes if its TaskIns exists and is not expired.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_ins.task.created_at = time.time() - 5 task_ins.task.ttl = 7.1 @@ -960,7 +988,7 @@ def test_get_task_res_return_if_not_expired(self) -> None: task_id = state.store_task_ins(task_ins=task_ins) task_res = create_task_res( - producer_node_id=1, + producer_node_id=node_id, anonymous=False, ancestry=[str(task_id)], run_id=run_id, @@ -980,17 +1008,18 @@ def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None: """Test store_task_res to fail if there is a mismatch between the consumer_node_id of taskIns and the producer_node_id of taskRes.""" # Prepare - consumer_node_id = 1 state = self.state_factory() + node_id = state.create_node(1e3) run_id = state.create_run(None, None, "9f86d08", {}) task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + consumer_node_id=node_id, anonymous=False, run_id=run_id ) task_id = state.store_task_ins(task_ins=task_ins) task_res = create_task_res( - producer_node_id=100, # different than consumer_node_id + # Different than consumer_node_id + producer_node_id=100 if node_id != 100 else 101, anonymous=False, ancestry=[str(task_id)], run_id=run_id, diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 2094bd1d8592..ad73bd4fcce0 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -271,7 +271,6 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: if any(errors): log(ERROR, errors) return None - # Create task_id task_id = uuid4() @@ -284,16 +283,36 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: data[0], ["run_id", "producer_node_id", "consumer_node_id"] ) + # Validate run_id + query = "SELECT run_id FROM run WHERE run_id = ?;" + if not self.query(query, (data[0]["run_id"],)): + log(ERROR, "Invalid run ID for TaskIns: %s", task_ins.run_id) + return None + # Validate source node ID + if task_ins.task.producer.node_id != 0: + log( + ERROR, + "Invalid source node ID for TaskIns: %s", + task_ins.task.producer.node_id, + ) + return None + # Validate destination node ID + query = "SELECT node_id FROM node WHERE node_id = ?;" + if not task_ins.task.consumer.anonymous: + if not self.query(query, (data[0]["consumer_node_id"],)): + log( + ERROR, + "Invalid destination node ID for TaskIns: %s", + task_ins.task.consumer.node_id, + ) + return None + columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. - try: - self.query(query, data) - except sqlite3.IntegrityError: - log(ERROR, "`run` is invalid") - return None + self.query(query, data) return task_id From a079987ba87290794c081470c11041dbe4ec23d5 Mon Sep 17 00:00:00 2001 From: Yan Gao Date: Wed, 6 Nov 2024 04:15:45 +0800 Subject: [PATCH 4/9] fix(benchmarks) Fix an empty output issue for finance evaluation pipeline (#4436) --- benchmarks/flowertune-llm/evaluation/finance/benchmarks.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmarks/flowertune-llm/evaluation/finance/benchmarks.py b/benchmarks/flowertune-llm/evaluation/finance/benchmarks.py index 2b1a174e571f..f2dad1e056b8 100644 --- a/benchmarks/flowertune-llm/evaluation/finance/benchmarks.py +++ b/benchmarks/flowertune-llm/evaluation/finance/benchmarks.py @@ -122,7 +122,10 @@ def inference(dataset, model, tokenizer, batch_size): **tokens, max_length=512, eos_token_id=tokenizer.eos_token_id ) res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res] - out_text = [o.split("Answer: ")[1] for o in res_sentences] + out_text = [ + o.split("Answer: ")[1] if len(o.split("Answer: ")) > 1 else "None" + for o in res_sentences + ] out_text_list += out_text torch.cuda.empty_cache() From a1f74f6dada500b9828b260c834afd87fb4a9057 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 6 Nov 2024 12:25:06 +0000 Subject: [PATCH 5/9] docs(framework) Update instructions in quickstart compose guide (#4409) Co-authored-by: Robert Steiner --- .../docker/run-quickstart-examples-docker-compose.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/source/docker/run-quickstart-examples-docker-compose.rst b/doc/source/docker/run-quickstart-examples-docker-compose.rst index a92f5fffdc3f..70e9b190faaf 100644 --- a/doc/source/docker/run-quickstart-examples-docker-compose.rst +++ b/doc/source/docker/run-quickstart-examples-docker-compose.rst @@ -39,13 +39,16 @@ Run the Quickstart Example .. code-block:: bash :substitutions: - $ curl https://raw.githubusercontent.com/adap/flower/refs/tags/v|stable_flwr_version|/src/docker/complete/compose.yml \ + $ curl https://raw.githubusercontent.com/adap/flower/24b2861465431a5ab234a8c4f76faea7a742b1fd/src/docker/complete/compose.yml \ -o compose.yml -3. Build and start the services using the following command: +3. Export the version of Flower that your environment uses. Then, build and start the + services using the following command: .. code-block:: bash + :substitutions: + $ export FLWR_VERSION="|stable_flwr_version|" # update with your version $ docker compose up --build -d 4. Append the following lines to the end of the ``pyproject.toml`` file and save it: From caafd022693149cd1fa67a304d3783b81e5499f4 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 6 Nov 2024 12:33:44 +0000 Subject: [PATCH 6/9] feat(framework:skip) Add `run_id` to `Context` (#4429) --- src/proto/flwr/proto/message.proto | 9 +++++---- .../clientapp/clientappio_servicer_test.py | 3 +++ .../message_handler/message_handler_test.py | 8 ++++++-- .../secure_aggregation/secaggplus_mod_test.py | 1 + src/py/flwr/client/mod/utils_test.py | 8 ++++++-- src/py/flwr/client/run_info_store.py | 1 + src/py/flwr/common/context.py | 13 +++++++++---- src/py/flwr/common/serde.py | 2 ++ src/py/flwr/common/serde_test.py | 1 + src/py/flwr/proto/message_pb2.py | 16 ++++++++-------- src/py/flwr/proto/message_pb2.pyi | 5 ++++- src/py/flwr/server/run_serverapp.py | 3 +++ src/py/flwr/server/server_app_test.py | 4 +++- .../server/superlink/linkstate/linkstate_test.py | 2 ++ src/py/flwr/simulation/run_simulation.py | 3 +++ src/py/flwr/superexec/deployment.py | 4 +++- 16 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/proto/flwr/proto/message.proto b/src/proto/flwr/proto/message.proto index 7066da5b7e76..cbe4bf7e027f 100644 --- a/src/proto/flwr/proto/message.proto +++ b/src/proto/flwr/proto/message.proto @@ -28,10 +28,11 @@ message Message { } message Context { - uint64 node_id = 1; - map node_config = 2; - RecordSet state = 3; - map run_config = 4; + uint64 run_id = 1; + uint64 node_id = 2; + map node_config = 3; + RecordSet state = 4; + map run_config = 5; } message Metadata { diff --git a/src/py/flwr/client/clientapp/clientappio_servicer_test.py b/src/py/flwr/client/clientapp/clientappio_servicer_test.py index 82c9f16e8201..3c862884a5f3 100644 --- a/src/py/flwr/client/clientapp/clientappio_servicer_test.py +++ b/src/py/flwr/client/clientapp/clientappio_servicer_test.py @@ -66,6 +66,7 @@ def test_set_inputs(self) -> None: content=self.maker.recordset(2, 2, 1), ) context = Context( + run_id=1, node_id=1, node_config={"nodeconfig1": 4.2}, state=self.maker.recordset(2, 2, 1), @@ -122,6 +123,7 @@ def test_get_outputs(self) -> None: content=self.maker.recordset(2, 2, 1), ) context = Context( + run_id=1, node_id=1, node_config={"nodeconfig1": 4.2}, state=self.maker.recordset(2, 2, 1), @@ -186,6 +188,7 @@ def test_push_clientapp_outputs(self) -> None: content=self.maker.recordset(2, 2, 1), ) context = Context( + run_id=1, node_id=1, node_config={"nodeconfig1": 4.2}, state=self.maker.recordset(2, 2, 1), diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 311f8c37e1b1..0be5ab30e026 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -142,7 +142,9 @@ def test_client_without_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), + context=Context( + run_id=2234, node_id=1123, node_config={}, state=RecordSet(), run_config={} + ), ) # Assert @@ -206,7 +208,9 @@ def test_client_with_get_properties() -> None: actual_msg = handle_legacy_message_from_msgtype( client_fn=_get_client_fn(client), message=message, - context=Context(node_id=1123, node_config={}, state=RecordSet(), run_config={}), + context=Context( + run_id=2234, node_id=1123, node_config={}, state=RecordSet(), run_config={} + ), ) # Assert diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index e68bf5177797..89729bca1b9c 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -74,6 +74,7 @@ def func(configs: dict[str, ConfigsRecordValues]) -> ConfigsRecord: def _make_ctxt() -> Context: cfg = ConfigsRecord(SecAggPlusState().to_dict()) return Context( + run_id=234, node_id=123, node_config={}, state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}), diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index e75fb5530b2c..248ee5bdae81 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -104,7 +104,9 @@ def test_multiple_mods(self) -> None: state = RecordSet() state.metrics_records[METRIC] = MetricsRecord({COUNTER: 0.0}) - context = Context(node_id=0, node_config={}, state=state, run_config={}) + context = Context( + run_id=1, node_id=0, node_config={}, state=state, run_config={} + ) message = _get_dummy_flower_message() # Execute @@ -129,7 +131,9 @@ def test_filter(self) -> None: # Prepare footprint: list[str] = [] mock_app = make_mock_app("app", footprint) - context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) + context = Context( + run_id=1, node_id=0, node_config={}, state=RecordSet(), run_config={} + ) message = _get_dummy_flower_message() def filter_mod( diff --git a/src/py/flwr/client/run_info_store.py b/src/py/flwr/client/run_info_store.py index 6b0c3bd3a493..a5cd5129bc3a 100644 --- a/src/py/flwr/client/run_info_store.py +++ b/src/py/flwr/client/run_info_store.py @@ -83,6 +83,7 @@ def register_context( self.run_infos[run_id] = RunInfo( initial_run_config=initial_run_config, context=Context( + run_id=run_id, node_id=self.node_id, node_config=self.node_config, state=RecordSet(), diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 1544b96d3fa3..edf2024c2b1c 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -27,36 +27,41 @@ class Context: Parameters ---------- + run_id : int + The ID that identifies the run. node_id : int The ID that identifies the node. node_config : UserConfig A config (key/value mapping) unique to the node and independent of the `run_config`. This config persists across all runs this node participates in. state : RecordSet - Holds records added by the entity in a given run and that will stay local. + Holds records added by the entity in a given `run_id` and that will stay local. This means that the data it holds will never leave the system it's running from. This can be used as an intermediate storage or scratchpad when executing mods. It can also be used as a memory to access at different points during the lifecycle of this entity (e.g. across multiple rounds) run_config : UserConfig - A config (key/value mapping) held by the entity in a given run and that will - stay local. It can be used at any point during the lifecycle of this entity + A config (key/value mapping) held by the entity in a given `run_id` and that + will stay local. It can be used at any point during the lifecycle of this entity (e.g. across multiple rounds) """ + run_id: int node_id: int node_config: UserConfig state: RecordSet run_config: UserConfig - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments self, + run_id: int, node_id: int, node_config: UserConfig, state: RecordSet, run_config: UserConfig, ) -> None: + self.run_id = run_id self.node_id = node_id self.node_config = node_config self.state = state diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index acac1ca046b7..99c52289b5a1 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -840,6 +840,7 @@ def message_from_proto(message_proto: ProtoMessage) -> Message: def context_to_proto(context: Context) -> ProtoContext: """Serialize `Context` to ProtoBuf.""" proto = ProtoContext( + run_id=context.run_id, node_id=context.node_id, node_config=user_config_to_proto(context.node_config), state=recordset_to_proto(context.state), @@ -851,6 +852,7 @@ def context_to_proto(context: Context) -> ProtoContext: def context_from_proto(context_proto: ProtoContext) -> Context: """Deserialize `Context` from ProtoBuf.""" context = Context( + run_id=context_proto.run_id, node_id=context_proto.node_id, node_config=user_config_from_proto(context_proto.node_config), state=recordset_from_proto(context_proto.state), diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 19e9889158a0..38ad1894f411 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -503,6 +503,7 @@ def test_context_serialization_deserialization() -> None: # Prepare maker = RecordMaker() original = Context( + run_id=0, node_id=1, node_config=maker.user_config(), state=maker.recordset(1, 1, 1), diff --git a/src/py/flwr/proto/message_pb2.py b/src/py/flwr/proto/message_pb2.py index d2201cb07b56..92e37d3b7ed4 100644 --- a/src/py/flwr/proto/message_pb2.py +++ b/src/py/flwr/proto/message_pb2.py @@ -17,7 +17,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x04\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x04\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x04\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xcf\x02\n\x07\x43ontext\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0f\n\x07node_id\x18\x02 \x01(\x04\x12\x38\n\x0bnode_config\x18\x03 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x04 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x05 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x04\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x04\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,11 +31,11 @@ _globals['_MESSAGE']._serialized_start=120 _globals['_MESSAGE']._serialized_end=243 _globals['_CONTEXT']._serialized_start=246 - _globals['_CONTEXT']._serialized_end=565 - _globals['_CONTEXT_NODECONFIGENTRY']._serialized_start=426 - _globals['_CONTEXT_NODECONFIGENTRY']._serialized_end=495 - _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=497 - _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=565 - _globals['_METADATA']._serialized_start=568 - _globals['_METADATA']._serialized_end=755 + _globals['_CONTEXT']._serialized_end=581 + _globals['_CONTEXT_NODECONFIGENTRY']._serialized_start=442 + _globals['_CONTEXT_NODECONFIGENTRY']._serialized_end=511 + _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=513 + _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=581 + _globals['_METADATA']._serialized_start=584 + _globals['_METADATA']._serialized_end=771 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/message_pb2.pyi b/src/py/flwr/proto/message_pb2.pyi index b352917f217e..766829a4798c 100644 --- a/src/py/flwr/proto/message_pb2.pyi +++ b/src/py/flwr/proto/message_pb2.pyi @@ -67,10 +67,12 @@ class Context(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + RUN_ID_FIELD_NUMBER: builtins.int NODE_ID_FIELD_NUMBER: builtins.int NODE_CONFIG_FIELD_NUMBER: builtins.int STATE_FIELD_NUMBER: builtins.int RUN_CONFIG_FIELD_NUMBER: builtins.int + run_id: builtins.int node_id: builtins.int @property def node_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... @@ -80,13 +82,14 @@ class Context(google.protobuf.message.Message): def run_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ... def __init__(self, *, + run_id: builtins.int = ..., node_id: builtins.int = ..., node_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., state: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., run_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["state",b"state"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["node_config",b"node_config","node_id",b"node_id","run_config",b"run_config","state",b"state"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["node_config",b"node_config","node_id",b"node_id","run_config",b"run_config","run_id",b"run_id","state",b"state"]) -> None: ... global___Context = Context class Metadata(google.protobuf.message.Message): diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 9937b993fd02..2215b87295b6 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -188,6 +188,7 @@ def run_server_app() -> None: app_path = str(get_project_dir(fab_id, fab_version, run_.fab_hash, flwr_dir)) config = get_project_config(app_path) + run_id = run_.run_id else: # User provided `app_dir`, but not `--run-id` # Create run if run_id is not provided @@ -204,6 +205,7 @@ def run_server_app() -> None: res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212 # Fetch full `Run` using `run_id` driver.init_run(res.run_id) # pylint: disable=W0212 + run_id = res.run_id # Obtain server app reference and the run config server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"] @@ -221,6 +223,7 @@ def run_server_app() -> None: # Initialize Context context = Context( + run_id=run_id, node_id=0, node_config={}, state=RecordSet(), diff --git a/src/py/flwr/server/server_app_test.py b/src/py/flwr/server/server_app_test.py index b0672b3202ed..b2515f09fdac 100644 --- a/src/py/flwr/server/server_app_test.py +++ b/src/py/flwr/server/server_app_test.py @@ -29,7 +29,9 @@ def test_server_app_custom_mode() -> None: # Prepare app = ServerApp() driver = MagicMock() - context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) + context = Context( + run_id=1, node_id=0, node_config={}, state=RecordSet(), run_config={} + ) called = {"called": False} diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 9e00e4a0c49a..2cdea58a7cb7 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -1036,6 +1036,7 @@ def test_get_set_serverapp_context(self) -> None: # Prepare state: LinkState = self.state_factory() context = Context( + run_id=1, node_id=0, node_config={"mock": "mock"}, state=RecordSet(), @@ -1057,6 +1058,7 @@ def test_set_context_invalid_run_id(self) -> None: # Prepare state: LinkState = self.state_factory() context = Context( + run_id=1, node_id=0, node_config={"mock": "mock"}, state=RecordSet(), diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 929824843f54..88d3fc8b213c 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -234,6 +234,7 @@ def run_serverapp_th( f_stop: threading.Event, has_exception: threading.Event, enable_tf_gpu_growth: bool, + run_id: int, ) -> threading.Thread: """Run SeverApp in a thread.""" @@ -258,6 +259,7 @@ def server_th_with_start_checks( # Initialize Context context = Context( + run_id=run_id, node_id=0, node_config={}, state=RecordSet(), @@ -357,6 +359,7 @@ def _main_loop( f_stop=f_stop, has_exception=server_app_thread_has_exception, enable_tf_gpu_growth=enable_tf_gpu_growth, + run_id=run.run_id, ) # Buffer time so the `ServerApp` in separate thread is ready diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 96d184661048..5d31bcd5edc4 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -139,7 +139,9 @@ def _create_run( def _create_context(self, run_id: int) -> None: """Register a Context for a Run.""" # Create an empty context for the Run - context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) + context = Context( + run_id=run_id, node_id=0, node_config={}, state=RecordSet(), run_config={} + ) # Register the context at the LinkState self.linkstate.set_serverapp_context(run_id=run_id, context=context) From 51ea8f8bb972be9e033ff1f95f6be5bfda55f0f3 Mon Sep 17 00:00:00 2001 From: Robert Steiner Date: Wed, 6 Nov 2024 17:35:51 +0100 Subject: [PATCH 7/9] feat(framework) Add Docker GPU base image (#4420) Signed-off-by: Robert Steiner --- src/docker/base/ubuntu-cuda/Dockerfile | 116 +++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 src/docker/base/ubuntu-cuda/Dockerfile diff --git a/src/docker/base/ubuntu-cuda/Dockerfile b/src/docker/base/ubuntu-cuda/Dockerfile new file mode 100644 index 000000000000..3ffb6401805a --- /dev/null +++ b/src/docker/base/ubuntu-cuda/Dockerfile @@ -0,0 +1,116 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# hadolint global ignore=DL3008 +ARG CUDA_VERSION=12.4.1 +ARG DISTRO=ubuntu +ARG DISTRO_VERSION=24.04 +FROM nvidia/cuda:${CUDA_VERSION}-base-${DISTRO}${DISTRO_VERSION} AS python + +ENV DEBIAN_FRONTEND=noninteractive + +# Install system dependencies +RUN apt-get update \ + && apt-get -y --no-install-recommends install \ + clang-format git unzip ca-certificates openssh-client liblzma-dev \ + build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev wget\ + libsqlite3-dev curl llvm libncursesw5-dev xz-utils tk-dev libxml2-dev \ + libxmlsec1-dev libffi-dev liblzma-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install PyEnv and Python +ARG PYTHON_VERSION=3.11 +ENV PYENV_ROOT=/root/.pyenv +ENV PATH=$PYENV_ROOT/bin:$PATH +# https://github.com/hadolint/hadolint/wiki/DL4006 +SHELL ["/bin/bash", "-o", "pipefail", "-c"] +RUN curl -L https://github.com/pyenv/pyenv-installer/raw/master/bin/pyenv-installer | bash + +# hadolint ignore=DL3003 +RUN git clone https://github.com/pyenv/pyenv.git \ + && cd pyenv/plugins/python-build || exit \ + && ./install.sh + +# Issue: python-build only accepts the exact Python version e.g. 3.11.1 but +# we want to allow more general versions like 3.11 +# Solution: first use pyenv to get the exact version and then pass it to python-build +RUN LATEST=$(pyenv latest -k ${PYTHON_VERSION}) \ + && python-build "${LATEST}" /usr/local/bin/python + +ENV PATH=/usr/local/bin/python/bin:$PATH + +ARG PIP_VERSION +ARG SETUPTOOLS_VERSION +# Keep the version of system Python pip and setuptools in sync with those installed in the +# virtualenv. +RUN pip install -U --no-cache-dir pip==${PIP_VERSION} setuptools==${SETUPTOOLS_VERSION} \ + # Use a virtual environment to ensure that Python packages are installed in the same location + # regardless of whether the subsequent image build is run with the app or the root user + && python -m venv /python/venv +ENV PATH=/python/venv/bin:$PATH + +RUN pip install -U --no-cache-dir \ + pip==${PIP_VERSION} \ + setuptools==${SETUPTOOLS_VERSION} + +ARG FLWR_VERSION +ARG FLWR_VERSION_REF +ARG FLWR_PACKAGE=flwr +# hadolint ignore=DL3013 +RUN if [ -z "${FLWR_VERSION_REF}" ]; then \ + pip install -U --no-cache-dir ${FLWR_PACKAGE}==${FLWR_VERSION}; \ + else \ + pip install -U --no-cache-dir ${FLWR_PACKAGE}@${FLWR_VERSION_REF}; \ + fi + +FROM nvidia/cuda:${CUDA_VERSION}-base-${DISTRO}${DISTRO_VERSION} AS base + +COPY --from=python /usr/local/bin/python /usr/local/bin/python + +ENV DEBIAN_FRONTEND=noninteractive \ + PATH=/usr/local/bin/python/bin:$PATH + +RUN apt-get update \ + && apt-get -y --no-install-recommends install \ + libsqlite3-0 \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* \ + # add non-root user + && useradd \ + --no-create-home \ + --home-dir /app \ + -c "" \ + --uid 49999 app \ + && mkdir -p /app \ + && chown -R app:app /app + +COPY --from=python --chown=app:app /python/venv /python/venv + +ENV PATH=/python/venv/bin:$PATH \ + # Send stdout and stderr stream directly to the terminal. Ensures that no + # output is retained in a buffer if the application crashes. + PYTHONUNBUFFERED=1 \ + # Typically, bytecode is created on the first invocation to speed up following invocation. + # However, in Docker we only make a single invocation (when we start the container). + # Therefore, we can disable bytecode writing. + PYTHONDONTWRITEBYTECODE=1 \ + # Ensure that python encoding is always UTF-8. + PYTHONIOENCODING=UTF-8 \ + LANG=C.UTF-8 \ + LC_ALL=C.UTF-8 + +WORKDIR /app +USER app +ENV HOME=/app From a23f730225a7187ea40516d60b3e26b9119345d0 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 6 Nov 2024 17:07:10 +0000 Subject: [PATCH 8/9] feat(framework) Introduce `SimulationIoServicer` and enable its usage in `SuperLink` (#4427) --- src/py/flwr/common/constant.py | 1 + src/py/flwr/server/app.py | 247 ++++++++++-------- .../server/superlink/simulation/__init__.py | 15 ++ .../superlink/simulation/simulationio_grpc.py | 65 +++++ .../simulation/simulationio_servicer.py | 132 ++++++++++ 5 files changed, 353 insertions(+), 107 deletions(-) create mode 100644 src/py/flwr/server/superlink/simulation/__init__.py create mode 100644 src/py/flwr/server/superlink/simulation/simulationio_grpc.py create mode 100644 src/py/flwr/server/superlink/simulation/simulationio_servicer.py diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 7fddc4a0e110..8aafb68ea17d 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -48,6 +48,7 @@ ) FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9095" EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093" +SIMULATIONIO_API_DEFAULT_ADDRESS = "0.0.0.0:9096" # Constants for ping PING_DEFAULT_INTERVAL = 30 diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index cfada7fca933..e931cf550014 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -47,6 +47,7 @@ ISOLATION_MODE_SUBPROCESS, MISSING_EXTRA_REST, SERVERAPPIO_API_DEFAULT_ADDRESS, + SIMULATIONIO_API_DEFAULT_ADDRESS, TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, @@ -63,6 +64,7 @@ from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server from flwr.superexec.app import load_executor from flwr.superexec.exec_grpc import run_exec_api_grpc +from flwr.superexec.simulation import SimulationEngine from .client_manager import ClientManager from .history import History @@ -79,6 +81,7 @@ from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor from .superlink.linkstate import LinkStateFactory +from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc DATABASE = ":flwr-in-memory-state:" BASE_DIR = get_flwr_dir() / "superlink" / "ffs" @@ -215,6 +218,7 @@ def run_superlink() -> None: # Parse IP addresses serverappio_address, _, _ = _format_address(args.serverappio_api_address) exec_address, _, _ = _format_address(args.exec_api_address) + simulationio_address, _, _ = _format_address(args.simulationio_api_address) # Obtain certificates certificates = _try_obtain_certificates(args) @@ -225,128 +229,148 @@ def run_superlink() -> None: # Initialize FfsFactory ffs_factory = FfsFactory(args.storage_dir) - # Start ServerAppIo API - serverappio_server: grpc.Server = run_serverappio_api_grpc( - address=serverappio_address, + # Start Exec API + executor = load_executor(args) + exec_server: grpc.Server = run_exec_api_grpc( + address=exec_address, state_factory=state_factory, ffs_factory=ffs_factory, + executor=executor, certificates=certificates, + config=parse_config_args( + [args.executor_config] if args.executor_config else args.executor_config + ), ) - grpc_servers = [serverappio_server] + grpc_servers = [exec_server] - # Start Fleet API - bckg_threads = [] - if not args.fleet_api_address: - if args.fleet_api_type in [ - TRANSPORT_TYPE_GRPC_RERE, - TRANSPORT_TYPE_GRPC_ADAPTER, - ]: - args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS - elif args.fleet_api_type == TRANSPORT_TYPE_REST: - args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS - - fleet_address, host, port = _format_address(args.fleet_api_address) - - num_workers = args.fleet_api_num_workers - if num_workers != 1: - log( - WARN, - "The Fleet API currently supports only 1 worker. " - "You have specified %d workers. " - "Support for multiple workers will be added in future releases. " - "Proceeding with a single worker.", - args.fleet_api_num_workers, - ) - num_workers = 1 + # Determine Exec plugin + # If simulation is used, don't start ServerAppIo and Fleet APIs + sim_exec = isinstance(executor, SimulationEngine) - if args.fleet_api_type == TRANSPORT_TYPE_REST: - if ( - importlib.util.find_spec("requests") - and importlib.util.find_spec("starlette") - and importlib.util.find_spec("uvicorn") - ) is None: - sys.exit(MISSING_EXTRA_REST) - - _, ssl_certfile, ssl_keyfile = ( - certificates if certificates is not None else (None, None, None) - ) - - fleet_thread = threading.Thread( - target=_run_fleet_api_rest, - args=( - host, - port, - ssl_keyfile, - ssl_certfile, - state_factory, - ffs_factory, - num_workers, - ), - ) - fleet_thread.start() - bckg_threads.append(fleet_thread) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: - maybe_keys = _try_setup_node_authentication(args, certificates) - interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None - if maybe_keys is not None: - ( - node_public_keys, - server_private_key, - server_public_key, - ) = maybe_keys - state = state_factory.state() - state.store_node_public_keys(node_public_keys) - state.store_server_private_public_key( - private_key_to_bytes(server_private_key), - public_key_to_bytes(server_public_key), - ) - log( - INFO, - "Node authentication enabled with %d known public keys", - len(node_public_keys), - ) - interceptors = [AuthenticateServerInterceptor(state)] + bckg_threads = [] - fleet_server = _run_fleet_api_grpc_rere( - address=fleet_address, + if sim_exec: + simulationio_server: grpc.Server = run_simulationio_api_grpc( + address=simulationio_address, state_factory=state_factory, ffs_factory=ffs_factory, certificates=certificates, - interceptors=interceptors, ) - grpc_servers.append(fleet_server) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER: - fleet_server = _run_fleet_api_grpc_adapter( - address=fleet_address, + grpc_servers.append(simulationio_server) + + else: + # Start ServerAppIo API + serverappio_server: grpc.Server = run_serverappio_api_grpc( + address=serverappio_address, state_factory=state_factory, ffs_factory=ffs_factory, certificates=certificates, ) - grpc_servers.append(fleet_server) - else: - raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") - - # Start Exec API - exec_server: grpc.Server = run_exec_api_grpc( - address=exec_address, - state_factory=state_factory, - ffs_factory=ffs_factory, - executor=load_executor(args), - certificates=certificates, - config=parse_config_args( - [args.executor_config] if args.executor_config else args.executor_config - ), - ) - grpc_servers.append(exec_server) + grpc_servers.append(serverappio_server) + + # Start Fleet API + if not args.fleet_api_address: + if args.fleet_api_type in [ + TRANSPORT_TYPE_GRPC_RERE, + TRANSPORT_TYPE_GRPC_ADAPTER, + ]: + args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS + elif args.fleet_api_type == TRANSPORT_TYPE_REST: + args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS + + fleet_address, host, port = _format_address(args.fleet_api_address) + + num_workers = args.fleet_api_num_workers + if num_workers != 1: + log( + WARN, + "The Fleet API currently supports only 1 worker. " + "You have specified %d workers. " + "Support for multiple workers will be added in future releases. " + "Proceeding with a single worker.", + args.fleet_api_num_workers, + ) + num_workers = 1 + + if args.fleet_api_type == TRANSPORT_TYPE_REST: + if ( + importlib.util.find_spec("requests") + and importlib.util.find_spec("starlette") + and importlib.util.find_spec("uvicorn") + ) is None: + sys.exit(MISSING_EXTRA_REST) + + _, ssl_certfile, ssl_keyfile = ( + certificates if certificates is not None else (None, None, None) + ) - if args.isolation == ISOLATION_MODE_SUBPROCESS: - # Scheduler thread - scheduler_th = threading.Thread( - target=_flwr_serverapp_scheduler, - args=(state_factory, args.serverappio_api_address, args.ssl_ca_certfile), - ) - scheduler_th.start() - bckg_threads.append(scheduler_th) + fleet_thread = threading.Thread( + target=_run_fleet_api_rest, + args=( + host, + port, + ssl_keyfile, + ssl_certfile, + state_factory, + ffs_factory, + num_workers, + ), + ) + fleet_thread.start() + bckg_threads.append(fleet_thread) + elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: + maybe_keys = _try_setup_node_authentication(args, certificates) + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None + if maybe_keys is not None: + ( + node_public_keys, + server_private_key, + server_public_key, + ) = maybe_keys + state = state_factory.state() + state.store_node_public_keys(node_public_keys) + state.store_server_private_public_key( + private_key_to_bytes(server_private_key), + public_key_to_bytes(server_public_key), + ) + log( + INFO, + "Node authentication enabled with %d known public keys", + len(node_public_keys), + ) + interceptors = [AuthenticateServerInterceptor(state)] + + fleet_server = _run_fleet_api_grpc_rere( + address=fleet_address, + state_factory=state_factory, + ffs_factory=ffs_factory, + certificates=certificates, + interceptors=interceptors, + ) + grpc_servers.append(fleet_server) + elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER: + fleet_server = _run_fleet_api_grpc_adapter( + address=fleet_address, + state_factory=state_factory, + ffs_factory=ffs_factory, + certificates=certificates, + ) + grpc_servers.append(fleet_server) + else: + raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") + + if args.isolation == ISOLATION_MODE_SUBPROCESS: + # Scheduler thread + scheduler_th = threading.Thread( + target=_flwr_serverapp_scheduler, + args=( + state_factory, + args.serverappio_api_address, + args.ssl_ca_certfile, + ), + ) + scheduler_th.start() + bckg_threads.append(scheduler_th) # Graceful shutdown register_exit_handlers( @@ -361,7 +385,7 @@ def run_superlink() -> None: for thread in bckg_threads: if not thread.is_alive(): sys.exit(1) - serverappio_server.wait_for_termination(timeout=1) + exec_server.wait_for_termination(timeout=1) def _flwr_serverapp_scheduler( @@ -657,6 +681,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser: _add_args_serverappio_api(parser=parser) _add_args_fleet_api(parser=parser) _add_args_exec_api(parser=parser) + _add_args_simulationio_api(parser=parser) return parser @@ -790,3 +815,11 @@ def _add_args_exec_api(parser: argparse.ArgumentParser) -> None: "For example:\n\n`--executor-config 'verbose=true " 'root-certificates="certificates/superlink-ca.crt"\'`', ) + + +def _add_args_simulationio_api(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--simulationio-api-address", + help="SimulationIo API (gRPC) server address (IPv4, IPv6, or a domain name).", + default=SIMULATIONIO_API_DEFAULT_ADDRESS, + ) diff --git a/src/py/flwr/server/superlink/simulation/__init__.py b/src/py/flwr/server/superlink/simulation/__init__.py new file mode 100644 index 000000000000..8485a3c9a3c7 --- /dev/null +++ b/src/py/flwr/server/superlink/simulation/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower SimulationIo service.""" diff --git a/src/py/flwr/server/superlink/simulation/simulationio_grpc.py b/src/py/flwr/server/superlink/simulation/simulationio_grpc.py new file mode 100644 index 000000000000..d1e79306e0b7 --- /dev/null +++ b/src/py/flwr/server/superlink/simulation/simulationio_grpc.py @@ -0,0 +1,65 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SimulationIo gRPC API.""" + + +from logging import INFO +from typing import Optional + +import grpc + +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.logger import log +from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611 + add_SimulationIoServicer_to_server, +) +from flwr.server.superlink.ffs.ffs_factory import FfsFactory +from flwr.server.superlink.linkstate import LinkStateFactory + +from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server +from .simulationio_servicer import SimulationIoServicer + + +def run_simulationio_api_grpc( + address: str, + state_factory: LinkStateFactory, + ffs_factory: FfsFactory, + certificates: Optional[tuple[bytes, bytes, bytes]], +) -> grpc.Server: + """Run SimulationIo API (gRPC, request-response).""" + # Create SimulationIo API gRPC server + simulationio_servicer: grpc.Server = SimulationIoServicer( + state_factory=state_factory, + ffs_factory=ffs_factory, + ) + simulationio_add_servicer_to_server_fn = add_SimulationIoServicer_to_server + simulationio_grpc_server = generic_create_grpc_server( + servicer_and_add_fn=( + simulationio_servicer, + simulationio_add_servicer_to_server_fn, + ), + server_address=address, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + certificates=certificates, + ) + + log( + INFO, + "Flower Simulation Engine: Starting SimulationIo API on %s", + address, + ) + simulationio_grpc_server.start() + + return simulationio_grpc_server diff --git a/src/py/flwr/server/superlink/simulation/simulationio_servicer.py b/src/py/flwr/server/superlink/simulation/simulationio_servicer.py new file mode 100644 index 000000000000..03bed32e4332 --- /dev/null +++ b/src/py/flwr/server/superlink/simulation/simulationio_servicer.py @@ -0,0 +1,132 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SimulationIo API servicer.""" + +import threading +from logging import DEBUG, INFO + +import grpc +from grpc import ServicerContext + +from flwr.common.constant import Status +from flwr.common.logger import log +from flwr.common.serde import ( + context_from_proto, + context_to_proto, + fab_to_proto, + run_status_from_proto, + run_to_proto, +) +from flwr.common.typing import Fab, RunStatus +from flwr.proto import simulationio_pb2_grpc +from flwr.proto.log_pb2 import ( # pylint: disable=E0611 + PushLogsRequest, + PushLogsResponse, +) +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + UpdateRunStatusRequest, + UpdateRunStatusResponse, +) +from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611 + PullSimulationInputsRequest, + PullSimulationInputsResponse, + PushSimulationOutputsRequest, + PushSimulationOutputsResponse, +) +from flwr.server.superlink.ffs.ffs_factory import FfsFactory +from flwr.server.superlink.linkstate import LinkStateFactory + + +class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer): + """SimulationIo API servicer.""" + + def __init__( + self, state_factory: LinkStateFactory, ffs_factory: FfsFactory + ) -> None: + self.state_factory = state_factory + self.ffs_factory = ffs_factory + self.lock = threading.RLock() + + def PullSimulationInputs( + self, request: PullSimulationInputsRequest, context: ServicerContext + ) -> PullSimulationInputsResponse: + """Pull SimultionIo process inputs.""" + log(DEBUG, "SimultionIoServicer.SimultionIoInputs") + # Init access to LinkState and Ffs + state = self.state_factory.state() + ffs = self.ffs_factory.ffs() + + # Lock access to LinkState, preventing obtaining the same pending run_id + with self.lock: + # Attempt getting the run_id of a pending run + run_id = state.get_pending_run_id() + # If there's no pending run, return an empty response + if run_id is None: + return PullSimulationInputsResponse() + + # Retrieve Context, Run and Fab for the run_id + serverapp_ctxt = state.get_serverapp_context(run_id) + run = state.get_run(run_id) + fab = None + if run and run.fab_hash: + if result := ffs.get(run.fab_hash): + fab = Fab(run.fab_hash, result[0]) + if run and fab and serverapp_ctxt: + # Update run status to STARTING + if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")): + log(INFO, "Starting run %d", run_id) + return PullSimulationInputsResponse( + context=context_to_proto(serverapp_ctxt), + run=run_to_proto(run), + fab=fab_to_proto(fab), + ) + + # Raise an exception if the Run or Fab is not found, + # or if the status cannot be updated to STARTING + raise RuntimeError(f"Failed to start run {run_id}") + + def PushSimulationOutputs( + self, request: PushSimulationOutputsRequest, context: ServicerContext + ) -> PushSimulationOutputsResponse: + """Push Simulation process outputs.""" + log(DEBUG, "SimultionIoServicer.PushSimulationOutputs") + state = self.state_factory.state() + state.set_serverapp_context(request.run_id, context_from_proto(request.context)) + return PushSimulationOutputsResponse() + + def UpdateRunStatus( + self, request: UpdateRunStatusRequest, context: grpc.ServicerContext + ) -> UpdateRunStatusResponse: + """Update the status of a run.""" + log(DEBUG, "SimultionIoServicer.UpdateRunStatus") + state = self.state_factory.state() + + # Update the run status + state.update_run_status( + run_id=request.run_id, new_status=run_status_from_proto(request.run_status) + ) + return UpdateRunStatusResponse() + + def PushLogs( + self, request: PushLogsRequest, context: grpc.ServicerContext + ) -> PushLogsResponse: + """Push logs.""" + log(DEBUG, "ServerAppIoServicer.PushLogs") + state = self.state_factory.state() + + # Add logs to LinkState + merged_logs = "".join(request.logs) + state.add_serverapp_log(request.run_id, merged_logs) + return PushLogsResponse() From 4473cb8d29fed6a3c088099772a6f9ed684ef926 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 6 Nov 2024 17:39:19 +0000 Subject: [PATCH 9/9] feat(framework) Introduce `SimulationIoConnection` (#4430) --- src/py/flwr/simulation/__init__.py | 2 + .../simulation/simulationio_connection.py | 86 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 src/py/flwr/simulation/simulationio_connection.py diff --git a/src/py/flwr/simulation/__init__.py b/src/py/flwr/simulation/__init__.py index a171347b1507..912613cbad9f 100644 --- a/src/py/flwr/simulation/__init__.py +++ b/src/py/flwr/simulation/__init__.py @@ -18,6 +18,7 @@ import importlib from flwr.simulation.run_simulation import run_simulation +from flwr.simulation.simulationio_connection import SimulationIoConnection is_ray_installed = importlib.util.find_spec("ray") is not None @@ -37,6 +38,7 @@ def start_simulation(*args, **kwargs): # type: ignore __all__ = [ + "SimulationIoConnection", "run_simulation", "start_simulation", ] diff --git a/src/py/flwr/simulation/simulationio_connection.py b/src/py/flwr/simulation/simulationio_connection.py new file mode 100644 index 000000000000..a53f0f5ce317 --- /dev/null +++ b/src/py/flwr/simulation/simulationio_connection.py @@ -0,0 +1,86 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower SimulationIo connection.""" + + +from logging import DEBUG, WARNING +from typing import Optional, cast + +import grpc + +from flwr.common.constant import SIMULATIONIO_API_DEFAULT_ADDRESS +from flwr.common.grpc import create_channel +from flwr.common.logger import log +from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611 + + +class SimulationIoConnection: + """`SimulationIoConnection` provides an interface to the SimulationIo API. + + Parameters + ---------- + simulationio_service_address : str (default: "[::]:9094") + The address (URL, IPv6, IPv4) of the SuperLink SimulationIo API service. + root_certificates : Optional[bytes] (default: None) + The PEM-encoded root certificates as a byte string. + If provided, a secure connection using the certificates will be + established to an SSL-enabled Flower server. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + simulationio_service_address: str = SIMULATIONIO_API_DEFAULT_ADDRESS, + root_certificates: Optional[bytes] = None, + ) -> None: + self._addr = simulationio_service_address + self._cert = root_certificates + self._grpc_stub: Optional[SimulationIoStub] = None + self._channel: Optional[grpc.Channel] = None + + @property + def _is_connected(self) -> bool: + """Check if connected to the SimulationIo API server.""" + return self._channel is not None + + @property + def _stub(self) -> SimulationIoStub: + """SimulationIo stub.""" + if not self._is_connected: + self._connect() + return cast(SimulationIoStub, self._grpc_stub) + + def _connect(self) -> None: + """Connect to the SimulationIo API.""" + if self._is_connected: + log(WARNING, "Already connected") + return + self._channel = create_channel( + server_address=self._addr, + insecure=(self._cert is None), + root_certificates=self._cert, + ) + self._grpc_stub = SimulationIoStub(self._channel) + log(DEBUG, "[SimulationIO] Connected to %s", self._addr) + + def _disconnect(self) -> None: + """Disconnect from the SimulationIo API.""" + if not self._is_connected: + log(DEBUG, "Already disconnected") + return + channel: grpc.Channel = self._channel + self._channel = None + self._grpc_stub = None + channel.close() + log(DEBUG, "[SimulationIO] Disconnected")