Skip to content

Commit

Permalink
[2.5] Add authentication to Client API messages (NVIDIA#3103)
Browse files Browse the repository at this point in the history
* fix cell pipe auth headers

* remove unused code

* remove unused imports
  • Loading branch information
yanchengnv authored Dec 12, 2024
1 parent f0378ac commit f73e127
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 20 deletions.
8 changes: 8 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,14 @@ class FLMetaKey:
SITE_NAME = "site_name"
PROCESS_RC_FILE = "_process_rc.txt"
SUBMIT_MODEL_NAME = "submit_model_name"
AUTH_TOKEN = "auth_token"
AUTH_TOKEN_SIGNATURE = "auth_token_signature"


class CellMessageAuthHeaderKey:
CLIENT_NAME = "client_name"
TOKEN = "__token__"
TOKEN_SIGNATURE = "__token_signature__"


class FilterKey:
Expand Down
9 changes: 8 additions & 1 deletion nvflare/app_common/executors/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nvflare.app_common.executors.launcher_executor import LauncherExecutor
from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.fuel.data_event.utils import get_scope_property
from nvflare.fuel.utils.attributes_exportable import ExportMode


Expand Down Expand Up @@ -125,10 +126,16 @@ def prepare_config_for_launch(self, fl_ctx: FLContext):
ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout,
}

site_name = fl_ctx.get_identity_name()
auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")

config_data = {
ConfigKey.TASK_EXCHANGE: task_exchange_attributes,
FLMetaKey.SITE_NAME: fl_ctx.get_identity_name(),
FLMetaKey.SITE_NAME: site_name,
FLMetaKey.JOB_ID: fl_ctx.get_job_id(),
FLMetaKey.AUTH_TOKEN: auth_token,
FLMetaKey.AUTH_TOKEN_SIGNATURE: signature,
}

config_file_path = self._get_external_config_file_path(fl_ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
def _prepare_task_meta(self, fl_ctx, task_name):
job_id = fl_ctx.get_job_id()
site_name = fl_ctx.get_identity_name()

meta = {
FLMetaKey.SITE_NAME: site_name,
FLMetaKey.JOB_ID: job_id,
Expand Down
10 changes: 10 additions & 0 deletions nvflare/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from typing import Dict, Optional

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.fuel.utils.config_factory import ConfigFactory


Expand Down Expand Up @@ -155,6 +156,15 @@ def get_heartbeat_timeout(self):
self.config.get(ConfigKey.METRICS_EXCHANGE, {}).get(ConfigKey.HEARTBEAT_TIMEOUT, 60),
)

def get_site_name(self):
return self.config.get(FLMetaKey.SITE_NAME)

def get_auth_token(self):
return self.config.get(FLMetaKey.AUTH_TOKEN)

def get_auth_token_signature(self):
return self.config.get(FLMetaKey.AUTH_TOKEN_SIGNATURE)

def to_json(self, config_file: str):
with open(config_file, "w") as f:
json.dump(self.config, f, indent=2)
Expand Down
8 changes: 8 additions & 0 deletions nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nvflare.client.flare_agent import FlareAgentException
from nvflare.client.flare_agent_with_fl_model import FlareAgentWithFLModel
from nvflare.client.model_registry import ModelRegistry
from nvflare.fuel.data_event.utils import set_scope_property
from nvflare.fuel.utils import fobs
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.fuel.utils.obj_utils import get_logger
Expand All @@ -37,6 +38,13 @@ def _create_client_config(config: str) -> ClientConfig:
client_config = from_file(config_file=config)
else:
raise ValueError(f"config should be a string but got: {type(config)}")

# get message auth info and put them into Databus for CellPipe to use
auth_token = client_config.get_auth_token()
signature = client_config.get_auth_token_signature()
site_name = client_config.get_site_name()
set_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, value=auth_token)
set_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=signature)
return client_config


Expand Down
53 changes: 53 additions & 0 deletions nvflare/fuel/data_event/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from nvflare.fuel.utils.validation_utils import check_str

from .data_bus import DataBus


def _scope_prop_key(scope_name: str, key: str):
return f"{scope_name}::{key}"


def set_scope_property(scope_name: str, key: str, value: Any):
"""Save the specified property of the specified scope (globally).
Args:
scope_name: name of the scope
key: key of the property to be saved
value: value of property
Returns: None
"""
check_str("scope_name", scope_name)
check_str("key", key)
data_bus = DataBus()
data_bus.put_data(_scope_prop_key(scope_name, key), value)


def get_scope_property(scope_name: str, key: str, default=None) -> Any:
"""Get the value of a specified property from the specified scope.
Args:
scope_name: name of the scope
key: key of the scope
default: value to return if property is not found
Returns:
"""
check_str("scope_name", scope_name)
check_str("key", key)
data_bus = DataBus()
result = data_bus.get_data(_scope_prop_key(scope_name, key))
if result is None:
result = default
return result
21 changes: 20 additions & 1 deletion nvflare/fuel/utils/pipe/cell_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import time
from typing import Tuple, Union

from nvflare.apis.fl_constant import SystemVarName
from nvflare.apis.fl_constant import CellMessageAuthHeaderKey, FLMetaKey, SystemVarName
from nvflare.fuel.data_event.utils import get_scope_property
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.cell import Message as CellMessage
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
Expand Down Expand Up @@ -112,6 +113,9 @@ class CellPipe(Pipe):

_lock = threading.Lock()
_cells_info = {} # (root_url, site_name, token) => _CellInfo
_auth_token = None
_token_signature = None
_site_name = None

