Skip to content

Commit

Permalink
Add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jun 5, 2024
1 parent 9eaefdc commit d841e02
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions datasets/flwr_datasets/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,31 @@ def compute_counts(
dataframe: pd.DataFrame
DataFrame where the rows represent the partition id and the column represent
the unique values found in column specified by `column_name`.
Examples
--------
Generate DataFrame with label counts resulting from DirichletPartitioner on cifar10
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import DirichletPartitioner
>>> from flwr_datasets.metrics import compute_counts
>>>
>>> fds = FederatedDataset(
>>> dataset="cifar10",
>>> partitioners={
>>> "train": DirichletPartitioner(
>>> num_partitions=20,
>>> partition_by="label",
>>> alpha=0.3,
>>> min_partition_size=0,
>>> ),
>>> },
>>> )
>>> partitioner = fds.partitioners["train"]
>>> counts_dataframe = compute_counts(
>>> partitioner=partitioner,
>>> column_name="label"
>>> )
"""
if column_name not in partitioner.dataset.column_names:
raise ValueError(
Expand Down Expand Up @@ -137,6 +162,31 @@ def compute_frequencies(
dataframe: pd.DataFrame
DataFrame where the rows represent the partition id and the column represent
the unique values found in column specified by `column_name`.
Examples
--------
Generate DataFrame with label counts resulting from DirichletPartitioner on cifar10
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import DirichletPartitioner
>>> from flwr_datasets.metrics import compute_frequencies
>>>
>>> fds = FederatedDataset(
>>> dataset="cifar10",
>>> partitioners={
>>> "train": DirichletPartitioner(
>>> num_partitions=20,
>>> partition_by="label",
>>> alpha=0.3,
>>> min_partition_size=0,
>>> ),
>>> },
>>> )
>>> partitioner = fds.partitioners["train"]
>>> counts_dataframe = compute_frequencies(
>>> partitioner=partitioner,
>>> column_name="label"
>>> )
"""
dataframe = compute_counts(
partitioner, column_name, verbose_names, max_num_partitions
Expand Down

0 comments on commit d841e02

Please sign in to comment.