Skip to content

Commit

Permalink
support passing custom env vars for flower client (NVIDIA#2870)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv authored Aug 28, 2024
1 parent a16ff0a commit e1c0a74
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
7 changes: 5 additions & 2 deletions nvflare/app_opt/flower/applet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@


class FlowerClientApplet(CLIApplet):
def __init__(self):
def __init__(self, extra_env: dict = None):
"""Constructor of FlowerClientApplet, which extends CLIApplet."""
CLIApplet.__init__(self)
self.extra_env = extra_env

def get_command(self, ctx: dict) -> CommandDescriptor:
"""Implementation of the get_command method required by the super class CLIApplet.
Expand Down Expand Up @@ -62,7 +63,9 @@ def get_command(self, ctx: dict) -> CommandDescriptor:
# this is necessary for client_api to be used with the flower client app for metrics logging
# client_api expects config info from the "config" folder in the cwd!
self.logger.info(f"starting flower client app: {cmd}")
return CommandDescriptor(cmd=cmd, cwd=app_dir, log_file_name="client_app_log.txt", stdout_msg_prefix="FLWR-CA")
return CommandDescriptor(
cmd=cmd, cwd=app_dir, env=self.extra_env, log_file_name="client_app_log.txt", stdout_msg_prefix="FLWR-CA"
)


class FlowerServerApplet(Applet):
Expand Down
18 changes: 17 additions & 1 deletion nvflare/app_opt/flower/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from nvflare.app_common.tie.executor import TieExecutor
from nvflare.app_opt.flower.applet import FlowerClientApplet
from nvflare.app_opt.flower.connectors.grpc_client_connector import GrpcClientConnector
from nvflare.fuel.utils.validation_utils import check_object_type

from .defs import Constant

Expand All @@ -27,18 +28,33 @@ def __init__(
per_msg_timeout=10.0,
tx_timeout=100.0,
client_shutdown_timeout=5.0,
extra_env: dict = None,
):
"""FlowerExecutor constructor
Args:
start_task_name: name of the "start" task
configure_task_name: name of the "config" task
per_msg_timeout: per-msg timeout for ReliableMessage
tx_timeout: transaction timeout for ReliableMessage
client_shutdown_timeout: how long to wait for graceful shutdown of the client
extra_env: extra env variables to be passed to client applet
"""
TieExecutor.__init__(
self,
start_task_name=start_task_name,
configure_task_name=configure_task_name,
)

if extra_env:
check_object_type("extra_env", extra_env, dict)

self.int_server_grpc_options = None
self.per_msg_timeout = per_msg_timeout
self.tx_timeout = tx_timeout
self.client_shutdown_timeout = client_shutdown_timeout
self.num_rounds = None
self.extra_env = extra_env

def get_connector(self, fl_ctx: FLContext):
return GrpcClientConnector(
Expand All @@ -48,7 +64,7 @@ def get_connector(self, fl_ctx: FLContext):
)

def get_applet(self, fl_ctx: FLContext):
return FlowerClientApplet()
return FlowerClientApplet(extra_env=self.extra_env)

def configure(self, config: dict, fl_ctx: FLContext):
self.num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS)
Expand Down
9 changes: 3 additions & 6 deletions nvflare/app_opt/flower/flower_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
client_shutdown_timeout=5.0,
stream_metrics=False,
analytics_receiver=None,
client_api_type: str = "EX_PROCESS_API",
extra_env: dict = None,
):
"""
Flower Job.
Expand All @@ -69,17 +69,13 @@ def __init__(
client_shutdown_timeout (float, optional): Timeout for client shutdown. Defaults to 5.0 seconds.
stream_metrics (bool, optional): Whether to stream metrics from Flower client to Flare
analytics_receiver (AnalyticsReceiver, optional): the AnalyticsReceiver to use to process received metrics.
client_api_type (str, optional): Client API type, can choose from EX_PROCESS_API and IN_PROCESS_API
extra_env (dict, optional): optional extra env variables to be passed to Flower client
"""
if not os.path.isdir(flower_content):
raise ValueError(f"{flower_content} is not a valid directory")

super().__init__(name=name, min_clients=min_clients, mandatory_clients=mandatory_clients)

if client_api_type not in ["EX_PROCESS_API", "IN_PROCESS_API"]:
raise ValueError("Invalid client api type.")
os.environ["CLIENT_API_TYPE"] = client_api_type

controller = FlowerController(
database=database,
server_app_args=server_app_args,
Expand All @@ -96,6 +92,7 @@ def __init__(
per_msg_timeout=per_msg_timeout,
tx_timeout=tx_timeout,
client_shutdown_timeout=client_shutdown_timeout,
extra_env=extra_env,
)
self.to_clients(executor)
self.to_clients(obj=flower_content)
Expand Down

0 comments on commit e1c0a74

Please sign in to comment.