Skip to content

Commit

Permalink
[backport] Allow blocking launch of federated tracker. (#10414) (#10425)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jun 15, 2024
1 parent 6094106 commit 63b49f3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion plugin/example/custom_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class MyLogistic : public ObjFunction {

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("my_logistic");
out["name"] = String("mylogistic");
out["my_logistic_param"] = ToJson(param_);
}

Expand Down
18 changes: 16 additions & 2 deletions python-package/xgboost/federated.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,19 @@ def run_federated_server( # pylint: disable=too-many-arguments
server_key_path: Optional[str] = None,
server_cert_path: Optional[str] = None,
client_cert_path: Optional[str] = None,
blocking: bool = True,
timeout: int = 300,
) -> Dict[str, Any]:
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info."""
) -> Optional[Dict[str, Any]]:
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info.
Parameters
----------
blocking :
Block the server until the training is finished. If set to False, the function
launches an additional thread and returns the worker arguments. The default is
True and a higher level framework is responsible for setting worker parameters.
"""
args: Dict[str, Any] = {"n_workers": n_workers}
secure = all(
path is not None
Expand All @@ -78,6 +88,10 @@ def run_federated_server( # pylint: disable=too-many-arguments
)
tracker.start()

if blocking:
tracker.wait_for()
return None

thread = Thread(target=tracker.wait_for)
thread.daemon = True
thread.start()
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_federated_communicator():
world_size = 2
tracker = multiprocessing.Process(
target=federated.run_federated_server,
kwargs={"port": port, "n_workers": world_size},
kwargs={"port": port, "n_workers": world_size, "blocking": False},
)
tracker.start()
if not tracker.is_alive():
Expand Down

0 comments on commit 63b49f3

Please sign in to comment.