diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index b591d4d3e2ff..388865a26cf6 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -146,7 +146,8 @@ def divide_dataset( >>> division = [0.8, 0.2] >>> train, test = divide_dataset(dataset=partition, division=division) - Use `divide_dataset` with division specified as a dict. + Use `divide_dataset` with division specified as a dict + (this accomplishes the same goal as the example with a list above). >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.utils import divide_dataset @@ -273,11 +274,11 @@ def concatenate_divisions( partition_division: Union[List[float], Tuple[float, ...], Dict[str, float]], division_id: Union[int, str], ) -> Dataset: - """Create a dataset by concatenation of all partitions in the same division. + """Create a dataset by concatenation of divisions from all partitions. The divisions are created based on the `partition_division` and accessed based - on the `division_id`. It can be used to create e.g. centralized dataset from - federated on-edge test sets. + on the `division_id`. This fuction can be used to create e.g. centralized dataset + from federated on-edge test sets. Parameters ---------- @@ -298,6 +299,35 @@ def concatenate_divisions( ------- concatenated_divisions : Dataset A dataset created as concatenation of the divisions from all partitions. + + Examples + -------- + Use `concatenate_divisions` with division specified as a list. + + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import concatenate_divisions + >>> + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> concatenated_divisions = concatenate_divisions( + ... partitioner=fds.partitioners["train"], + ... partition_division=[0.8, 0.2], + ... division_id=1 + ... ) + >>> print(concatenated_divisions) + + Use `concatenate_divisions` with division specified as a dict. + This accomplishes the same goal as the example with a list above. + + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import concatenate_divisions + >>> + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> concatenated_divisions = concatenate_divisions( + ... partitioner=fds["train"], + ... partition_division={"train": 0.8, "test": 0.2}, + ... division_id="test" + ... ) + >>> print(concatenated_divisions) """ _check_division_config_correctness(partition_division) divisions = []