diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 17fc263078a8..ac3beae66a9c 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -134,29 +134,23 @@ def divide_dataset( Use `divide_dataset` with division specified as a list. >>> from flwr_datasets import FederatedDataset - >>> from flwr_datasets.utils import concatenate_divisions + >>> from flwr_datasets.utils import divide_dataset >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) - >>> concatenated_divisions = concatenate_divisions( - ... partitioner=fds.partitioners["train"], - ... partition_division=[0.8, 0.2], - ... division_id=1 - ... ) - >>> print(concatenated_divisions) + >>> partition = fds.load_partition(0) + >>> division = [0.8, 0.2] + >>> train, test = divide_dataset(dataset=partition, division=division) Use `divide_dataset` with division specified as a dict (this accomplishes the same goal as the example with a list above). >>> from flwr_datasets import FederatedDataset - >>> from flwr_datasets.utils import concatenate_divisions + >>> from flwr_datasets.utils import divide_dataset >>> - >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) - >>> concatenated_divisions = concatenate_divisions( - ... partitioner=fds.partitioners["train"], - ... partition_division={"train": 0.8, "test": 0.2}, - ... division_id="test", - ... ) - >>> print(concatenated_divisions) + >>> partition = fds.load_partition(0) + >>> division = {"train": 0.8, "test": 0.2} + >>> train_test = divide_dataset(dataset=partition, division=division) + >>> train, test = train_test["train"], train_test["test"] """ _check_division_config_correctness(division) dataset_length = len(dataset)