diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index a4913d51315b..cf6b716bd189 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -14,8 +14,8 @@ # ============================================================================== """Flower server app.""" - import argparse +import asyncio import importlib.util import sys import threading @@ -362,6 +362,7 @@ def run_superlink() -> None: ) grpc_servers.append(fleet_server) elif args.fleet_api_type == TRANSPORT_TYPE_VCE: + f_stop = asyncio.Event() # Does nothing _run_fleet_api_vce( num_supernodes=args.num_supernodes, client_app_module_name=args.client_app, @@ -369,6 +370,7 @@ def run_superlink() -> None: backend_config_json_stream=args.backend_config, working_dir=args.dir, state_factory=state_factory, + f_stop=f_stop, ) else: raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") @@ -468,6 +470,7 @@ def _run_fleet_api_vce( backend_config_json_stream: str, working_dir: str, state_factory: StateFactory, + f_stop: asyncio.Event, ) -> None: log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)") @@ -478,6 +481,7 @@ def _run_fleet_api_vce( backend_config_json_stream=backend_config_json_stream, state_factory=state_factory, working_dir=working_dir, + f_stop=f_stop, ) diff --git a/src/py/flwr/server/superlink/fleet/vce/__init__.py b/src/py/flwr/server/superlink/fleet/vce/__init__.py index 72cd76f73761..57d39688b527 100644 --- a/src/py/flwr/server/superlink/fleet/vce/__init__.py +++ b/src/py/flwr/server/superlink/fleet/vce/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Fleet VirtualClientEngine side.""" +"""Fleet Simulation Engine side.""" from .vce_api import start_vce diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 409deb077f1d..8ef0d54622ae 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -141,13 +141,13 @@ async def process_message( Return output message and updated context. """ - node_id = message.metadata.dst_node_id + partition_id = message.metadata.partition_id try: # Submite a task to the pool future = await self.pool.submit( lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), - (app, message, str(node_id), context), + (app, message, str(partition_id), context), ) await future @@ -163,10 +163,9 @@ async def process_message( except LoadClientAppError as load_ex: log( ERROR, - "An exception was raised when processing a message. Terminating %s", + "An exception was raised when processing a message by %s", self.__class__.__name__, ) - await self.terminate() raise load_ex async def terminate(self) -> None: diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index fd246b5fc2af..e14c466e7b82 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -20,6 +20,8 @@ from typing import Callable, Dict, Optional, Tuple, Union from unittest import IsolatedAsyncioTestCase +import ray + from flwr.client import Client, NumPyClient from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app from flwr.common import ( @@ -119,6 +121,11 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: class AsyncTestRayBackend(IsolatedAsyncioTestCase): """A basic class that allows runnig multliple asyncio tests.""" + async def on_cleanup(self) -> None: + """Ensure Ray has shutdown.""" + if ray.is_initialized(): + ray.shutdown() + def test_backend_creation_and_termination(self) -> None: """Test creation of RayBackend and its termination.""" backend = RayBackend(backend_config={}, work_dir="") @@ -171,6 +178,7 @@ def test_backend_creation_submit_and_termination_non_existing_client_app( self.test_backend_creation_submit_and_termination( client_app_loader=_load_from_module("a_non_existing_module:app") ) + self.addAsyncCleanup(self.on_cleanup) def test_backend_creation_submit_and_termination_existing_client_app( self, @@ -198,3 +206,4 @@ def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdi client_app_loader=_load_from_module("raybackend_test:client_app"), workdir="/?&%$^#%@$!", ) + self.addAsyncCleanup(self.on_cleanup) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 5d194632541e..ad858cbb9979 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -12,19 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Fleet VirtualClientEngine API.""" +"""Fleet Simulation Engine API.""" + import asyncio import json -from logging import ERROR, INFO -from typing import Dict, Optional +import traceback +from logging import DEBUG, ERROR, INFO, WARN +from typing import Callable, Dict, List, Optional -from flwr.client.client_app import ClientApp, load_client_app +from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app from flwr.client.node_state import NodeState from flwr.common.logger import log +from flwr.common.serde import message_from_taskins, message_to_taskres +from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 from flwr.server.superlink.state import StateFactory -from .backend import error_messages_backends, supported_backends +from .backend import Backend, error_messages_backends, supported_backends NodeToPartitionMapping = Dict[int, int] @@ -42,21 +46,223 @@ def _register_nodes( return nodes_mapping -# pylint: disable=too-many-arguments,unused-argument +# pylint: disable=too-many-arguments,too-many-locals +async def worker( + app_fn: Callable[[], ClientApp], + queue: "asyncio.Queue[TaskIns]", + node_states: Dict[int, NodeState], + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + backend: Backend, +) -> None: + """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" + state = state_factory.state() + while True: + try: + task_ins: TaskIns = await queue.get() + node_id = task_ins.task.consumer.node_id + + # Register and retrieve runstate + node_states[node_id].register_context(run_id=task_ins.run_id) + context = node_states[node_id].retrieve_context(run_id=task_ins.run_id) + + # Convert TaskIns to Message + message = message_from_taskins(task_ins) + # Set partition_id + message.metadata.partition_id = nodes_mapping[node_id] + + # Let backend process message + out_mssg, updated_context = await backend.process_message( + app_fn, message, context + ) + + # Update Context + node_states[node_id].update_context( + task_ins.run_id, context=updated_context + ) + + # Convert to TaskRes + task_res = message_to_taskres(out_mssg) + # Store TaskRes in state + state.store_task_res(task_res) + + except asyncio.CancelledError as e: + log(DEBUG, "Async worker: %s", e) + break + + except LoadClientAppError as app_ex: + log(ERROR, "Async worker: %s", app_ex) + log(ERROR, traceback.format_exc()) + raise + + except Exception as ex: # pylint: disable=broad-exception-caught + log(ERROR, ex) + log(ERROR, traceback.format_exc()) + break + + +async def add_taskins_to_queue( + queue: "asyncio.Queue[TaskIns]", + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + backend: Backend, + consumers: List["asyncio.Task[None]"], + f_stop: asyncio.Event, +) -> None: + """Retrieve TaskIns and add it to the queue.""" + state = state_factory.state() + num_initial_consumers = len(consumers) + while not f_stop.is_set(): + for node_id in nodes_mapping.keys(): + task_ins = state.get_task_ins(node_id=node_id, limit=1) + if task_ins: + await queue.put(task_ins[0]) + + # Count consumers that are running + num_active = sum(not (cc.done()) for cc in consumers) + + # Alert if number of consumers decreased by half + if num_active < num_initial_consumers // 2: + log( + WARN, + "Number of active workers has more than halved: (%i/%i active)", + num_active, + num_initial_consumers, + ) + + # Break if consumers died + if num_active == 0: + raise RuntimeError("All workers have died. Ending Simulation.") + + # Log some stats + log( + DEBUG, + "Simulation Engine stats: " + "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)", + num_active, + num_initial_consumers, + backend.__class__.__name__, + backend.num_workers, + queue.qsize(), + ) + await asyncio.sleep(1.0) + log(DEBUG, "Async producer: Stopped pulling from StateFactory.") + + +async def run( + app_fn: Callable[[], ClientApp], + backend_fn: Callable[[], Backend], + nodes_mapping: NodeToPartitionMapping, + state_factory: StateFactory, + node_states: Dict[int, NodeState], + f_stop: asyncio.Event, +) -> None: + """Run the VCE async.""" + queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128) + + try: + + # Instantiate backend + backend = backend_fn() + + # Build backend + await backend.build() + + # Add workers (they submit Messages to Backend) + worker_tasks = [ + asyncio.create_task( + worker( + app_fn, queue, node_states, state_factory, nodes_mapping, backend + ) + ) + for _ in range(backend.num_workers) + ] + # Create producer (adds TaskIns into Queue) + producer = asyncio.create_task( + add_taskins_to_queue( + queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop + ) + ) + + # Wait for producer to finish + # The producer runs forever until f_stop is set or until + # all worker (consumer) coroutines are completed. Workers + # also run forever and only end if an exception is raised. + await asyncio.gather(producer) + + except Exception as ex: + + log(ERROR, "An exception occured!! %s", ex) + log(ERROR, traceback.format_exc()) + log(WARN, "Stopping Simulation Engine.") + + # Manually trigger stopping event + f_stop.set() + + # Raise exception + raise RuntimeError("Simulation Engine crashed.") from ex + + finally: + # Produced task terminated, now cancel worker tasks + for w_t in worker_tasks: + _ = w_t.cancel() + + while not all(w_t.done() for w_t in worker_tasks): + log(DEBUG, "Terminating async workers...") + await asyncio.sleep(0.5) + + await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()]) + + # Terminate backend + await backend.terminate() + + +# pylint: disable=too-many-arguments,unused-argument,too-many-locals def start_vce( - num_supernodes: int, client_app_module_name: str, backend_name: str, backend_config_json_stream: str, - state_factory: StateFactory, working_dir: str, - f_stop: Optional[asyncio.Event] = None, + f_stop: asyncio.Event, + num_supernodes: Optional[int] = None, + state_factory: Optional[StateFactory] = None, + existing_nodes_mapping: Optional[NodeToPartitionMapping] = None, ) -> None: - """Start Fleet API with the VirtualClientEngine (VCE).""" - # Register SuperNodes - nodes_mapping = _register_nodes( - num_nodes=num_supernodes, state_factory=state_factory - ) + """Start Fleet API with the Simulation Engine.""" + if num_supernodes is not None and existing_nodes_mapping is not None: + raise ValueError( + "Both `num_supernodes` and `existing_nodes_mapping` are provided, " + "but only one is allowed." + ) + if num_supernodes is None: + if state_factory is None or existing_nodes_mapping is None: + raise ValueError( + "If not passing an existing `state_factory` and associated " + "`existing_nodes_mapping` you must supply `num_supernodes` to indicate " + "how many nodes to insert into a new StateFactory that will be created." + ) + if existing_nodes_mapping: + if state_factory is None: + raise ValueError( + "`existing_nodes_mapping` was passed, but no `state_factory` was " + "passed." + ) + log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.") + # Use mapping constructed externally. This also means nodes + # have previously being registered. + nodes_mapping = existing_nodes_mapping + + if not state_factory: + log(INFO, "A StateFactory was not supplied to the SimulationEngine.") + # Create an empty in-memory state factory + state_factory = StateFactory(":flwr-in-memory-state:") + log(INFO, "Created new %s.", state_factory.__class__.__name__) + + if num_supernodes: + # Register SuperNodes + nodes_mapping = _register_nodes( + num_nodes=num_supernodes, state_factory=state_factory + ) # Construct mapping of NodeStates node_states: Dict[int, NodeState] = {} @@ -69,7 +275,6 @@ def start_vce( try: backend_type = supported_backends[backend_name] - _ = backend_type(backend_config, work_dir=working_dir) except KeyError as ex: log( ERROR, @@ -83,10 +288,25 @@ def start_vce( raise ex + def backend_fn() -> Backend: + """Instantiate a Backend.""" + return backend_type(backend_config, work_dir=working_dir) + log(INFO, "client_app_module_name = %s", client_app_module_name) def _load() -> ClientApp: app: ClientApp = load_client_app(client_app_module_name) return app - # start backend + app_fn = _load + + asyncio.run( + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, + ) + ) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py new file mode 100644 index 000000000000..ea2de2e636ba --- /dev/null +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -0,0 +1,292 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test Fleet Simulation Engine API.""" + + +import asyncio +import threading +from itertools import cycle +from json import JSONDecodeError +from math import pi +from pathlib import Path +from time import sleep +from typing import Dict, Optional, Set, Tuple +from unittest import IsolatedAsyncioTestCase +from uuid import UUID + +from flwr.common import GetPropertiesIns, Message, Metadata +from flwr.common.constant import MESSAGE_TYPE_GET_PROPERTIES +from flwr.common.recordset_compat import getpropertiesins_to_recordset +from flwr.common.serde import message_from_taskres, message_to_taskins +from flwr.server.superlink.fleet.vce.vce_api import ( + NodeToPartitionMapping, + _register_nodes, + start_vce, +) +from flwr.server.superlink.state import InMemoryState, StateFactory + + +def terminate_simulation(f_stop: asyncio.Event, sleep_duration: int) -> None: + """Set event to terminate Simulation Engine after `sleep_duration` seconds.""" + sleep(sleep_duration) + f_stop.set() + + +def init_state_factory_nodes_mapping( + num_nodes: int, + num_messages: int, + erroneous_message: Optional[bool] = False, +) -> Tuple[StateFactory, NodeToPartitionMapping, Dict[UUID, float]]: + """Instatiate StateFactory, register nodes and pre-insert messages in the state.""" + # Register a state and a run_id in it + run_id = 1234 + state_factory = StateFactory(":flwr-in-memory-state:") + + # Register a few nodes + nodes_mapping = _register_nodes(num_nodes=num_nodes, state_factory=state_factory) + + expected_results = register_messages_into_state( + state_factory=state_factory, + nodes_mapping=nodes_mapping, + run_id=run_id, + num_messages=num_messages, + erroneous_message=erroneous_message, + ) + return state_factory, nodes_mapping, expected_results + + +# pylint: disable=too-many-locals +def register_messages_into_state( + state_factory: StateFactory, + nodes_mapping: NodeToPartitionMapping, + run_id: int, + num_messages: int, + erroneous_message: Optional[bool] = False, +) -> Dict[UUID, float]: + """Register `num_messages` into the state factory.""" + state: InMemoryState = state_factory.state() # type: ignore + state.run_ids.add(run_id) + # Artificially add TaskIns to state so they can be processed + # by the Simulation Engine logic + nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes + task_ids: Set[UUID] = set() # so we can retrieve them later + expected_results = {} + for i in range(num_messages): + dst_node_id = next(nodes_cycle) + # Construct a Message + mult_factor = 2024 + i + getproperties_ins = GetPropertiesIns(config={"factor": mult_factor}) + recordset = getpropertiesins_to_recordset(getproperties_ins) + message = Message( + content=recordset, + metadata=Metadata( + run_id=run_id, + message_id="", + group_id="", + src_node_id=0, + dst_node_id=dst_node_id, # indicate destination node + reply_to_message="", + ttl="", + message_type=( + "a bad message" + if erroneous_message + else MESSAGE_TYPE_GET_PROPERTIES + ), + ), + ) + # Convert Message to TaskIns + taskins = message_to_taskins(message) + # Instert in state + task_id = state.store_task_ins(taskins) + if task_id: + # Add to UUID set + task_ids.add(task_id) + # Store expected output for check later on + expected_results[task_id] = mult_factor * pi + + return expected_results + + +def _autoresolve_working_dir(rel_client_app_dir: str = "backend") -> str: + """Correctly resolve working directory.""" + file_path = Path(__file__) + working_dir = Path.cwd() + rel_workdir = file_path.relative_to(working_dir) + + # Susbtract lats element and append "backend/test" (wher the client module is.) + return str(rel_workdir.parent / rel_client_app_dir) + + +# pylint: disable=too-many-arguments +def start_and_shutdown( + backend: str = "ray", + clientapp_module: str = "raybackend_test:client_app", + working_dir: str = "", + num_supernodes: Optional[int] = None, + state_factory: Optional[StateFactory] = None, + nodes_mapping: Optional[NodeToPartitionMapping] = None, + duration: int = 0, + backend_config: str = "{}", +) -> None: + """Start Simulation Engine and terminate after specified number of seconds. + + Some tests need to be terminated by triggering externally an asyncio.Event. This + is enabled whtn passing `duration`>0. + """ + f_stop = asyncio.Event() + + if duration: + + # Setup thread that will set the f_stop event, triggering the termination of all + # asyncio logic in the Simulation Engine. It will also terminate the Backend. + termination_th = threading.Thread( + target=terminate_simulation, args=(f_stop, duration) + ) + termination_th.start() + + # Resolve working directory if not passed + if not working_dir: + working_dir = _autoresolve_working_dir() + + start_vce( + num_supernodes=num_supernodes, + client_app_module_name=clientapp_module, + backend_name=backend, + backend_config_json_stream=backend_config, + state_factory=state_factory, + working_dir=working_dir, + f_stop=f_stop, + existing_nodes_mapping=nodes_mapping, + ) + + if duration: + termination_th.join() + + +class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase): + """A basic class that enables testing asyncio functionalities.""" + + def test_erroneous_no_supernodes_client_mapping(self) -> None: + """Test with unset arguments.""" + with self.assertRaises(ValueError): + start_and_shutdown(duration=2) + + def test_erroneous_clientapp_module_name(self) -> None: + """Tests attempt to load a ClientApp that can't be found.""" + num_messages = 7 + num_nodes = 59 + + state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages + ) + with self.assertRaises(RuntimeError): + start_and_shutdown( + clientapp_module="totally_fictitious_app:client", + state_factory=state_factory, + nodes_mapping=nodes_mapping, + ) + + def test_erroneous_messages(self) -> None: + """Test handling of error in async worker (consumer). + + We register messages which will trigger an error when handling, triggering an + error. + """ + num_messages = 100 + num_nodes = 59 + + state_factory, nodes_mapping, _ = init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages, erroneous_message=True + ) + + with self.assertRaises(RuntimeError): + start_and_shutdown( + state_factory=state_factory, + nodes_mapping=nodes_mapping, + ) + + def test_erroneous_backend_config(self) -> None: + """Backend Config should be a JSON stream.""" + with self.assertRaises(JSONDecodeError): + start_and_shutdown(num_supernodes=50, backend_config="not a proper config") + + def test_with_nonexistent_backend(self) -> None: + """Test specifying a backend that does not exist.""" + with self.assertRaises(KeyError): + start_and_shutdown(num_supernodes=50, backend="this-backend-does-not-exist") + + def test_erroneous_arguments_num_supernodes_and_existing_mapping(self) -> None: + """Test ValueError if a node mapping is passed but also num_supernodes. + + Passing `num_supernodes` does nothing since we assume that if a node mapping + is supplied, nodes have been registered externally already. Therefore passing + `num_supernodes` might give the impression that that many nodes will be + registered. We don't do that since a mapping already exists. + """ + with self.assertRaises(ValueError): + start_and_shutdown(num_supernodes=50, nodes_mapping={0: 1}) + + def test_erroneous_arguments_existing_mapping_but_no_state_factory(self) -> None: + """Test ValueError if a node mapping is passed but no state. + + Passing a node mapping indicates that (externally) nodes have registered with a + state factory. Therefore, that state factory should be passed too. + """ + with self.assertRaises(ValueError): + start_and_shutdown(nodes_mapping={0: 1}) + + def test_start_and_shutdown(self) -> None: + """Start Simulation Engine Fleet and terminate it.""" + start_and_shutdown(num_supernodes=50, duration=10) + + # pylint: disable=too-many-locals + def test_start_and_shutdown_with_tasks_in_state(self) -> None: + """Run Simulation Engine with some TasksIns in State. + + This test creates a few nodes and submits a few messages that need to be + executed by the Backend. In order for that to happen the asyncio + producer/consumer logic must function. This also severs to evaluate a valid + ClientApp. + """ + num_messages = 229 + num_nodes = 59 + + state_factory, nodes_mapping, expected_results = ( + init_state_factory_nodes_mapping( + num_nodes=num_nodes, num_messages=num_messages + ) + ) + + # Run + start_and_shutdown( + state_factory=state_factory, nodes_mapping=nodes_mapping, duration=10 + ) + + # Get all TaskRes + state = state_factory.state() + task_ids = set(expected_results.keys()) + task_res_list = state.get_task_res(task_ids=task_ids, limit=len(task_ids)) + + # Check results by first converting to Message + for task_res in task_res_list: + + message = message_from_taskres(task_res) + + # Verify message content is as expected + content = message.content + assert ( + content.configs_records["getpropertiesres.properties"]["result"] + == expected_results[UUID(task_res.task.ancestry[0])] + )