From 2c6ebb011f46139cf3fb62c24a1e636c21a29ff0 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 22 Aug 2024 12:27:56 +0200 Subject: [PATCH] Fix tests --- .../grouped_natural_id_partitioner.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py b/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py index d1b8e6edd137..531972264c4e 100644 --- a/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py +++ b/datasets/flwr_datasets/partitioner/grouped_natural_id_partitioner.py @@ -13,12 +13,12 @@ # limitations under the License. # ============================================================================== """Grouped natural id partitioner class that works with Hugging Face Datasets.""" -from typing import Dict, Literal, Tuple +from typing import Any, Dict, List, Literal import numpy as np -from common.typing import NDArrayInt import datasets +from flwr_datasets.common.typing import NDArrayInt from flwr_datasets.partitioner.partitioner import Partitioner @@ -69,8 +69,8 @@ def __init__( sort_unique_ids: bool = False, ) -> None: super().__init__() - self._partition_id_to_natural_ids: Dict[int, Tuple[str]] = {} - self._natural_ids_to_partition_id: Dict[Tuple[str], int] = {} + self._partition_id_to_natural_ids: Dict[int, List[Any]] = {} + self._natural_id_to_partition_id: Dict[Any, int] = {} self._partition_id_to_indices: Dict[int, NDArrayInt] = {} self._partition_by = partition_by self._mode = mode @@ -133,7 +133,7 @@ def _create_int_partition_id_to_natural_id(self) -> None: for group_of_natural_ids_id, group_of_natural_ids in zip( range(len(groups_of_natural_ids)), groups_of_natural_ids ): - self._partition_id_to_natural_ids[group_of_natural_ids_id] = tuple( + self._partition_id_to_natural_ids[group_of_natural_ids_id] = ( group_of_natural_ids.tolist() ) @@ -144,9 +144,10 @@ def _create_natural_id_to_int_partition_id(self) -> None: inverse of the `self._partition_id_to_natural_id`. This method assumes that `self._partition_id_to_natural_id` already exists. """ - self._natural_ids_to_partition_id = { - value: key for key, value in self._partition_id_to_natural_ids.items() - } + self._natural_id_to_partition_id = {} + for partition_id, natural_ids in self._partition_id_to_natural_ids.items(): + for natural_id in natural_ids: + self._natural_id_to_partition_id[natural_id] = partition_id def _create_partition_id_to_indices(self) -> None: natural_id_to_indices = {} # type: ignore @@ -199,15 +200,14 @@ def num_partitions(self) -> int: return len(self._partition_id_to_natural_ids) @property - def partition_id_to_natural_ids(self) -> Dict[int, Tuple[str]]: + def partition_id_to_natural_ids(self) -> Dict[int, List[Any]]: """Partition id to the corresponding group of natural ids present. Natural ids are the unique values in `partition_by` column in dataset. """ return self._partition_id_to_natural_ids - @partition_id_to_natural_ids.setter - def partition_id_to_natural_ids(self, value: Dict[int, Tuple[str]]) -> None: - raise AttributeError( - "Setting the partition_id_to_natural_ids dictionary is not allowed." - ) + @property + def natural_id_to_partition_id(self) -> Dict[Any, int]: + """Natural id to the corresponding partition id.""" + return self._natural_id_to_partition_id