diff --git a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py index 8bad0668595b..38698c58da56 100644 --- a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py +++ b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py @@ -14,10 +14,12 @@ # ============================================================================== """Natural id partitioner class that works with Hugging Face Datasets.""" - from typing import Dict +import numpy as np + import datasets +from flwr_datasets.common.typing import NDArrayInt from flwr_datasets.partitioner.partitioner import Partitioner @@ -30,6 +32,8 @@ def __init__( ): super().__init__() self._partition_id_to_natural_id: Dict[int, str] = {} + self._natural_id_to_partition_id: Dict[str, int] = {} + self._partition_id_to_indices: Dict[int, NDArrayInt] = {} self._partition_by = partition_by def _create_int_partition_id_to_natural_id(self) -> None: @@ -42,6 +46,26 @@ def _create_int_partition_id_to_natural_id(self) -> None: zip(range(len(unique_natural_ids)), unique_natural_ids) ) + def _create_natural_id_to_int_partition_id(self) -> None: + """Create a mapping from unique client ids from dataset to int indices. + + Natural ids come from the column specified in `partition_by`. This object is + inverse of the `self._partition_id_to_natural_id`. This method assumes that + `self._partition_id_to_natural_id` already exist. + """ + self._natural_id_to_partition_id = { + value: key for key, value in self._partition_id_to_natural_id.items() + } + + def _create_partition_id_to_indices(self) -> None: + """Create an assignment of indices to the partition indices.""" + natural_ids = self.dataset[self._partition_by] + unique_natural_ids, inverse = np.unique(natural_ids, return_inverse=True) + + for i, natural_id in enumerate(unique_natural_ids): + partition_id = self._natural_id_to_partition_id[natural_id] + self._partition_id_to_indices[partition_id] = np.where(inverse == i)[0] + def load_partition(self, partition_id: int) -> datasets.Dataset: """Load a single partition corresponding to a single `partition_id`. @@ -60,17 +84,19 @@ def load_partition(self, partition_id: int) -> datasets.Dataset: """ if len(self._partition_id_to_natural_id) == 0: self._create_int_partition_id_to_natural_id() + self._create_natural_id_to_int_partition_id() - return self.dataset.filter( - lambda row: row[self._partition_by] - == self._partition_id_to_natural_id[partition_id] - ) + if len(self._partition_id_to_indices) == 0: + self._create_partition_id_to_indices() + + return self.dataset.select(self._partition_id_to_indices[partition_id]) @property def num_partitions(self) -> int: """Total number of partitions.""" if len(self._partition_id_to_natural_id) == 0: self._create_int_partition_id_to_natural_id() + self._create_natural_id_to_int_partition_id() return len(self._partition_id_to_natural_id) @property