Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete VCE Fleet API loop #2998

Merged
merged 101 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 95 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
4c83a5e
init
jafermarq Feb 21, 2024
a85db40
base backend
jafermarq Feb 21, 2024
b770312
update
jafermarq Feb 21, 2024
b9c6455
update docstrings
jafermarq Feb 21, 2024
cd48539
minor fixes
jafermarq Feb 21, 2024
2791163
updates
jafermarq Feb 22, 2024
4ca33ec
backend-config should contain value types
jafermarq Feb 22, 2024
1b6564a
fix
jafermarq Feb 22, 2024
bad8727
w/ previous
jafermarq Feb 22, 2024
a68172f
fix
jafermarq Feb 22, 2024
e2b072e
Merge branch 'vce-flee-api' into vce-fleet-api-backends
jafermarq Feb 22, 2024
6881866
fix for json.loads
jafermarq Feb 22, 2024
935e333
keep backend-config as json string
jafermarq Feb 22, 2024
704d667
merged
jafermarq Feb 22, 2024
adfe198
added `RayBackend` and `SimpleActorPool`
jafermarq Feb 22, 2024
d0bab9a
complete VCE loop; works with `simulation-pytorch` example
jafermarq Feb 22, 2024
5e0ee74
fix exclude generation logic
jafermarq Feb 22, 2024
82836ad
Merge branch 'vce-fleet-api-backends-ray' into vce-fleet-api-loop
jafermarq Feb 22, 2024
4853813
simulation-tf w/ Flower-next; updates pytorch example too
jafermarq Feb 22, 2024
31787cf
format
jafermarq Feb 22, 2024
8522022
passing actor init kwargs
jafermarq Feb 22, 2024
788e63f
Merge branch 'vce-fleet-api-backends-ray' into vce-fleet-api-loop
jafermarq Feb 22, 2024
9200513
updated examples
jafermarq Feb 22, 2024
d8935b3
auto enable GPU growth if 'tensorflow' passed
jafermarq Feb 22, 2024
b108be2
return to default 1xCPU for virtual client
jafermarq Feb 22, 2024
17e3089
Merge branch 'main' into vce-flee-api
danieljanes Feb 22, 2024
0e02b05
moved import
jafermarq Feb 22, 2024
b77e4e8
Merge branch 'vce-flee-api' into vce-fleet-api-backends
jafermarq Feb 22, 2024
fd67f22
Apply suggestions from code review
jafermarq Feb 22, 2024
cf004d8
renamed vars; exporting
jafermarq Feb 22, 2024
f10a6ca
Merge branch 'vce-flee-api' into vce-fleet-api-backends
jafermarq Feb 22, 2024
d097bcd
Merge branch 'main' into vce-fleet-api-backends
jafermarq Feb 22, 2024
e669e2a
merge
jafermarq Feb 22, 2024
a521b40
moved
jafermarq Feb 22, 2024
890e329
Merge branch 'move-server-functions' into vce-fleet-api-backends-ray
jafermarq Feb 22, 2024
2ccf612
Merge branch 'main' into vce-fleet-api-backends
jafermarq Feb 22, 2024
0f7a071
Merge branch 'vce-fleet-api-backends' into vce-fleet-api-backends-ray
jafermarq Feb 22, 2024
443551f
revisited imports readiness for chosen backend
jafermarq Feb 23, 2024
bdfcccb
merge w/ main
jafermarq Feb 23, 2024
79f363e
Apply suggestions from code review
jafermarq Feb 23, 2024
12fa44c
remove suprefluous if
jafermarq Feb 23, 2024
c309046
fixes
jafermarq Feb 23, 2024
d217677
merge and more
jafermarq Feb 23, 2024
ee20d50
merge w/ main
jafermarq Feb 23, 2024
0e4ab14
terminate method for backend; asyncio event to trigger stop
jafermarq Feb 25, 2024
e173312
Merge branch 'vce-fleet-terminate-and-rename' into vce-fleet-api-loop
jafermarq Feb 25, 2024
21e9932
propagate terminate asyncio logic
jafermarq Feb 25, 2024
2faef27
merge
jafermarq Feb 26, 2024
f8b57c5
added build/process/terminate tests
jafermarq Feb 26, 2024
accc67c
Merge branch 'vce-fleet-raybackend-tests' into vce-fleet-api-loop
jafermarq Feb 26, 2024
39e3234
format
jafermarq Feb 26, 2024
8ea4b08
fix for py3.8
jafermarq Feb 26, 2024
fb9bfc2
Merge branch 'vce-fleet-raybackend-tests' into vce-fleet-api-loop
jafermarq Feb 26, 2024
35c55d4
fix py3.11
jafermarq Feb 26, 2024
49bc661
fix import
jafermarq Feb 26, 2024
4506a17
wrapped asyncio test under `IsolatedAsyncioTestCase` class
jafermarq Feb 26, 2024
7d6f821
Merge branch 'vce-fleet-raybackend-tests' into vce-fleet-api-loop
jafermarq Feb 26, 2024
ed5b181
start/shutdown tests
jafermarq Feb 26, 2024
2c05cdd
full loop tests; tweaks
jafermarq Feb 26, 2024
c589f7d
erge w/ main
jafermarq Feb 26, 2024
98fb4b4
.
jafermarq Feb 26, 2024
65c8b79
undoing changes to simulation examples
jafermarq Feb 26, 2024
35ab1f3
Apply suggestions from code review
jafermarq Feb 26, 2024
5b3365a
Apply suggestions from code review
jafermarq Feb 26, 2024
785ac91
introduced `partition_id`.
jafermarq Feb 26, 2024
eba053a
fix for ray proxies and tests
jafermarq Feb 26, 2024
3d15041
Merge branch 'main' into metadata-with-partition-id
jafermarq Feb 26, 2024
ee84f85
Merge branch 'main' into metadata-with-partition-id
danieljanes Feb 26, 2024
27d2bb1
re written
jafermarq Feb 26, 2024
b7d5521
Merge branch 'metadata-with-partition-id' into vce-fleet-api-loop
jafermarq Feb 26, 2024
1969aac
using `metadata.partition_id`
jafermarq Feb 26, 2024
33e7be3
Merge branch 'main' into vce-fleet-api-loop
jafermarq Feb 27, 2024
666c65f
Merge branch 'main' into vce-fleet-api-loop
danieljanes Feb 27, 2024
ab55b0c
more tests
jafermarq Feb 27, 2024
28dda2d
more
jafermarq Feb 27, 2024
5cd047e
minor update
jafermarq Feb 27, 2024
4be09c2
handle loading of non-existing ClientApp
jafermarq Feb 27, 2024
17d3d34
merge /w branch ahead; test vce with non existing clientapp
jafermarq Feb 27, 2024
3c616e9
better tests; reorg
jafermarq Feb 27, 2024
ffed29b
Merge branch 'more-raybackend-tests-yes' into vce-fleet-api-loop
jafermarq Feb 27, 2024
aed4420
update
jafermarq Feb 27, 2024
96519dc
w/ previous
jafermarq Feb 27, 2024
e491b7b
Merge branch 'more-raybackend-tests-yes' into vce-fleet-api-loop
jafermarq Feb 27, 2024
21f03a9
fix
jafermarq Feb 27, 2024
ab63974
Merge branch 'minor-fix-cli-test' into more-raybackend-tests-yes
jafermarq Feb 27, 2024
1d137da
Merge branch 'main' into more-raybackend-tests-yes
jafermarq Feb 27, 2024
a4590af
Merge branch 'more-raybackend-tests-yes' into vce-fleet-api-loop
jafermarq Feb 27, 2024
8f6de1e
Merge branch 'main' into more-raybackend-tests-yes
danieljanes Feb 27, 2024
c45c4af
no need for separate test/ dir
jafermarq Feb 27, 2024
e2ac2b0
Merge branch 'more-raybackend-tests-yes' into vce-fleet-api-loop
jafermarq Feb 27, 2024
c9492f0
update
jafermarq Feb 27, 2024
b1e0460
merge w/ main
jafermarq Feb 27, 2024
b3d397b
better handling of exceptions in vce's ; adjust test for
jafermarq Feb 27, 2024
bd7b1aa
completed tests.
jafermarq Feb 27, 2024
6e3271b
minior formatting
jafermarq Feb 28, 2024
67777c5
Apply suggestions from code review
jafermarq Feb 28, 2024
35f4566
Merge branch 'main' into vce-fleet-api-loop
jafermarq Feb 28, 2024
46eac84
fixes post review
jafermarq Feb 28, 2024
f2ee2cd
Merge branch 'main' into vce-fleet-api-loop
jafermarq Feb 28, 2024
cc6a145
instantiating backend in asyncio event loop
jafermarq Feb 28, 2024
5cffeba
Merge branch 'main' into vce-fleet-api-loop
danieljanes Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# ==============================================================================
"""Flower server app."""


