Skip to content

Commit

Permalink
Split indices in a better way
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jun 18, 2024
1 parent dd5addb commit 8633037
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions datasets/flwr_datasets/partitioner/probability_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Probability partitioner class that works with Hugging Face Datasets."""


import warnings
from typing import Dict, List, Optional, Union

import numpy as np
Expand All @@ -26,7 +25,6 @@


class ProbabilityPartitioner(Partitioner):

def __init__(
self,
probabilities: NDArrayFloat,
Expand Down Expand Up @@ -80,7 +78,6 @@ def num_partitions(self) -> int:
self._determine_partition_id_to_indices_if_needed()
return self._num_partitions


def _determine_partition_id_to_indices_if_needed(
self,
) -> None:
Expand All @@ -96,19 +93,28 @@ def _determine_partition_id_to_indices_if_needed(

for unique_label in self._unique_classes:
unique_label_to_indices[unique_label] = np.where(labels == unique_label)[0]
unique_label_to_size[unique_label] = len(unique_label_to_indices[unique_label])
unique_label_to_size[unique_label] = len(
unique_label_to_indices[unique_label]
)

self._partition_id_to_indices = {partition_id: [] for partition_id in range(self._num_partitions)}
self._partition_id_to_indices = {
partition_id: [] for partition_id in range(self._num_partitions)
}

for unique_label in self._unique_classes:
probabilities_per_label = self._probabilities[:, unique_label]
split_sizes = (unique_label_to_size[unique_label] * probabilities_per_label).astype(int)

beg_id = 0
split_sizes = (
unique_label_to_size[unique_label] * probabilities_per_label
).astype(int)
cumsum_division_numbers = np.cumsum(split_sizes)
indices_on_which_split = cumsum_division_numbers.astype(int)[:-1]
split_indices = np.split(
unique_label_to_indices[unique_label], indices_on_which_split
)
for partition_id in range(self._num_partitions):
end_id = beg_id + split_sizes[partition_id]
self._partition_id_to_indices[partition_id].extend(unique_label_to_indices[unique_label][beg_id:end_id].tolist())
beg_id = end_id
self._partition_id_to_indices[partition_id].extend(
split_indices[partition_id]
)

self._partition_id_to_indices_determined = True

Expand Down

0 comments on commit 8633037

Please sign in to comment.