Skip to content

Commit

Permalink
Improve speed of NaturalIdPartitioner
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Apr 17, 2024
1 parent 58c47b5 commit 9369b81
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# ==============================================================================
"""Natural id partitioner class that works with Hugging Face Datasets."""


from typing import Dict

import numpy as np

import datasets
from flwr_datasets.common.typing import NDArrayInt
from flwr_datasets.partitioner.partitioner import Partitioner


Expand All @@ -30,6 +32,8 @@ def __init__(
):
super().__init__()
self._partition_id_to_natural_id: Dict[int, str] = {}
self._natural_id_to_partition_id: Dict[str, int] = {}
self._partition_id_to_indices: Dict[int, NDArrayInt] = {}
self._partition_by = partition_by

def _create_int_partition_id_to_natural_id(self) -> None:
Expand All @@ -42,6 +46,26 @@ def _create_int_partition_id_to_natural_id(self) -> None:
zip(range(len(unique_natural_ids)), unique_natural_ids)
)

def _create_natural_id_to_int_partition_id(self) -> None:
"""Create a mapping from unique client ids from dataset to int indices.
Natural ids come from the column specified in `partition_by`. This object is
inverse of the `self._partition_id_to_natural_id`. This method assumes that
`self._partition_id_to_natural_id` already exist.
"""
self._natural_id_to_partition_id = {
value: key for key, value in self._partition_id_to_natural_id.items()
}

def _create_partition_id_to_indices(self) -> None:
"""Create an assignment of indices to the partition indices."""
natural_ids = self.dataset[self._partition_by]
unique_natural_ids, inverse = np.unique(natural_ids, return_inverse=True)

for i, natural_id in enumerate(unique_natural_ids):
partition_id = self._natural_id_to_partition_id[natural_id]
self._partition_id_to_indices[partition_id] = np.where(inverse == i)[0]

def load_partition(self, partition_id: int) -> datasets.Dataset:
"""Load a single partition corresponding to a single `partition_id`.
Expand All @@ -60,17 +84,19 @@ def load_partition(self, partition_id: int) -> datasets.Dataset:
"""
if len(self._partition_id_to_natural_id) == 0:
self._create_int_partition_id_to_natural_id()
self._create_natural_id_to_int_partition_id()

return self.dataset.filter(
lambda row: row[self._partition_by]
== self._partition_id_to_natural_id[partition_id]
)
if len(self._partition_id_to_indices) == 0:
self._create_partition_id_to_indices()

return self.dataset.select(self._partition_id_to_indices[partition_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
if len(self._partition_id_to_natural_id) == 0:
self._create_int_partition_id_to_natural_id()
self._create_natural_id_to_int_partition_id()
return len(self._partition_id_to_natural_id)

@property
Expand Down

0 comments on commit 9369b81

Please sign in to comment.