@classmethod
def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_dir):
Expand All @@ -131,6 +135,7 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di
"""
with cls._lock:
cls._site_name = site_name
cell_key = f"{root_url}.{site_name}.{token}"
ci = cls._cells_info.get(cell_key)
if not ci:
Expand All @@ -151,11 +156,25 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di
credentials=credentials,
create_internal_listener=False,
)

# set filter to add additional auth headers
cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=cls._add_auth_headers)
cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=cls._add_auth_headers)

net_agent = NetAgent(cell)
ci = _CellInfo(cell, net_agent)
cls._cells_info[cell_key] = ci
return ci

@classmethod
def _add_auth_headers(cls, message: CellMessage):
if not cls._auth_token:
cls._auth_token = get_scope_property(scope_name=cls._site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
cls._token_signature = get_scope_property(cls._site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")
message.set_header(CellMessageAuthHeaderKey.CLIENT_NAME, cls._site_name)
message.set_header(CellMessageAuthHeaderKey.TOKEN, cls._auth_token)
message.set_header(CellMessageAuthHeaderKey.TOKEN_SIGNATURE, cls._token_signature)

def __init__(
self,
mode: Mode,
Expand Down
7 changes: 4 additions & 3 deletions nvflare/private/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import uuid

# this import is to let existing scripts import from nvflare.private.defs
from nvflare.apis.fl_constant import CellMessageAuthHeaderKey
from nvflare.fuel.f3.cellnet.defs import CellChannel, CellChannelTopic, SSLConstants # noqa: F401
from nvflare.fuel.f3.message import Message
from nvflare.fuel.hci.server.constants import ConnProps
Expand Down Expand Up @@ -137,11 +138,11 @@ class AppFolderConstants:

class CellMessageHeaderKeys:

CLIENT_NAME = "client_name"
CLIENT_NAME = CellMessageAuthHeaderKey.CLIENT_NAME
TOKEN = CellMessageAuthHeaderKey.TOKEN
TOKEN_SIGNATURE = CellMessageAuthHeaderKey.TOKEN_SIGNATURE
CLIENT_IP = "client_ip"
PROJECT_NAME = "project_name"
TOKEN = "__token__"
TOKEN_SIGNATURE = "__token_signature__"
SSID = "ssid"
UNAUTHENTICATED = "unauthenticated"
JOB_ID = "job_id"
Expand Down
10 changes: 6 additions & 4 deletions nvflare/private/fed/client/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from nvflare.apis.event_type import EventType
from nvflare.apis.filter import Filter
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_constant import FLContextKey, FLMetaKey
from nvflare.apis.fl_constant import ReturnCode as ShareableRC
from nvflare.apis.fl_constant import SecureTrainConst, ServerCommandKey, ServerCommandNames
from nvflare.apis.fl_context import FLContext
Expand All @@ -34,7 +34,7 @@
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message
from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec
from nvflare.private.fed.utils.fed_utils import get_scope_prop
from nvflare.private.fed.utils.fed_utils import get_scope_prop, set_scope_prop
from nvflare.private.fed.utils.identity_utils import IdentityAsserter, IdentityVerifier, load_crt_bytes
from nvflare.security.logging import secure_format_exception

Expand Down Expand Up @@ -100,16 +100,18 @@ def __init__(
self.token_signature = None
self.ssid = None
self.client_name = None

self.logger = logging.getLogger(self.__class__.__name__)
self.logger.info(f"==== Communicator GOT CELL: {type(cell)}")

def set_auth(self, client_name, token, token_signature, ssid):
self.ssid = ssid
self.token_signature = token_signature
self.token = token
self.client_name = client_name

# put auth properties in database so that they can be used elsewhere
set_scope_prop(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN, value=token)
set_scope_prop(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=token_signature)

def set_cell(self, cell):
self.cell = cell

Expand Down
16 changes: 5 additions & 11 deletions nvflare/private/fed/utils/fed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@
from nvflare.apis.utils.decomposers import flare_decomposers
from nvflare.apis.workspace import Workspace
from nvflare.app_common.decomposers import common_decomposers
from nvflare.fuel.data_event.data_bus import DataBus
from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property
from nvflare.fuel.f3.stats_pool import CsvRecordHandler, StatsPoolManager
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.authz import AuthorizationService
from nvflare.fuel.sec.security_content_service import LoadResult, SecurityContentService
from nvflare.fuel.utils import fobs
from nvflare.fuel.utils.fobs.fobs import register_custom_folder
from nvflare.fuel.utils.validation_utils import check_str
from nvflare.private.defs import RequestHeader, SSLConstants
from nvflare.private.event import fire_event
from nvflare.private.fed.utils.decomposers import private_decomposers
Expand Down Expand Up @@ -415,20 +414,15 @@ def set_scope_prop(scope_name: str, key: str, value: Any):
value: value of property
Returns: None
"""
check_str("scope_name", scope_name)
check_str("key", key)
data_bus = DataBus()
data_bus.put_data(_scope_prop_key(scope_name, key), value)
set_scope_property(scope_name, key, value)


def get_scope_prop(scope_name: str, key: str) -> Any:
def get_scope_prop(scope_name: str, key: str, default=None) -> Any:
"""Get the value of a specified property from the specified scope.
Args:
scope_name: name of the scope
key: key of the scope
default: value to return if the prop is not found
Returns:
"""
check_str("scope_name", scope_name)
check_str("key", key)
data_bus = DataBus()
return data_bus.get_data(_scope_prop_key(scope_name, key))
return get_scope_property(scope_name, key, default)

0 comments on commit f73e127

Please sign in to comment.