Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz authored Mar 25, 2024
2 parents a655d9d + 385dc3c commit 4e5ba5d
Show file tree
Hide file tree
Showing 47 changed files with 884 additions and 255 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ All client controlled workflows must have a server side controller that extends
result_clients_policy: str = DefaultValuePolicy.ALL,
max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT,
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
private_p2p: bool = True,
):
Init args for ServerSideController
Expand Down Expand Up @@ -233,10 +234,11 @@ Cyclic Learning: Server Side Controller
starting_client: str = "",
max_status_report_interval: float = Constant.PER_CLIENT_STATUS_REPORT_TIMEOUT,
progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT,
rr_order: str = RROrder.FIXED,
private_p2p: bool = True,
cyclic_order: str = CyclicOrder.FIXED,
):
The only extra init arg is ``rr_order``, which specifies how the round-robin sequence is to be computed for each round: fixed order or random order.
The only extra init arg is ``cyclic_order``, which specifies how the cyclic sequence is to be computed for each round: fixed order or random order.

Of all the init args, only the ``num_rounds`` must be explicitly specified. All others can take default values:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class ServerCustomSecurityHandler(FLComponent):
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.CLIENT_REGISTERED:
if event_type == EventType.CLIENT_REGISTER_RECEIVED:
self.authenticate(fl_ctx=fl_ctx)