import argparse
import asyncio
import importlib.util
import sys
import threading
Expand Down Expand Up @@ -362,13 +362,15 @@ def run_superlink() -> None:
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_VCE:
f_stop = asyncio.Event() # Does nothing
_run_fleet_api_vce(
num_supernodes=args.num_supernodes,
client_app_module_name=args.client_app,
backend_name=args.backend,
backend_config_json_stream=args.backend_config,
working_dir=args.dir,
state_factory=state_factory,
f_stop=f_stop,
)
else:
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
Expand Down Expand Up @@ -468,6 +470,7 @@ def _run_fleet_api_vce(
backend_config_json_stream: str,
working_dir: str,
state_factory: StateFactory,
f_stop: asyncio.Event,
) -> None:
log(INFO, "Flower VCE: Starting Fleet API (VirtualClientEngine)")

Expand All @@ -478,6 +481,7 @@ def _run_fleet_api_vce(
backend_config_json_stream=backend_config_json_stream,
state_factory=state_factory,
working_dir=working_dir,
f_stop=f_stop,
)


Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/superlink/fleet/vce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fleet VirtualClientEngine side."""
"""Fleet Simulation Engine side."""

from .vce_api import start_vce

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def process_message(

Return output message and updated context.
"""
node_id = message.metadata.dst_node_id
node_id = message.metadata.partition_id

