Skip to content

Commit

Permalink
[2.5] Support one-way SSL (NVIDIA#3062)
Browse files Browse the repository at this point in the history
* experiment with 1-way ssl and utc logging

* split err log

* support one-way ssl

* added value check

* consolidate ssl_mode and connection_security

* fix test case

* address pr comments
  • Loading branch information
yanchengnv authored Nov 20, 2024
1 parent 1098fed commit 3d6846c
Show file tree
Hide file tree
Showing 16 changed files with 272 additions and 62 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ class SecureTrainConst:
SSL_ROOT_CERT = "ssl_root_cert"
SSL_CERT = "ssl_cert"
PRIVATE_KEY = "ssl_private_key"
CONNECTION_SECURITY = "connection_security"


class FLMetaKey:
Expand Down
22 changes: 22 additions & 0 deletions nvflare/app_common/utils/log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time


class UTCEnabler:
def __init__(self):
logging.Formatter.converter = time.gmtime
print("ENABLED UTC")
11 changes: 6 additions & 5 deletions nvflare/fuel/f3/cellnet/connector_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ class _Defaults:


class ConnectorData:
def __init__(self, handle, connect_url: str, active: bool):
def __init__(self, handle, connect_url: str, active: bool, params: dict):
self.handle = handle
self.connect_url = connect_url
self.active = active
self.params = params

def get_connection_url(self):
return self.connect_url
Expand Down Expand Up @@ -192,19 +193,19 @@ def _get_connector(

try:
if active:
handle = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required)
handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required)
connect_url = url
elif url:
handle = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required)
handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required)
connect_url = url
else:
self.logger.info(f"{os.getpid()}: Try start_listener Listener resources: {reqs}")
handle, connect_url = self.communicator.start_listener(scheme, reqs)
handle, connect_url, conn_params = self.communicator.start_listener(scheme, reqs)
self.logger.debug(f"{os.getpid()}: ############ dynamic listener at {connect_url}")
# Kludge: to wait for listener ready and avoid race
time.sleep(0.5)

return ConnectorData(handle, connect_url, active)
return ConnectorData(handle, connect_url, active, conn_params)
except CommError as ex:
self.logger.error(f"Failed to get connector: {secure_format_exception(ex)}")
return None
Expand Down
12 changes: 11 additions & 1 deletion nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.communicator import Communicator, MessageReceiver
from nvflare.fuel.f3.connection import Connection
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams
from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info
from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState
from nvflare.fuel.f3.message import Message
Expand Down Expand Up @@ -330,6 +330,16 @@ def __init__(
if err:
raise ValueError(f"Invalid FQCN '{fqcn}': {err}")

# Determine the value of 'secure' based on configured connection_security in credentials.
# If configured, use it; otherwise keep the original value of 'secure'.
conn_security = credentials.get(DriverParams.CONNECTION_SECURITY.value)
if conn_security:
if conn_security == ConnectionSecurity.INSECURE:
secure = False
else:
secure = True

self.logger.debug(f"connection secure: {secure}")
self.my_info = FqcnInfo(FQCN.normalize(fqcn))
self.secure = secure
self.logger.debug(f"{self.my_info.fqcn}: max_msg_size={self.max_msg_size}")
Expand Down
7 changes: 6 additions & 1 deletion nvflare/fuel/f3/cellnet/net_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,12 @@ def get_peers(self, target_fqcn: str) -> (Union[None, dict], List[str]):

@staticmethod
def _connector_info(info: ConnectorData) -> dict:
return {"url": info.connect_url, "handle": info.handle, "type": "connector" if info.active else "listener"}
return {
"url": info.connect_url,
"handle": info.handle,
"type": "connector" if info.active else "listener",
"params": info.params,
}

def _get_connectors(self) -> dict:
cell = self.cell
Expand Down
16 changes: 7 additions & 9 deletions nvflare/fuel/f3/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def register_message_receiver(self, app_id: int, receiver: MessageReceiver):

self.conn_manager.register_message_receiver(app_id, receiver)

def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str:
def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dict):
"""Load a connector. The driver is selected based on the URL
Args:
Expand All @@ -163,7 +163,7 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str:
secure: True if SSL is required.
Returns:
A handle that can be used to delete connector
A tuple of (A handle that can be used to delete connector, connector params)
Raises:
CommError: If any errors
Expand All @@ -177,17 +177,17 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str:
raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}")

params = parse_url(url)
return self.add_connector_advanced(driver_class(), mode, params, secure, False)
return self.add_connector_advanced(driver_class(), mode, params, secure, False), params

def start_listener(self, scheme: str, resources: dict) -> (str, str):
def start_listener(self, scheme: str, resources: dict) -> (str, str, dict):
"""Add and start a connector in passive mode on an address selected by the driver.
Args:
scheme: Connection scheme, e.g. http, https
resources: User specified resources like host and port ranges
Returns:
A tuple with connector handle and connect url
A tuple with connector handle and connect url, and connection params
Raises:
CommError: If any errors like invalid host or port not available
Expand All @@ -205,7 +205,7 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str):

