Skip to content

Commit

Permalink
terminate method for backend; asyncio event to trigger stop
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Feb 25, 2024
1 parent de5af24 commit 0e4ab14
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/py/flwr/server/superlink/fleet/vce/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def num_workers(self) -> int:
def is_worker_idle(self) -> bool:
"""Report whether a backend worker is idle and can therefore run a ClientApp."""

@abstractmethod
async def terminate(self) -> None:
"""Terminate backend."""

@abstractmethod
async def process_message(
self,
Expand Down
9 changes: 8 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ClientAppActor,
init_ray,
)
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth

from .backend import Backend, BackendConfig

Expand Down Expand Up @@ -56,7 +57,9 @@ def __init__(
self.client_resources_key = "client_resources"

# Create actor pool
actor_kwargs = backend_config.get("actor_kwargs", {})
use_tf = backend_config.get("tensorflow", False)
actor_kwargs = {"on_actor_init_fn": enable_tf_gpu_growth} if use_tf else {}

client_resources = self._validate_client_resources(config=backend_config)
self.pool = BasicActorPool(
actor_type=ClientAppActor,
Expand Down Expand Up @@ -151,3 +154,7 @@ async def process_message(
) = await self.pool.fetch_result_and_return_actor_to_pool(future)

return out_mssg, updated_context

async def terminate(self) -> None:
"""Terminate all actors in actor pool."""
await self.pool.terminate_all_actors()
4 changes: 3 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# ==============================================================================
"""Fleet VirtualClientEngine API."""

import asyncio
import json
from logging import ERROR, INFO
from typing import Dict
from typing import Dict, Optional

from flwr.client.clientapp import ClientApp, load_client_app
from flwr.client.node_state import NodeState
Expand Down Expand Up @@ -49,6 +50,7 @@ def start_vce(
backend_config_json_stream: str,
state_factory: StateFactory,
working_dir: str,
f_stop: Optional[asyncio.Event] = None,
) -> None:
"""Start Fleet API with the VirtualClientEngine (VCE)."""
# Register SuperNodes
Expand Down
18 changes: 15 additions & 3 deletions src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import threading
import traceback
from abc import ABC
from logging import ERROR, WARNING
from logging import DEBUG, ERROR, WARNING
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

import ray
Expand Down Expand Up @@ -46,7 +46,7 @@ class VirtualClientEngineActor(ABC):

def terminate(self) -> None:
"""Manually terminate Actor object."""
log(WARNING, "Manually terminating %s}", self.__class__.__name__)
log(WARNING, "Manually terminating %s", self.__class__.__name__)
ray.actor.exit_actor()

def run(
Expand Down Expand Up @@ -434,7 +434,9 @@ def __init__(
self.client_resources = client_resources

# Queue of idle actors
self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue()
self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue(
maxsize=1024
)
self.num_actors = 0

# Resolve arguments to pass during actor init
Expand Down Expand Up @@ -464,6 +466,16 @@ async def add_actors_to_pool(self, num_actors: int) -> None:
await self.pool.put(self.create_actor_fn()) # type: ignore
self.num_actors += num_actors

async def terminate_all_actors(self) -> None:
"""Terminate actors in pool."""
num_terminated = 0
while self.pool.qsize():
actor = await self.pool.get()
actor.terminate.remote() # type: ignore
num_terminated += 1

log(DEBUG, "Terminated %i actors", num_terminated)

async def submit(
self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context]
) -> Any:
Expand Down

0 comments on commit 0e4ab14

Please sign in to comment.