From d841e02425b222608413b0900d0db5074a2bde5b Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 5 Jun 2024 15:15:54 +0200 Subject: [PATCH] Add examples --- datasets/flwr_datasets/metrics/utils.py | 50 +++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/datasets/flwr_datasets/metrics/utils.py b/datasets/flwr_datasets/metrics/utils.py index 43a2e88682e1..065e2daf26ec 100644 --- a/datasets/flwr_datasets/metrics/utils.py +++ b/datasets/flwr_datasets/metrics/utils.py @@ -52,6 +52,31 @@ def compute_counts( dataframe: pd.DataFrame DataFrame where the rows represent the partition id and the column represent the unique values found in column specified by `column_name`. + + Examples + -------- + Generate DataFrame with label counts resulting from DirichletPartitioner on cifar10 + + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import DirichletPartitioner + >>> from flwr_datasets.metrics import compute_counts + >>> + >>> fds = FederatedDataset( + >>> dataset="cifar10", + >>> partitioners={ + >>> "train": DirichletPartitioner( + >>> num_partitions=20, + >>> partition_by="label", + >>> alpha=0.3, + >>> min_partition_size=0, + >>> ), + >>> }, + >>> ) + >>> partitioner = fds.partitioners["train"] + >>> counts_dataframe = compute_counts( + >>> partitioner=partitioner, + >>> column_name="label" + >>> ) """ if column_name not in partitioner.dataset.column_names: raise ValueError( @@ -137,6 +162,31 @@ def compute_frequencies( dataframe: pd.DataFrame DataFrame where the rows represent the partition id and the column represent the unique values found in column specified by `column_name`. + + Examples + -------- + Generate DataFrame with label counts resulting from DirichletPartitioner on cifar10 + + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import DirichletPartitioner + >>> from flwr_datasets.metrics import compute_frequencies + >>> + >>> fds = FederatedDataset( + >>> dataset="cifar10", + >>> partitioners={ + >>> "train": DirichletPartitioner( + >>> num_partitions=20, + >>> partition_by="label", + >>> alpha=0.3, + >>> min_partition_size=0, + >>> ), + >>> }, + >>> ) + >>> partitioner = fds.partitioners["train"] + >>> counts_dataframe = compute_frequencies( + >>> partitioner=partitioner, + >>> column_name="label" + >>> ) """ dataframe = compute_counts( partitioner, column_name, verbose_names, max_num_partitions