try:
# Submite a task to the pool
Expand All @@ -163,10 +163,9 @@ async def process_message(
except LoadClientAppError as load_ex:
log(
ERROR,
"An exception was raised when processing a message. Terminating %s",
"An exception was raised when processing a message by %s",
self.__class__.__name__,
)
await self.terminate()
raise load_ex

async def terminate(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Callable, Dict, Optional, Tuple, Union
from unittest import IsolatedAsyncioTestCase

import ray

from flwr.client import Client, NumPyClient
from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
from flwr.common import (
Expand Down Expand Up @@ -119,6 +121,11 @@ def _create_message_and_context() -> Tuple[Message, Context, float]:
class AsyncTestRayBackend(IsolatedAsyncioTestCase):
"""A basic class that allows runnig multliple asyncio tests."""

async def on_cleanup(self) -> None:
"""Ensure Ray has shutdown."""
if ray.is_initialized():
ray.shutdown()

def test_backend_creation_and_termination(self) -> None:
"""Test creation of RayBackend and its termination."""
backend = RayBackend(backend_config={}, work_dir="")
Expand Down Expand Up @@ -171,6 +178,7 @@ def test_backend_creation_submit_and_termination_non_existing_client_app(
self.test_backend_creation_submit_and_termination(
client_app_loader=_load_from_module("a_non_existing_module:app")
)
self.addAsyncCleanup(self.on_cleanup)

def test_backend_creation_submit_and_termination_existing_client_app(
self,
Expand Down Expand Up @@ -198,3 +206,4 @@ def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdi
client_app_loader=_load_from_module("raybackend_test:client_app"),
workdir="/?&%$^#%@$!",
)
self.addAsyncCleanup(self.on_cleanup)
243 changes: 227 additions & 16 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fleet VirtualClientEngine API."""
"""Fleet Simulation Engine API."""


import asyncio
import json
from logging import ERROR, INFO
from typing import Dict, Optional
import traceback
from logging import DEBUG, ERROR, INFO, WARN
from typing import Callable, Dict, List, Optional

from flwr.client.client_app import ClientApp, load_client_app
from flwr.client.client_app import ClientApp, LoadClientAppError, load_client_app
from flwr.client.node_state import NodeState
from flwr.common.logger import log
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611
from flwr.server.superlink.state import StateFactory

from .backend import error_messages_backends, supported_backends
from .backend import Backend, error_messages_backends, supported_backends

NodeToPartitionMapping = Dict[int, int]

Expand All @@ -42,21 +46,217 @@ def _register_nodes(
return nodes_mapping


# pylint: disable=too-many-arguments,unused-argument
# pylint: disable=too-many-arguments,too-many-locals
async def worker(
app: Callable[[], ClientApp],
queue: "asyncio.Queue[TaskIns]",
node_states: Dict[int, NodeState],
state_factory: StateFactory,
nodes_mapping: NodeToPartitionMapping,
backend: Backend,
) -> None:
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
state = state_factory.state()
while True:
try:
task_ins: TaskIns = await queue.get()
node_id = task_ins.task.consumer.node_id

# Register and retrieve runstate
node_states[node_id].register_context(run_id=task_ins.run_id)
context = node_states[node_id].retrieve_context(run_id=task_ins.run_id)

# Convert TaskIns to Message
message = message_from_taskins(task_ins)
# Replace node ID with data partition ID
message.metadata.partition_id = nodes_mapping[node_id]

# Let backend process message
out_mssg, updated_context = await backend.process_message(
app, message, context
)

# Update Context
node_states[node_id].update_context(
task_ins.run_id, context=updated_context
)

# Convert to TaskRes
task_res = message_to_taskres(out_mssg)
# Store TaskRes in state
state.store_task_res(task_res)

except asyncio.CancelledError as e:
log(DEBUG, "Async worker: %s", e)
break

except LoadClientAppError as app_ex:
log(ERROR, "Async worker: %s", app_ex)
log(ERROR, traceback.format_exc())
raise

except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, ex)
log(ERROR, traceback.format_exc())
break


async def add_taskins_to_queue(
queue: "asyncio.Queue[TaskIns]",
state_factory: StateFactory,
nodes_mapping: NodeToPartitionMapping,
backend: Backend,
consumers: List["asyncio.Task[None]"],
f_stop: asyncio.Event,
) -> None:
"""Retrieve TaskIns and add it to the queue."""
state = state_factory.state()
num_initial_consumers = len(consumers)
while not f_stop.is_set():
for node_id in nodes_mapping.keys():
task_ins = state.get_task_ins(node_id=node_id, limit=1)
if task_ins:
await queue.put(task_ins[0])

# Count consumers that are running
num_active = sum(not (cc.done()) for cc in consumers)

# Alert if number of consumers decreased by half
if num_active < num_initial_consumers // 2:
log(
WARN,
"Number of active workers has more than halved: (%i/%i active)",
num_active,
num_initial_consumers,
)

# Break if consumers died
if num_active == 0:
raise RuntimeError("All workers have died. Ending Simulation.")

# Log some stats
log(
DEBUG,
"Simulation Engine stats: "
"Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
num_active,
num_initial_consumers,
backend.__class__.__name__,
backend.num_workers,
queue.qsize(),
)
await asyncio.sleep(1.0)
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")


async def run(
app: Callable[[], ClientApp],
backend: Backend,
nodes_mapping: NodeToPartitionMapping,
state_factory: StateFactory,
node_states: Dict[int, NodeState],
f_stop: asyncio.Event,
) -> None:
"""Run the VCE async."""
# pylint: disable=fixme
queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)

try:
# Build backend
await backend.build()

# Add workers (they submit Messages to Backend)
worker_tasks = [
asyncio.create_task(
worker(app, queue, node_states, state_factory, nodes_mapping, backend)
)
for _ in range(backend.num_workers)
]
# Create producer (adds TaskIns into Queue)
producer = asyncio.create_task(
add_taskins_to_queue(
queue, state_factory, nodes_mapping, backend, worker_tasks, f_stop
)
)

# Wait for producer to finish
# The producer runs forever until f_stop is set or until
# all worker (consumer) coroutines are completed. Workers
# also run forever and only end if an exception is raised.
await asyncio.gather(producer)

except Exception as ex:

log(ERROR, "An exception occured!! %s", ex)
log(ERROR, traceback.format_exc())
log(WARN, "Stopping Simulation Engine.")

# Manually trigger stopping event
f_stop.set()

# Raise exception
raise RuntimeError("Simulation Engine crashed.") from ex

finally:
# Produced task terminated, now cancel worker tasks
for w_t in worker_tasks:
_ = w_t.cancel()

while not all(w_t.done() for w_t in worker_tasks):
log(DEBUG, "Terminating async workers...")
await asyncio.sleep(0.5)

await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])

# Terminate backend
await backend.terminate()


# pylint: disable=too-many-arguments,unused-argument,too-many-locals
def start_vce(
num_supernodes: int,
client_app_module_name: str,
backend_name: str,
backend_config_json_stream: str,
state_factory: StateFactory,
working_dir: str,
f_stop: Optional[asyncio.Event] = None,
f_stop: asyncio.Event,
num_supernodes: Optional[int] = None,
state_factory: Optional[StateFactory] = None,
existing_nodes_mapping: Optional[NodeToPartitionMapping] = None,
) -> None:
"""Start Fleet API with the VirtualClientEngine (VCE)."""
# Register SuperNodes
nodes_mapping = _register_nodes(
num_nodes=num_supernodes, state_factory=state_factory
)
"""Start Fleet API with the Simulation Engine."""
if num_supernodes is not None and existing_nodes_mapping is not None:
raise ValueError(
"Both `num_supernodes` and `existing_nodes_mapping` are provided, "
"but only one is allowed."
)
if num_supernodes is None:
if state_factory is None or existing_nodes_mapping is None:
raise ValueError(
"If not passing an existing `state_factory` and associated "
"`existing_nodes_mapping` you must supply `num_supernodes` to indicate "
"how many nodes to insert into a new StateFactory that will be created."
)
if existing_nodes_mapping:
if state_factory is None:
raise ValueError(
"You passed `existing_nodes_mapping` but no `state_factory` was passed."
)
log(INFO, "Using exiting NodeToPartitionMapping and StateFactory.")
# Use mapping constructed externally. This also means nodes
# have previously being registered.
nodes_mapping = existing_nodes_mapping

if not state_factory:
log(INFO, "A StateFactory was not supplied to the SimulationEngine.")
# Create an empty in-memory state factory
state_factory = StateFactory(":flwr-in-memory-state:")
log(INFO, "Created new %s.", state_factory.__class__.__name__)

if num_supernodes:
# Register SuperNodes
nodes_mapping = _register_nodes(
num_nodes=num_supernodes, state_factory=state_factory
)

# Construct mapping of NodeStates
node_states: Dict[int, NodeState] = {}
Expand All @@ -69,7 +269,7 @@ def start_vce(

try:
backend_type = supported_backends[backend_name]
_ = backend_type(backend_config, work_dir=working_dir)
backend = backend_type(backend_config, work_dir=working_dir)
except KeyError as ex:
log(
ERROR,
Expand All @@ -89,4 +289,15 @@ def _load() -> ClientApp:
app: ClientApp = load_client_app(client_app_module_name)
return app

# start backend
app = _load

asyncio.run(
run(
app,
backend,
nodes_mapping,
state_factory,
node_states,
f_stop,
)
)
Loading