Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Mar 13, 2024
1 parent d374bc4 commit a10f40d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
14 changes: 8 additions & 6 deletions datasets/flwr_datasets/partitioner/dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DirichletPartitioner(Partitioner):
The notion of balancing is explicitly introduced here (not mentioned in paper but
implemented in the code). It is a mechanism that excludes the partition from
assigning new samples to it if the current number of samples on that partition
exceeds the average number that the partition would get in case of even data
exceeds the average number that the partition would get in case of even data
distribution. It is controlled by`self_balancing` parameter.
Parameters
Expand Down Expand Up @@ -205,7 +205,9 @@ def _determine_partition_id_to_indices_if_needed(
self._unique_classes = self.dataset.unique(self._partition_by)
assert self._unique_classes is not None
# This is needed only if self._self_balancing is True (the default option)
self._avg_num_of_samples_per_partition = self.dataset.num_rows / self._num_partitions
self._avg_num_of_samples_per_partition = (
self.dataset.num_rows / self._num_partitions
)

# Change targets list data type to numpy
targets = np.array(self.dataset[self._partition_by])
Expand All @@ -232,10 +234,10 @@ def _determine_partition_id_to_indices_if_needed(
nid
]
# Balancing (not mentioned in the paper but implemented)
# Do not assign additional samples to the partition if it already has more
# than the average numbers of samples per partition. Note that it might
# especially affect classes that are later in the order. This is the
# reason for more sparse division that the alpha might suggest.
# Do not assign additional samples to the partition if it already has
# more than the average numbers of samples per partition. Note that it
# might especially affect classes that are later in the order. This is
# the reason for more sparse division that the alpha might suggest.
if self._self_balancing:
assert self._avg_num_of_samples_per_partition is not None
for nid in nid_to_proportion_of_k_samples.copy():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _determine_partition_id_to_indices_if_needed(
current_partition_id = self._rng.choice(not_full_partition_ids)
# If current partition is full resample a client
if partition_id_to_left_to_allocate[current_partition_id] == 0:
# When the partition is full, exclude it from the sampling partitions list
# When the partition is full, exclude it from the sampling list
not_full_partition_ids.pop(
not_full_partition_ids.index(current_partition_id)
)
Expand Down
16 changes: 9 additions & 7 deletions datasets/flwr_datasets/partitioner/shard_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902
a) `num_shards_per_partitions`, `shard_size`; or b) `num_shards_per_partition`)
In case of b the `shard_size` is calculated as floor(len(dataset) /
(`num_shards_per_partitions` * `num_partitions`))
2) possibly different number of shards per partition (use nearly all data) + the same
shard size (specify: `shard_size` + `keep_incomplete_shard=False`)
3) possibly different number of shards per partition (use all data) + possibly different
shard size (specify: `shard_size` + `keep_incomplete_shard=True`)
2) possibly different number of shards per partition (use nearly all data) + the
same shard size (specify: `shard_size` + `keep_incomplete_shard=False`)
3) possibly different number of shards per partition (use all data) + possibly
different shard size (specify: `shard_size` + `keep_incomplete_shard=True`)
Algorithm based on the description in Communication-Efficient Learning of Deep
Expand Down Expand Up @@ -88,8 +88,8 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902
Examples
--------
1) If you need same number of shards per partitions + the same shard size (and you know
both of these values)
1) If you need same number of shards per partitions + the same shard size (and you
know both of these values)
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import ShardPartitioner
>>>
Expand Down Expand Up @@ -149,7 +149,9 @@ def __init__( # pylint: disable=R0913
_check_if_natual_number(num_partitions, "num_partitions")
self._num_partitions = num_partitions
self._partition_by = partition_by
_check_if_natual_number(num_shards_per_partition, "num_shards_per_partition", True)
_check_if_natual_number(
num_shards_per_partition, "num_shards_per_partition", True
)
self._num_shards_per_partition = num_shards_per_partition
self._num_shards_used: Optional[int] = None
_check_if_natual_number(shard_size, "shard_size", True)
Expand Down

0 comments on commit a10f40d

Please sign in to comment.