-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes for encrypted the GRPC backend.
- Run tests on the CI.
- Loading branch information
1 parent
09d32f1
commit ccba0d0
Showing
9 changed files
with
176 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# pylint: disable=unbalanced-tuple-unpacking, too-many-locals | ||
"""Tests for federated learning.""" | ||
|
||
import multiprocessing | ||
import os | ||
import subprocess | ||
import time | ||
|
||
from sklearn.datasets import dump_svmlight_file, load_svmlight_file | ||
from sklearn.model_selection import train_test_split | ||
|
||
import xgboost as xgb | ||
import xgboost.federated | ||
from xgboost import testing as tm | ||
|
||
SERVER_KEY = "server-key.pem" | ||
SERVER_CERT = "server-cert.pem" | ||
CLIENT_KEY = "client-key.pem" | ||
CLIENT_CERT = "client-cert.pem" | ||
|
||
|
||
def run_server(port: int, world_size: int, with_ssl: bool) -> None: | ||
"""Run federated server for test.""" | ||
if with_ssl: | ||
xgboost.federated.run_federated_server( | ||
world_size, | ||
port, | ||
server_key_path=SERVER_KEY, | ||
server_cert_path=SERVER_CERT, | ||
client_cert_path=CLIENT_CERT, | ||
) | ||
else: | ||
xgboost.federated.run_federated_server(world_size, port) | ||
|
||
|
||
def run_worker( | ||
port: int, world_size: int, rank: int, with_ssl: bool, device: str | ||
) -> None: | ||
"""Run federated client worker for test.""" | ||
communicator_env = { | ||
"dmlc_communicator": "federated", | ||
"federated_server_address": f"localhost:{port}", | ||
"federated_world_size": world_size, | ||
"federated_rank": rank, | ||
} | ||
if with_ssl: | ||
communicator_env["federated_server_cert_path"] = SERVER_CERT | ||
communicator_env["federated_client_key_path"] = CLIENT_KEY | ||
communicator_env["federated_client_cert_path"] = CLIENT_CERT | ||
|
||
cpu_count = os.cpu_count() | ||
assert cpu_count is not None | ||
n_threads = cpu_count // world_size | ||
|
||
# Always call this before using distributed module | ||
with xgb.collective.CommunicatorContext(**communicator_env): | ||
# Load file, file will not be sharded in federated mode. | ||
X, y = load_svmlight_file(f"agaricus.txt-{rank}.train") | ||
dtrain = xgb.DMatrix(X, y) | ||
X, y = load_svmlight_file(f"agaricus.txt-{rank}.test") | ||
dtest = xgb.DMatrix(X, y) | ||
|
||
# Specify parameters via map, definition are same as c++ version | ||
param = { | ||
"max_depth": 2, | ||
"eta": 1, | ||
"objective": "binary:logistic", | ||
"nthread": n_threads, | ||
"tree_method": "hist", | ||
"device": device, | ||
} | ||
|
||
# Specify validations set to watch performance | ||
watchlist = [(dtest, "eval"), (dtrain, "train")] | ||
num_round = 20 | ||
|
||
# Run training, all the features in training API is available. | ||
bst = xgb.train( | ||
param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2 | ||
) | ||
|
||
# Save the model, only ask process 0 to save the model. | ||
if xgb.collective.get_rank() == 0: | ||
bst.save_model("test.model.json") | ||
xgb.collective.communicator_print("Finished training\n") | ||
|
||
|
||
def run_federated(world_size: int, with_ssl: bool, use_gpu: bool) -> None: | ||
"""Launcher for clients and the server.""" | ||
port = 9091 | ||
|
||
server = multiprocessing.Process( | ||
target=run_server, args=(port, world_size, with_ssl) | ||
) | ||
server.start() | ||
time.sleep(1) | ||
if not server.is_alive(): | ||
raise ValueError("Error starting Federated Learning server") | ||
|
||
workers = [] | ||
for rank in range(world_size): | ||
device = f"cuda:{rank}" if use_gpu else "cpu" | ||
worker = multiprocessing.Process( | ||
target=run_worker, args=(port, world_size, rank, with_ssl, device) | ||
) | ||
workers.append(worker) | ||
worker.start() | ||
for worker in workers: | ||
worker.join() | ||
server.terminate() | ||
|
||
|
||
def run_federated_learning(with_ssl: bool, use_gpu: bool) -> None: | ||
"""Run federated learning tests.""" | ||
n_workers = 2 | ||
|
||
if with_ssl: | ||
command = "openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout {part}-key.pem -out {part}-cert.pem -subj /C=US/CN=localhost" # pylint: disable=line-too-long | ||
server_key = command.format(part="server").split() | ||
subprocess.check_call(server_key) | ||
client_key = command.format(part="client").split() | ||
subprocess.check_call(client_key) | ||
|
||
train_path = os.path.join(tm.data_dir(__file__), "agaricus.txt.train") | ||
test_path = os.path.join(tm.data_dir(__file__), "agaricus.txt.test") | ||
|
||
X_train, y_train = load_svmlight_file(train_path) | ||
X_test, y_test = load_svmlight_file(test_path) | ||
|
||
X0, X1, y0, y1 = train_test_split(X_train, y_train, test_size=0.5) | ||
X0_valid, X1_valid, y0_valid, y1_valid = train_test_split( | ||
X_test, y_test, test_size=0.5 | ||
) | ||
|
||
dump_svmlight_file(X0, y0, "agaricus.txt-0.train") | ||
dump_svmlight_file(X0_valid, y0_valid, "agaricus.txt-0.test") | ||
|
||
dump_svmlight_file(X1, y1, "agaricus.txt-1.train") | ||
dump_svmlight_file(X1_valid, y1_valid, "agaricus.txt-1.test") | ||
|
||
run_federated(world_size=n_workers, with_ssl=with_ssl, use_gpu=use_gpu) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 0 additions & 17 deletions
17
tests/test_distributed/test_federated/runtests-federated.sh
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,8 @@ | ||
#!/usr/bin/python | ||
import multiprocessing | ||
import sys | ||
import time | ||
import pytest | ||
|
||
import xgboost as xgb | ||
import xgboost.federated | ||
from xgboost.testing.federated import run_federated_learning | ||
|
||
SERVER_KEY = 'server-key.pem' | ||
SERVER_CERT = 'server-cert.pem' | ||
CLIENT_KEY = 'client-key.pem' | ||
CLIENT_CERT = 'client-cert.pem' | ||
|
||
|
||
def run_server(port: int, world_size: int, with_ssl: bool) -> None: | ||
if with_ssl: | ||
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, | ||
CLIENT_CERT) | ||
else: | ||
xgboost.federated.run_federated_server(port, world_size) | ||
|
||
|
||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None: | ||
communicator_env = { | ||
'xgboost_communicator': 'federated', | ||
'federated_server_address': f'localhost:{port}', | ||
'federated_world_size': world_size, | ||
'federated_rank': rank | ||
} | ||
if with_ssl: | ||
communicator_env['federated_server_cert'] = SERVER_CERT | ||
communicator_env['federated_client_key'] = CLIENT_KEY | ||
communicator_env['federated_client_cert'] = CLIENT_CERT | ||
|
||
# Always call this before using distributed module | ||
with xgb.collective.CommunicatorContext(**communicator_env): | ||
# Load file, file will not be sharded in federated mode. | ||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank) | ||
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank) | ||
|
||
# Specify parameters via map, definition are same as c++ version | ||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} | ||
if with_gpu: | ||
param['tree_method'] = 'hist' | ||
param['device'] = f"cuda:{rank}" | ||
|
||
# Specify validations set to watch performance | ||
watchlist = [(dtest, 'eval'), (dtrain, 'train')] | ||
num_round = 20 | ||
|
||
# Run training, all the features in training API is available. | ||
bst = xgb.train(param, dtrain, num_round, evals=watchlist, | ||
early_stopping_rounds=2) | ||
|
||
# Save the model, only ask process 0 to save the model. | ||
if xgb.collective.get_rank() == 0: | ||
bst.save_model("test.model.json") | ||
xgb.collective.communicator_print("Finished training\n") | ||
|
||
|
||
def run_federated(with_ssl: bool = True, with_gpu: bool = False) -> None: | ||
port = 9091 | ||
world_size = int(sys.argv[1]) | ||
|
||
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl)) | ||
server.start() | ||
time.sleep(1) | ||
if not server.is_alive(): | ||
raise Exception("Error starting Federated Learning server") | ||
|
||
workers = [] | ||
for rank in range(world_size): | ||
worker = multiprocessing.Process(target=run_worker, | ||
args=(port, world_size, rank, with_ssl, with_gpu)) | ||
workers.append(worker) | ||
worker.start() | ||
for worker in workers: | ||
worker.join() | ||
server.terminate() | ||
|
||
|
||
if __name__ == '__main__': | ||
run_federated(with_ssl=True, with_gpu=False) | ||
run_federated(with_ssl=False, with_gpu=False) | ||
run_federated(with_ssl=True, with_gpu=True) | ||
run_federated(with_ssl=False, with_gpu=True) | ||
@pytest.mark.parametrize("with_ssl", [True, False]) | ||
def test_federated_learning(with_ssl: bool) -> None: | ||
run_federated_learning(with_ssl, False) |
8 changes: 8 additions & 0 deletions
8
tests/test_distributed/test_gpu_federated/test_gpu_federated.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import pytest | ||
|
||
from xgboost.testing.federated import run_federated_learning | ||
|
||
|
||
@pytest.mark.parametrize("with_ssl", [True, False]) | ||
def test_federated_learning(with_ssl: bool) -> None: | ||
run_federated_learning(with_ssl, True) |