handle = self.add_connector_advanced(driver_class(), Mode.PASSIVE, params, False, True)

return handle, connect_url
return handle, connect_url, params

def add_connector_advanced(
self, driver: Driver, mode: Mode, params: dict, secure: bool, start: bool = False
Expand All @@ -229,9 +229,7 @@ def add_connector_advanced(
if self.local_endpoint.conn_props:
params.update(self.local_endpoint.conn_props)

if secure:
params[DriverParams.SECURE] = secure

params[DriverParams.SECURE] = secure
handle = self.conn_manager.add_connector(driver, params, mode)

if not start:
Expand Down
9 changes: 9 additions & 0 deletions nvflare/fuel/f3/drivers/driver_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,21 @@ class DriverParams(str, Enum):
SERVER_KEY = "server_key"
CLIENT_CERT = "client_cert"
CLIENT_KEY = "client_key"
CONNECTION_SECURITY = "connection_security"
CUSTOM_CA_CERT = "custom_ca_cert"
SECURE = "secure"
PORTS = "ports"
SOCKET = "socket"
LOCAL_ADDR = "local_addr"
PEER_ADDR = "peer_addr"
PEER_CN = "peer_cn"
IMPLEMENTED_CONN_SEC = "implemented_conn_sec"


class ConnectionSecurity:
INSECURE = "insecure"
TLS = "tls"
MTLS = "mtls"


class DriverCap(str, Enum):
Expand Down
44 changes: 36 additions & 8 deletions nvflare/fuel/f3/drivers/grpc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import grpc

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams


def use_aio_grpc():
Expand All @@ -23,23 +23,51 @@ def use_aio_grpc():


def get_grpc_client_credentials(params: dict):
root_cert = _read_file(params.get(DriverParams.CA_CERT.value))
cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT))
private_key = _read_file(params.get(DriverParams.CLIENT_KEY))
return grpc.ssl_channel_credentials(
certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert
)
conn_security = params.get(DriverParams.CONNECTION_SECURITY.value, ConnectionSecurity.MTLS)
if conn_security == ConnectionSecurity.TLS:
# One-way SSL
# For one-way SSL, only CA cert is needed, and no need for client cert and key.
# We try to use custom CA cert if it's provided. This is because the client may connect to ALB or proxy
# that provides its CA cert to the client.
# If the custom CA cert is not provided, we'll use Flare provisioned CA cert.
params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Custom CA Cert used"
root_cert_file = params.get(DriverParams.CUSTOM_CA_CERT)
if not root_cert_file:
params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Flare CA Cert used"
root_cert_file = params.get(DriverParams.CA_CERT.value)
if not root_cert_file:
raise ValueError(f"cannot get CA cert for one-way SSL: {params}")
root_cert = _read_file(root_cert_file)
return grpc.ssl_channel_credentials(root_certificates=root_cert)
else:
# For two-way SSL, we always use our own provisioned certs.
# In the future, we may change to also support other ways to get cert and key.
params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client mTLS: Flare credentials used"
root_cert = _read_file(params.get(DriverParams.CA_CERT.value))
cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT))
private_key = _read_file(params.get(DriverParams.CLIENT_KEY))
return grpc.ssl_channel_credentials(
certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert
)


def get_grpc_server_credentials(params: dict):
root_cert = _read_file(params.get(DriverParams.CA_CERT.value))
cert_chain = _read_file(params.get(DriverParams.SERVER_CERT))
private_key = _read_file(params.get(DriverParams.SERVER_KEY))

conn_security = params.get(DriverParams.CONNECTION_SECURITY.value, ConnectionSecurity.MTLS)
require_client_auth = False if conn_security == ConnectionSecurity.TLS else True

if require_client_auth:
params[DriverParams.IMPLEMENTED_CONN_SEC] = "Server mTLS: client auth required"
else:
params[DriverParams.IMPLEMENTED_CONN_SEC] = "Server TLS: client auth not required"

return grpc.ssl_server_credentials(
[(private_key, cert_chain)],
root_certificates=root_cert,
require_client_auth=True,
require_client_auth=require_client_auth,
)


Expand Down
Loading

0 comments on commit 3d6846c

Please sign in to comment.