def authenticate(self, fl_ctx: FLContext):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"metadata": {},
"source": [
"\n",
"Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the streaming capability from the clients to the server with Tensorboard SummaryWriter sender syntax, but with a MLflow receiver\n",
"Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the streaming capability from the clients to the server with Tensorboard SummaryWriter sender syntax, but with a MLflow receiver\n",
"\n",
"> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"mlflow.note.content": "## **Hello PyTorch experiment with MLflow**"
},
"run_tags": {
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
}
},
"artifact_location": "artifacts"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"mlflow.note.content": "## **Hello PyTorch experiment with MLflow**"
},
"run_tags": {
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
}
},
"artifact_location": "artifacts"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"kwargs" : {
"project": "hello-pt-experiment",
"name": "hello-pt",
"notes": "Federated Experiment tracking with W&B \n Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server and deliver to MLFLow.\\n\\n> **_NOTE:_** \\n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n",
"notes": "Federated Experiment tracking with W&B \n Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server and deliver to MLFLow.\\n\\n> **_NOTE:_** \\n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n",
"tags": ["baseline", "paper1"],
"job_type": "train-validate",
"config": {
Expand Down
1 change: 0 additions & 1 deletion nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def stop_controller(self, fl_ctx: FLContext):
"""
pass

@abstractmethod
def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
):
Expand Down
11 changes: 10 additions & 1 deletion nvflare/apis/event_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class EventType(object):
JOB_COMPLETED = "_job_completed"
JOB_ABORTED = "_job_aborted"
JOB_CANCELLED = "_job_cancelled"
JOB_DEAD = "_job_dead"

BEFORE_PULL_TASK = "_before_pull_task"
AFTER_PULL_TASK = "_after_pull_task"
Expand All @@ -50,6 +51,8 @@ class EventType(object):
AFTER_TASK_EXECUTION = "_after_task_execution"
BEFORE_SEND_TASK_RESULT = "_before_send_task_result"
AFTER_SEND_TASK_RESULT = "_after_send_task_result"
BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK = "_before_process_result_of_unknown_task"
AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK = "_after_process_result_of_unknown_task"

CRITICAL_LOG_AVAILABLE = "_critical_log_available"
ERROR_LOG_AVAILABLE = "_error_log_available"
Expand All @@ -73,8 +76,14 @@ class EventType(object):

BEFORE_CLIENT_REGISTER = "_before_client_register"
AFTER_CLIENT_REGISTER = "_after_client_register"
CLIENT_REGISTERED = "_client_registered"
CLIENT_REGISTER_RECEIVED = "_client_register_received"
CLIENT_REGISTER_PROCESSED = "_client_register_processed"
CLIENT_QUIT = "_client_quit"
SYSTEM_BOOTSTRAP = "_system_bootstrap"
BEFORE_CLIENT_HEARTBEAT = "_before_client_heartbeat"
AFTER_CLIENT_HEARTBEAT = "_after_client_heartbeat"
CLIENT_HEARTBEAT_RECEIVED = "_client_heartbeat_received"
CLIENT_HEARTBEAT_PROCESSED = "_client_heartbeat_processed"

AUTHORIZE_COMMAND_CHECK = "_authorize_command_check"
BEFORE_BUILD_COMPONENT = "_before_build_component"
2 changes: 2 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ class FLContextKey(object):
COMMUNICATION_ERROR = "Flare_communication_error__"
UNAUTHENTICATED = "Flare_unauthenticated__"
CLIENT_RESOURCE_SPECS = "__client_resource_specs"
RESOURCE_CHECK_RESULT = "__resource_check_result"
JOB_PARTICIPANTS = "__job_participants"
JOB_BLOCK_REASON = "__job_block_reason" # why the job should be blocked from scheduling
SSID = "__ssid__"
CLIENT_TOKEN = "__client_token"
AUTHORIZATION_RESULT = "_authorization_result"
AUTHORIZATION_REASON = "_authorization_reason"
DEAD_JOB_CLIENT_NAME = "_dead_job_client_name"

CLIENT_REGISTER_DATA = "_client_register_data"
SECURITY_ITEMS = "_security_items"
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def initialize(self, fl_ctx: FLContext):
return

self._engine = engine
self.start_controller(fl_ctx)

def set_communicator(self, communicator: WFCommSpec):
if not isinstance(communicator, WFCommSpec):
Expand Down
34 changes: 16 additions & 18 deletions nvflare/apis/impl/wf_comm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, SendOrder, Task, TaskCompletionStatus
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
Expand Down Expand Up @@ -133,8 +134,8 @@ def _set_stats(self, fl_ctx: FLContext):
raise TypeError(
"collector must be an instance of GroupInfoCollector, but got {}".format(type(collector))
)
collector.set_info(
group_name=self._name,
collector.add_info(
group_name=self.controller._name,
info={
"tasks": {t.name: [ct.client.name for ct in t.client_tasks] for t in self._tasks},
},
Expand All @@ -149,6 +150,12 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
"""
if event_type == InfoCollector.EVENT_TYPE_GET_STATS:
self._set_stats(fl_ctx)
elif event_type == EventType.JOB_DEAD:
client_name = fl_ctx.get_prop(FLContextKey.DEAD_JOB_CLIENT_NAME)
with self._dead_clients_lock:
self.log_info(fl_ctx, f"received dead job report from client {client_name}")
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

def process_task_request(self, client: Client, fl_ctx: FLContext) -> Tuple[str, str, Shareable]:
"""Called by runner when a client asks for a task.
Expand Down Expand Up @@ -330,22 +337,6 @@ def handle_exception(self, task_id: str, fl_ctx: FLContext) -> None:
self.cancel_task(task=task, fl_ctx=fl_ctx)
self.log_error(fl_ctx, "task {} is cancelled due to exception".format(task.name))

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.
Args:
client_name: name of the client on which the job is dead
fl_ctx: the FLContext
"""
# record the report and to be used by the task monitor
with self._dead_clients_lock:
self.log_info(fl_ctx, f"received dead job report from client {client_name}")
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

self.controller.handle_dead_job(client_name, fl_ctx)

def process_task_check(self, task_id: str, fl_ctx: FLContext):
with self._task_lock:
# task_id is the uuid associated with the client_task
Expand Down Expand Up @@ -400,7 +391,14 @@ def _do_process_submission(
if client_task is None:
# cannot find a standing task for the submission
self.log_debug(fl_ctx, "no standing task found for {}:{}".format(task_name, task_id))

self.log_debug(fl_ctx, "firing event EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.BEFORE_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)

self.controller.process_result_of_unknown_task(client, task_name, task_id, result, fl_ctx)

self.log_debug(fl_ctx, "firing event EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK")
self.fire_event(EventType.AFTER_PROCESS_RESULT_OF_UNKNOWN_TASK, fl_ctx)
return

task = client_task.task
Expand Down
6 changes: 4 additions & 2 deletions nvflare/apis/server_engine_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,13 @@ def restore_components(self, snapshot: RunSnapshot, fl_ctx: FLContext):
pass

@abstractmethod
def start_client_job(self, job_id, client_sites):
def start_client_job(self, job_id, client_sites, fl_ctx: FLContext):
"""To send the start client run commands to the clients
Args:
client_sites: client sites
job_id: job_id
fl_ctx: FLContext
Returns:
Expand Down Expand Up @@ -187,14 +188,15 @@ def check_client_resources(

@abstractmethod
def cancel_client_resources(
self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict]
self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext
):
"""Cancels the request resources for the job.
Args:
resource_check_results: A dict of {client_name: client_check_result}
where client_check_result is a tuple of (is_resource_enough, resource reserve token if any)
resource_reqs: A dict of {client_name: resource requirements dict}
fl_ctx: FLContext
"""
pass

Expand Down
10 changes: 0 additions & 10 deletions nvflare/apis/wf_comm_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,6 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext):
"""
raise NotImplementedError

def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.
Args:
client_name: name of the client on which the job is dead
fl_ctx: the FLContext
"""
raise NotImplementedError

def initialize_run(self, fl_ctx: FLContext):
"""Called when a new RUN is about to start.
Expand Down
2 changes: 2 additions & 0 deletions nvflare/app_common/ccwf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class Constant:
FINAL_RESULT_ACK_TIMEOUT = 10
GET_MODEL_TIMEOUT = 10

PROP_KEY_TRAIN_CLIENTS = "cwf.train_clients"


class ModelType:

Expand Down
6 changes: 6 additions & 0 deletions nvflare/app_common/ccwf/cse_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,13 @@ def sub_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.current_round += 1

# ask everyone to eval everyone else's local model
train_clients = fl_ctx.get_prop(Constant.PROP_KEY_TRAIN_CLIENTS)
for c in self.evaluatees:
if train_clients and c not in train_clients:
# this client does not have local models
self.log_info(fl_ctx, f"ignore client {c} since it does not have local models")
continue

self._ask_to_evaluate(
current_round=self.current_round,
model_name=ModelName.BEST_MODEL,
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_common/ccwf/cyclic_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def start_workflow(self, shareable: Shareable, fl_ctx: FLContext, abort_signal:
clients = self.get_config_prop(Constant.CLIENTS)
# make sure the starting client is the 1st
rotate_to_front(self.me, clients)
rr_order = self.get_config_prop(Constant.ORDER)
self.log_info(fl_ctx, f"Starting cyclic workflow on clients {clients} with order {rr_order} ")
cyclic_order = self.get_config_prop(Constant.ORDER)
self.log_info(fl_ctx, f"Starting cyclic workflow on clients {clients} with order {cyclic_order} ")
self._set_task_headers(
task_data=shareable,
num_rounds=self.get_config_prop(AppConstants.NUM_ROUNDS),
Expand Down
4 changes: 3 additions & 1 deletion nvflare/app_common/ccwf/cyclic_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def __init__(
)
check_str("cyclic_order", cyclic_order)
if cyclic_order not in [CyclicOrder.FIXED, CyclicOrder.RANDOM]:
raise ValueError(f"invalid rr_order {cyclic_order}: must be in {[CyclicOrder.FIXED, CyclicOrder.RANDOM]}")
raise ValueError(
f"invalid cyclic_order {cyclic_order}: must be in {[CyclicOrder.FIXED, CyclicOrder.RANDOM]}"
)
self.cyclic_order = cyclic_order

def prepare_config(self):
Expand Down
4 changes: 4 additions & 0 deletions nvflare/app_common/ccwf/swarm_server_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,9 @@ def start_controller(self, fl_ctx: FLContext):
if c not in self.train_clients and c not in self.aggr_clients:
raise RuntimeError(f"Config Error: client {c} is neither train client nor aggr client")

# set train_clients as a sticky prop in fl_ctx
# in case CSE (cross site eval) workflow follows, it will know that only training clients have local models
fl_ctx.set_prop(key=Constant.PROP_KEY_TRAIN_CLIENTS, value=self.train_clients, private=True, sticky=True)

def prepare_config(self):
return {Constant.AGGR_CLIENTS: self.aggr_clients, Constant.TRAIN_CLIENTS: self.train_clients}
3 changes: 2 additions & 1 deletion nvflare/app_common/job_schedulers/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _cancel_resources(
if not isinstance(engine, ServerEngineSpec):
raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.")

engine.cancel_client_resources(resource_check_results, resource_reqs)
engine.cancel_client_resources(resource_check_results, resource_reqs, fl_ctx)
self.log_debug(fl_ctx, f"cancel client resources using check results: {resource_check_results}")
return False, None

Expand Down Expand Up @@ -165,6 +165,7 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp
return SCHEDULE_RESULT_NO_RESOURCE, None, block_reason

resource_check_results = self._check_client_resources(job=job, resource_reqs=resource_reqs, fl_ctx=fl_ctx)
fl_ctx.set_prop(FLContextKey.RESOURCE_CHECK_RESULT, resource_check_results, private=True, sticky=False)
self.fire_event(EventType.AFTER_CHECK_CLIENT_RESOURCES, fl_ctx)

if not resource_check_results:
Expand Down
Loading

0 comments on commit 4e5ba5d

Please sign in to comment.