diff --git a/nvflare/app_opt/flower/applet.py b/nvflare/app_opt/flower/applet.py index f34361bab5..63c63efc2d 100644 --- a/nvflare/app_opt/flower/applet.py +++ b/nvflare/app_opt/flower/applet.py @@ -26,9 +26,10 @@ class FlowerClientApplet(CLIApplet): - def __init__(self): + def __init__(self, extra_env: dict = None): """Constructor of FlowerClientApplet, which extends CLIApplet.""" CLIApplet.__init__(self) + self.extra_env = extra_env def get_command(self, ctx: dict) -> CommandDescriptor: """Implementation of the get_command method required by the super class CLIApplet. @@ -62,7 +63,9 @@ def get_command(self, ctx: dict) -> CommandDescriptor: # this is necessary for client_api to be used with the flower client app for metrics logging # client_api expects config info from the "config" folder in the cwd! self.logger.info(f"starting flower client app: {cmd}") - return CommandDescriptor(cmd=cmd, cwd=app_dir, log_file_name="client_app_log.txt", stdout_msg_prefix="FLWR-CA") + return CommandDescriptor( + cmd=cmd, cwd=app_dir, env=self.extra_env, log_file_name="client_app_log.txt", stdout_msg_prefix="FLWR-CA" + ) class FlowerServerApplet(Applet): diff --git a/nvflare/app_opt/flower/executor.py b/nvflare/app_opt/flower/executor.py index 2ee1d89e5a..88d8745b6f 100644 --- a/nvflare/app_opt/flower/executor.py +++ b/nvflare/app_opt/flower/executor.py @@ -15,6 +15,7 @@ from nvflare.app_common.tie.executor import TieExecutor from nvflare.app_opt.flower.applet import FlowerClientApplet from nvflare.app_opt.flower.connectors.grpc_client_connector import GrpcClientConnector +from nvflare.fuel.utils.validation_utils import check_object_type from .defs import Constant @@ -27,18 +28,33 @@ def __init__( per_msg_timeout=10.0, tx_timeout=100.0, client_shutdown_timeout=5.0, + extra_env: dict = None, ): + """FlowerExecutor constructor + + Args: + start_task_name: name of the "start" task + configure_task_name: name of the "config" task + per_msg_timeout: per-msg timeout for ReliableMessage + tx_timeout: transaction timeout for ReliableMessage + client_shutdown_timeout: how long to wait for graceful shutdown of the client + extra_env: extra env variables to be passed to client applet + """ TieExecutor.__init__( self, start_task_name=start_task_name, configure_task_name=configure_task_name, ) + if extra_env: + check_object_type("extra_env", extra_env, dict) + self.int_server_grpc_options = None self.per_msg_timeout = per_msg_timeout self.tx_timeout = tx_timeout self.client_shutdown_timeout = client_shutdown_timeout self.num_rounds = None + self.extra_env = extra_env def get_connector(self, fl_ctx: FLContext): return GrpcClientConnector( @@ -48,7 +64,7 @@ def get_connector(self, fl_ctx: FLContext): ) def get_applet(self, fl_ctx: FLContext): - return FlowerClientApplet() + return FlowerClientApplet(extra_env=self.extra_env) def configure(self, config: dict, fl_ctx: FLContext): self.num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS) diff --git a/nvflare/app_opt/flower/flower_job.py b/nvflare/app_opt/flower/flower_job.py index 0451bee409..97bcdbe899 100644 --- a/nvflare/app_opt/flower/flower_job.py +++ b/nvflare/app_opt/flower/flower_job.py @@ -47,7 +47,7 @@ def __init__( client_shutdown_timeout=5.0, stream_metrics=False, analytics_receiver=None, - client_api_type: str = "EX_PROCESS_API", + extra_env: dict = None, ): """ Flower Job. @@ -69,17 +69,13 @@ def __init__( client_shutdown_timeout (float, optional): Timeout for client shutdown. Defaults to 5.0 seconds. stream_metrics (bool, optional): Whether to stream metrics from Flower client to Flare analytics_receiver (AnalyticsReceiver, optional): the AnalyticsReceiver to use to process received metrics. - client_api_type (str, optional): Client API type, can choose from EX_PROCESS_API and IN_PROCESS_API + extra_env (dict, optional): optional extra env variables to be passed to Flower client """ if not os.path.isdir(flower_content): raise ValueError(f"{flower_content} is not a valid directory") super().__init__(name=name, min_clients=min_clients, mandatory_clients=mandatory_clients) - if client_api_type not in ["EX_PROCESS_API", "IN_PROCESS_API"]: - raise ValueError("Invalid client api type.") - os.environ["CLIENT_API_TYPE"] = client_api_type - controller = FlowerController( database=database, server_app_args=server_app_args, @@ -96,6 +92,7 @@ def __init__( per_msg_timeout=per_msg_timeout, tx_timeout=tx_timeout, client_shutdown_timeout=client_shutdown_timeout, + extra_env=extra_env, ) self.to_clients(executor) self.to_clients(obj=flower_content)