Skip to content

Commit

Permalink
gracefully shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Feb 25, 2024
1 parent b16d0b8 commit 93918db
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/py/flwr/server/superlink/fleet/vce/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def num_workers(self) -> int:
def is_worker_idle(self) -> bool:
"""Report whether a backend worker is idle and can therefore run a ClientApp."""

@abstractmethod
async def terminate(self) -> None:
"""Terminate backend."""

@abstractmethod
async def process_message(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,7 @@ async def process_message(
) = await self.pool.fetch_result_and_return_actor_to_pool(future)

return out_mssg, updated_context

async def terminate(self) -> None:
"""Terminate all actors in actor pool."""
await self.pool.terminate_all_actors()
31 changes: 28 additions & 3 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ async def worker(
# Store TaskRes in state
state.store_task_res(task_res)

except asyncio.CancelledError as e:
log(DEBUG, f"Async worker: {e}")
break

except Exception as ex: # pylint: disable=broad-exception-caught
# pylint: disable=fixme
# TODO: gen TaskRes with relevant error, add it to state_factory
Expand All @@ -103,17 +107,19 @@ async def generate_pull_requests(
queue: TaskInsQueue,
state_factory: StateFactory,
nodes_mapping: NodeToPartitionMapping,
f_stop: asyncio.Event,
) -> None:
"""Generate TaskIns and add it to the queue."""
state = state_factory.state()
while True:
while not (f_stop.is_set()):
for node_id in nodes_mapping.keys():
task_ins = state.get_task_ins(node_id=node_id, limit=1)
if task_ins:
await queue.put(task_ins[0])
log(DEBUG, "TaskIns in queue: %i", queue.qsize())
# pylint: disable=fixme
await asyncio.sleep(1.0) # TODO: revisit
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")


async def run(
Expand All @@ -122,6 +128,7 @@ async def run(
nodes_mapping: NodeToPartitionMapping,
state_factory: StateFactory,
node_states: Dict[int, NodeState],
f_stop: asyncio.Event,
) -> None:
"""Run the VCE async."""
# pylint: disable=fixme
Expand All @@ -135,10 +142,26 @@ async def run(
)
for _ in range(backend.num_workers)
]
asyncio.create_task(generate_pull_requests(queue, state_factory, nodes_mapping))
await queue.join()
producer = asyncio.create_task(
generate_pull_requests(queue, state_factory, nodes_mapping, f_stop)
)

await asyncio.gather(producer)

# Produced task terminated, now cancel worker tasks
for w_t in worker_tasks:
_ = w_t.cancel("Terminate on Simulation Engine shutdown.")

# print('requested cancel')
while not all(w_t.done() for w_t in worker_tasks):
log(DEBUG, "Terminating async workers...")
await asyncio.sleep(0.5)

await asyncio.gather(*worker_tasks)

# Terminate backend
await backend.terminate()


# pylint: disable=too-many-arguments,unused-argument
def start_vce(
Expand All @@ -148,6 +171,7 @@ def start_vce(
backend_config_json_stream: str,
state_factory: StateFactory,
working_dir: str,
f_stop: asyncio.Event,
) -> None:
"""Start Fleet API with the VirtualClientEngine (VCE)."""
# Register SuperNodes
Expand Down Expand Up @@ -195,5 +219,6 @@ def _load() -> ClientApp:
nodes_mapping,
state_factory,
node_states,
f_stop,
)
)
18 changes: 15 additions & 3 deletions src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import threading
import traceback
from abc import ABC
from logging import ERROR, WARNING
from logging import DEBUG, ERROR, WARNING
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

import ray
Expand Down Expand Up @@ -46,7 +46,7 @@ class VirtualClientEngineActor(ABC):

def terminate(self) -> None:
"""Manually terminate Actor object."""
log(WARNING, "Manually terminating %s}", self.__class__.__name__)
log(WARNING, "Manually terminating %s", self.__class__.__name__)
ray.actor.exit_actor()

def run(
Expand Down Expand Up @@ -434,7 +434,9 @@ def __init__(
self.client_resources = client_resources

# Queue of idle actors
self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue()
self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue(
maxsize=1024
)
self.num_actors = 0

# Resolve arguments to pass during actor init
Expand Down Expand Up @@ -464,6 +466,16 @@ async def add_actors_to_pool(self, num_actors: int) -> None:
await self.pool.put(self.create_actor_fn()) # type: ignore
self.num_actors += num_actors

async def terminate_all_actors(self) -> None:
"""Terminate actors in pool."""
num_terminated = 0
while self.pool.qsize():
actor = await self.pool.get()
actor.terminate.remote() # type: ignore
num_terminated += 1

log(DEBUG, "Terminated %i actors", num_terminated)

async def submit(
self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
) -> Any:
Expand Down
11 changes: 10 additions & 1 deletion src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Flower Simulation."""

import argparse
import asyncio
import threading

import grpc
Expand Down Expand Up @@ -44,6 +45,7 @@ def run_simulation() -> None:
)

# Superlink with Simulation Engine
f_stop = asyncio.Event()
superlink_th = threading.Thread(
target=start_vce,
args=(
Expand All @@ -53,6 +55,7 @@ def run_simulation() -> None:
args.backend_config,
state_factory,
args.dir,
f_stop,
),
daemon=False,
)
Expand All @@ -69,11 +72,17 @@ def run_simulation() -> None:
# Launch server app
run(args.server_app, driver, args.dir)

del driver

# Trigger stop event
f_stop.set()

_register_exit_handlers(
grpc_servers=[driver_server],
bckg_threads=[superlink_th],
event_type=EventType.RUN_SUPERLINK_LEAVE,
)
superlink_th.join()


def _parse_args_run_simulation() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -106,7 +115,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
parser.add_argument(
"--backend-config",
type=str,
default='{"client_resources": {"num_cpus":1, "num_gpus":0.0}, "tensorflow": 0}',
default='{"client_resources": {"num_cpus":2, "num_gpus":0.0}, "tensorflow": 0}',
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
"configure a backend. Values supported in <value> are those included by "
"`flwr.common.typing.ConfigsRecordValues`. ",
Expand Down

0 comments on commit 93918db

Please sign in to comment.