Skip to content

Commit

Permalink
Support Aux Message and Object Streaming in SP and CP (NVIDIA#3068)
Browse files Browse the repository at this point in the history
* support SP/CP aux msg

* resolve merge conflict

* resolve conflict

* reformat

* fix process type setting

* add streaming to CP and SP

* removed unused args

* remove unsed self.fl_ctx_mgr

* address PR comments

* reformat
  • Loading branch information
yanchengnv authored Nov 23, 2024
1 parent 82fe663 commit 5534af5
Show file tree
Hide file tree
Showing 19 changed files with 471 additions and 181 deletions.
10 changes: 10 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ReturnCode(object):
BAD_REQUEST_DATA = "BAD_REQUEST_DATA"
BAD_TASK_DATA = "BAD_TASK_DATA"
COMMUNICATION_ERROR = "COMMUNICATION_ERROR"
TIMEOUT = "TIMEOUT"
ERROR = "ERROR"
EXECUTION_EXCEPTION = "EXECUTION_EXCEPTION"
EXECUTION_RESULT_ERROR = "EXECUTION_RESULT_ERROR"
Expand Down Expand Up @@ -104,6 +105,7 @@ class ReservedKey(object):
JOB_IS_UNSAFE = "__job_is_unsafe__"
CUSTOM_PROPS = "__custom_props__"
EXCEPTIONS = "__exceptions__"
PROCESS_TYPE = "__process_type__" # type of the current process (SP, CP, SJ, CJ)


class FLContextKey(object):
Expand Down Expand Up @@ -184,6 +186,14 @@ class FLContextKey(object):
CLIENT_CONFIG = "__client_config__"
SERVER_CONFIG = "__server_config__"
SERVER_HOST_NAME = "__server_host_name__"
PROCESS_TYPE = ReservedKey.PROCESS_TYPE


class ProcessType:
SERVER_PARENT = "SP"
SERVER_JOB = "SJ"
CLIENT_PARENT = "CP"
CLIENT_JOB = "CJ"


class ReservedTopic(object):
Expand Down
3 changes: 3 additions & 0 deletions nvflare/apis/fl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ def _simple_get(self, key: str, default=None):
def get_engine(self, default=None):
return self._simple_get(ReservedKey.ENGINE, default)

def get_process_type(self, default=None):
return self._simple_get(ReservedKey.PROCESS_TYPE, default)

def get_job_id(self, default=None):
return self._simple_get(ReservedKey.RUN_NUM, default)

Expand Down
58 changes: 0 additions & 58 deletions nvflare/apis/server_engine_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Dict, List, Optional, Tuple

from nvflare.apis.shareable import Shareable
from nvflare.apis.streaming import ConsumerFactory, ObjectProducer, StreamContext
from nvflare.widgets.widget import Widget

from .client import Client
Expand Down Expand Up @@ -165,63 +164,6 @@ def fire_and_forget_aux_request(
) -> dict:
return self.send_aux_request(targets, topic, request, 0.0, fl_ctx, optional, secure=secure)

@abstractmethod
def stream_objects(
self,
channel: str,
topic: str,
stream_ctx: StreamContext,
targets: List[str],
producer: ObjectProducer,
fl_ctx: FLContext,
optional=False,
secure=False,
):
"""Send a stream of Shareable objects to receivers.
Args:
channel: the channel for this stream
topic: topic of the stream
stream_ctx: context of the stream
targets: receiving sites
producer: the ObjectProducer that can produces the stream of Shareable objects
fl_ctx: the FLContext object
optional: whether the stream is optional
secure: whether to use P2P security
Returns: result from the generator's reply processing
"""
pass

@abstractmethod
def register_stream_processing(
self,
channel: str,
topic: str,
factory: ConsumerFactory,
stream_done_cb=None,
**cb_kwargs,
):
"""Register a ConsumerFactory for specified app channel and topic.
Once a new streaming request is received for the channel/topic, the registered factory will be used
to create an ObjectConsumer object to handle the new stream.
Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because
multiple streaming sessions could be going on at the same time. Each streaming session should have its
own ObjectConsumer.
Args:
channel: app channel
topic: app topic
factory: the factory to be registered
stream_done_cb: the callback to be called when streaming is done on receiving side
Returns: None
"""
pass

@abstractmethod
def get_widget(self, widget_id: str) -> Widget:
"""Get the widget with the specified ID.
Expand Down
4 changes: 3 additions & 1 deletion nvflare/apis/shareable.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class Shareable(dict):
It is recommended that keys are strings. Values must be serializable.
"""

def __init__(self):
def __init__(self, data: dict = None):
"""Init the Shareable."""
super().__init__()
if data:
self.update(data)
self[ReservedHeaderKey.HEADERS] = {}

def set_header(self, key: str, value):
Expand Down
72 changes: 71 additions & 1 deletion nvflare/apis/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from abc import ABC, abstractmethod
from builtins import dict as StreamContext
from typing import Any, Dict, Tuple
from typing import Any, Dict, List, Tuple

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
Expand Down Expand Up @@ -161,3 +161,73 @@ def stream_done_cb_signature(stream_ctx: StreamContext, fl_ctx: FLContext, **kwa
"""
pass


class StreamableEngine(ABC):
"""This class defines requirements for streaming capable engines."""

@abstractmethod
def stream_objects(
self,
channel: str,
topic: str,
stream_ctx: StreamContext,
targets: List[str],
producer: ObjectProducer,
fl_ctx: FLContext,
optional=False,
secure=False,
):
"""Send a stream of Shareable objects to receivers.
Args:
channel: the channel for this stream
topic: topic of the stream
stream_ctx: context of the stream
targets: receiving sites
producer: the ObjectProducer that can produces the stream of Shareable objects
fl_ctx: the FLContext object
optional: whether the stream is optional
secure: whether to use P2P security
Returns: result from the generator's reply processing
"""
pass

@abstractmethod
def register_stream_processing(
self,
channel: str,
topic: str,
factory: ConsumerFactory,
stream_done_cb=None,
**cb_kwargs,
):
"""Register a ConsumerFactory for specified app channel and topic.
Once a new streaming request is received for the channel/topic, the registered factory will be used
to create an ObjectConsumer object to handle the new stream.
Note: the factory should generate a new ObjectConsumer every time get_consumer() is called. This is because
multiple streaming sessions could be going on at the same time. Each streaming session should have its
own ObjectConsumer.
Args:
channel: app channel
topic: app topic
factory: the factory to be registered
stream_done_cb: the callback to be called when streaming is done on receiving side
Returns: None
"""
pass

@abstractmethod
def shutdown_streamer(self):
"""Shutdown the engine's streamer.
Returns: None
"""
pass
10 changes: 9 additions & 1 deletion nvflare/app_common/streamers/file_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamContext
from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext
from nvflare.fuel.utils.obj_utils import get_logger
from nvflare.fuel.utils.validation_utils import check_positive_int, check_positive_number

Expand Down Expand Up @@ -179,6 +179,9 @@ def register_stream_processing(
raise ValueError(f"dest_dir '{dest_dir}' is not a valid dir")

engine = fl_ctx.get_engine()
if not isinstance(engine, StreamableEngine):
raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}")

engine.register_stream_processing(
channel=channel,
topic=topic,
Expand Down Expand Up @@ -238,7 +241,12 @@ def stream_file(
with open(file_name, "rb") as file:
producer = _ChunkProducer(file, chunk_size, chunk_timeout)
engine = fl_ctx.get_engine()

if not isinstance(engine, StreamableEngine):
raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}")

stream_ctx[_KEY_FILE_NAME] = os.path.basename(file_name)

return engine.stream_objects(
channel=channel,
topic=topic,
Expand Down
42 changes: 27 additions & 15 deletions nvflare/private/aux_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from nvflare.apis.client import Client
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import ConfigVarName, ReturnCode, SystemConfigs
from nvflare.apis.fl_constant import ConfigVarName, ProcessType, ReturnCode, SystemConfigs
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply
from nvflare.fuel.f3.cellnet.core_cell import Message, MessageHeaderKey
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, engine):
def register_aux_message_handler(self, topic: str, message_handle_func):
"""Register aux message handling function with specified topics.
This method should be called by ServerEngine's register_aux_message_handler method.
This method should be called by Engine's register_aux_message_handler method.
Args:
topic: the topic to be handled by the func
Expand Down Expand Up @@ -196,7 +196,7 @@ def _process_cell_replies(
if cell_replies:
for reply_cell_fqcn, v in cell_replies.items():
assert isinstance(v, Message)
rc = v.get_header(MessageHeaderKey.RETURN_CODE, ReturnCode.OK)
rc = v.get_header(MessageHeaderKey.RETURN_CODE, CellReturnCode.OK)
target_name = fqcn_to_name[reply_cell_fqcn]
if rc == CellReturnCode.OK:
result = v.payload
Expand Down Expand Up @@ -258,25 +258,23 @@ def _send_multi_requests(
if not cell:
return {}

job_id = fl_ctx.get_job_id()
public_props = fl_ctx.get_all_public_props()
target_messages = {}
fqcn_to_name = {}
for t in target_requests:
msg_target, req = t
assert isinstance(msg_target, AuxMsgTarget)
target_name = msg_target.name
target_fqcn = msg_target.fqcn
if not isinstance(req, Shareable):
raise ValueError(f"request of {target_name} should be Shareable but got {type(req)}")

req.set_header(ReservedHeaderKey.TOPIC, topic)
req.set_peer_props(public_props)
job_cell_fqcn = FQCN.join([target_fqcn, job_id])
self.log_info(fl_ctx, f"sending multicast aux: {job_cell_fqcn=}")
fqcn_to_name[job_cell_fqcn] = target_name
target_messages[job_cell_fqcn] = TargetMessage(
topic=topic, channel=channel, target=job_cell_fqcn, message=Message(payload=req)
cell_fqcn = self._get_target_fqcn(msg_target, fl_ctx)
self.log_debug(fl_ctx, f"sending multicast aux: {cell_fqcn=}")
fqcn_to_name[cell_fqcn] = target_name
target_messages[cell_fqcn] = TargetMessage(
topic=topic, channel=channel, target=cell_fqcn, message=Message(payload=req)
)

if timeout > 0:
Expand Down Expand Up @@ -374,7 +372,6 @@ def _send_to_cell(
request.set_header(ReservedHeaderKey.TOPIC, topic)
request.set_peer_props(fl_ctx.get_all_public_props())

job_id = fl_ctx.get_job_id()
cell = self._wait_for_cell()
if not cell:
return {}
Expand All @@ -383,9 +380,9 @@ def _send_to_cell(
fqcn_to_name = {}
for t in targets:
# targeting job cells!
job_cell_fqcn = FQCN.join([t.fqcn, job_id])
target_fqcns.append(job_cell_fqcn)
fqcn_to_name[job_cell_fqcn] = t.name
cell_fqcn = self._get_target_fqcn(t, fl_ctx)
target_fqcns.append(cell_fqcn)
fqcn_to_name[cell_fqcn] = t.name

cell_msg = Message(payload=request)
if timeout > 0:
Expand All @@ -409,10 +406,25 @@ def _send_to_cell(
)
return {}

@staticmethod
def _get_target_fqcn(target: AuxMsgTarget, fl_ctx: FLContext):
process_type = fl_ctx.get_process_type()
if process_type in [ProcessType.CLIENT_PARENT, ProcessType.SERVER_PARENT]:
# parent process
return target.fqcn
elif process_type in [ProcessType.CLIENT_JOB, ProcessType.SERVER_JOB]:
# job process
job_id = fl_ctx.get_job_id()
if not job_id:
raise RuntimeError("no job ID in fl_ctx in Job Process!")
return FQCN.join([target.fqcn, job_id])
else:
raise RuntimeError(f"invalid process_type {process_type}")

@staticmethod
def _convert_return_code(rc):
rc_table = {
CellReturnCode.TIMEOUT: ReturnCode.COMMUNICATION_ERROR,
CellReturnCode.TIMEOUT: ReturnCode.TIMEOUT,
CellReturnCode.COMM_ERROR: ReturnCode.COMMUNICATION_ERROR,
CellReturnCode.PROCESS_EXCEPTION: ReturnCode.EXECUTION_EXCEPTION,
CellReturnCode.ABORT_RUN: CellReturnCode.ABORT_RUN,
Expand Down
2 changes: 2 additions & 0 deletions nvflare/private/fed/app/client/client_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def main(args):
print("Waiting client cell to be created ....")
time.sleep(1.0)

client_engine.initialize_comm(federated_client.cell)

with client_engine.new_context() as fl_ctx:
client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx)

Expand Down
Loading

0 comments on commit 5534af5

Please sign in to comment.