Skip to content

Commit

Permalink
Improve sampling speed
Browse files Browse the repository at this point in the history
* Track exhausted nodes
* Adjust the probabilities once there are no more samples from class k
  • Loading branch information
adam-narozniak committed Feb 29, 2024
1 parent 76d2fa9 commit 8c30794
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0
# Create class priors for the whole partitioning process
assert self._alpha is not None
class_priors = self._rng.dirichlet(alpha=self._alpha, size=self._num_partitions)
prior_cumsum = np.cumsum(class_priors, axis=1)
targets = np.asarray(self.dataset[self._partition_by])
# List representing indices of each class
assert self._num_unique_classes is not None
Expand All @@ -194,20 +193,33 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0
zip(range(self._num_partitions), self._partition_sizes)
)

not_full_node_ids = list(range(self._num_partitions))
while np.sum(list(node_id_to_left_to_allocate.values())) != 0:
# Choose a node
current_node_id = np.random.randint(self._num_partitions)
current_node_id = self._rng.choice(not_full_node_ids)
# If current node is full resample a client
if node_id_to_left_to_allocate[current_node_id] == 0:
# When the node is full, exclude it from the sampling nodes list
not_full_node_ids.pop(not_full_node_ids.index(current_node_id))
continue
node_id_to_left_to_allocate[current_node_id] -= 1
# Access the label distribution of the chose client
curr_prior = prior_cumsum[current_node_id]
# Access the label distribution of the chosen client
current_probabilities = class_priors[current_node_id]
while True:
curr_class = np.argmax(np.random.uniform() <= curr_prior)
# curr_class = np.argmax(np.random.uniform() <= curr_prior)
curr_class = self._rng.choice(
list(range(self._num_unique_classes)), p=current_probabilities
)
# Redraw class label if there are no samples left to allocated from
# that class
if class_sizes[curr_class] == 0:
# Class got exhausted, set probabilities to 0
class_priors[:, curr_class] = 0
# Renormalize such that the probability sums to 1
row_sums = class_priors.sum(axis=1, keepdims=True)
class_priors = class_priors / row_sums
# Adjust the current_probabilities (it won't sum up to 1 otherwise)
current_probabilities = class_priors[current_node_id]
continue
class_sizes[curr_class] -= 1
# Store sample index at the empty array cell
Expand Down

0 comments on commit 8c30794

Please sign in to comment.