Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retire the example of quickstart-xgboost-horizontal #2801

Merged
merged 10 commits into from
Jan 18, 2024
2 changes: 1 addition & 1 deletion baselines/hfedxgboost/hfedxgboost/conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ client_resources:
num_gpus: 0.0

strategy:
_target_: flwr.server.strategy.FedXgbNnAvg
_target_: hfedxgboost.strategy.FedXgbNnAvg
_recursive_: true #everything to be instantiated
fraction_fit: 1.0
fraction_evaluate: 0.0 # no clients will be sampled for federated evaluation (we will still perform global evaluation)
Expand Down
74 changes: 74 additions & 0 deletions baselines/hfedxgboost/hfedxgboost/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,77 @@
Needed only when the strategy is not yet implemented in Flower or because you want to
extend or modify the functionality of an existing strategy.
"""
from logging import WARNING
from typing import Any, Dict, List, Optional, Tuple, Union

from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from flwr.server.strategy.aggregate import aggregate
from flwr.server.strategy import FedAvg


class FedXgbNnAvg(FedAvg):
"""Configurable FedXgbNnAvg strategy implementation."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Federated XGBoost [Ma et al., 2023] strategy.

Implementation based on https://arxiv.org/abs/2304.07537.
"""
super().__init__(*args, **kwargs)

def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
rep = f"FedXgbNnAvg(accept_failures={self.accept_failures})"
return rep

def evaluate(
self, server_round: int, parameters: Any
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_fn is None:
# No evaluation function provided
return None
eval_res = self.evaluate_fn(server_round, parameters, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics

def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Any], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(
parameters_to_ndarrays(fit_res.parameters[0].parameters), # type: ignore # noqa: E501 # pylint: disable=line-too-long
fit_res.num_examples,
)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))

# Aggregate XGBoost trees from all clients
trees_aggregated = [fit_res.parameters[1] for _, fit_res in results] # type: ignore # noqa: E501 # pylint: disable=line-too-long

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")

return [parameters_aggregated, trees_aggregated], metrics_aggregated
2 changes: 0 additions & 2 deletions examples/quickstart-xgboost-horizontal/.gitignore

This file was deleted.

19 changes: 0 additions & 19 deletions examples/quickstart-xgboost-horizontal/README.md

This file was deleted.

Loading