Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-gao-GY committed Sep 18, 2024
1 parent 517c15a commit 6fb13bc
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 8 deletions.
1 change: 0 additions & 1 deletion examples/xgboost-comprehensive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ flwr run . --run-config "train-method='cyclic' partitioner-type='linear' central
> \[!NOTE\]
> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker.

## Expected Experimental Results

### Bagging aggregation experiment
Expand Down
1 change: 0 additions & 1 deletion examples/xgboost-comprehensive/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,3 @@ default = "local-simulation"
[tool.flwr.federations.local-simulation]
options.num-supernodes = 5
options.backend.client-resources.num-cpus = 2
options.backend.client-resources.num-gpus = 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,14 @@ def client_fn(context: Context):
centralised_eval_client = cfg["centralised_eval_client"]

# Load training and validation data
train_dmatrix, valid_dmatrix, num_train, num_val = load_data(partitioner_type, partition_id, num_partitions,
centralised_eval_client, test_fraction, seed)
train_dmatrix, valid_dmatrix, num_train, num_val = load_data(
partitioner_type,
partition_id,
num_partitions,
centralised_eval_client,
test_fraction,
seed,
)

# Setup learning rate
if cfg["scaled_lr"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def server_fn(context: Context):
config = ServerConfig(num_rounds=num_rounds)
client_manager = CyclicClientManager() if train_method == "cyclic" else None

return ServerAppComponents(strategy=strategy, config=config, client_manager=client_manager)
return ServerAppComponents(
strategy=strategy, config=config, client_manager=client_manager
)


# Create ServerApp
Expand Down
12 changes: 9 additions & 3 deletions examples/xgboost-comprehensive/xgboost_comprehensive/task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""xgboost_comprehensive: A Flower / XGBoost app."""

from logging import INFO
from typing import Union
from tqdm import tqdm
import xgboost as xgb
from datasets import Dataset, DatasetDict, concatenate_datasets

Expand All @@ -24,6 +22,7 @@

fds = None # Cache FederatedDataset


def train_test_split(partition, test_fraction, seed):
"""Split the data into train and validation set given split rate."""
train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
Expand Down Expand Up @@ -60,7 +59,14 @@ def instantiate_fds(partitioner_type, num_partitions):
return fds


def load_data(partitioner_type, partition_id, num_partitions, centralised_eval_client, test_fraction, seed):
def load_data(
partitioner_type,
partition_id,
num_partitions,
centralised_eval_client,
test_fraction,
seed,
):
"""Load partition data."""
fds_ = instantiate_fds(partitioner_type, num_partitions)
partition = fds_.load_partition(partition_id)
Expand Down

0 comments on commit 6fb13bc

Please sign in to comment.