Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Aug 22, 2024
1 parent 0b410a8 commit 2c6ebb0
Showing 1 changed file with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
# ==============================================================================
"""Grouped natural id partitioner class that works with Hugging Face Datasets."""
from typing import Dict, Literal, Tuple
from typing import Any, Dict, List, Literal

import numpy as np
from common.typing import NDArrayInt

import datasets
from flwr_datasets.common.typing import NDArrayInt
from flwr_datasets.partitioner.partitioner import Partitioner


Expand Down Expand Up @@ -69,8 +69,8 @@ def __init__(
sort_unique_ids: bool = False,
) -> None:
super().__init__()
self._partition_id_to_natural_ids: Dict[int, Tuple[str]] = {}
self._natural_ids_to_partition_id: Dict[Tuple[str], int] = {}
self._partition_id_to_natural_ids: Dict[int, List[Any]] = {}
self._natural_id_to_partition_id: Dict[Any, int] = {}
self._partition_id_to_indices: Dict[int, NDArrayInt] = {}
self._partition_by = partition_by
self._mode = mode
Expand Down Expand Up @@ -133,7 +133,7 @@ def _create_int_partition_id_to_natural_id(self) -> None:
for group_of_natural_ids_id, group_of_natural_ids in zip(
range(len(groups_of_natural_ids)), groups_of_natural_ids
):
self._partition_id_to_natural_ids[group_of_natural_ids_id] = tuple(
self._partition_id_to_natural_ids[group_of_natural_ids_id] = (
group_of_natural_ids.tolist()
)

Expand All @@ -144,9 +144,10 @@ def _create_natural_id_to_int_partition_id(self) -> None:
inverse of the `self._partition_id_to_natural_id`. This method assumes that
`self._partition_id_to_natural_id` already exists.
"""
self._natural_ids_to_partition_id = {
value: key for key, value in self._partition_id_to_natural_ids.items()
}
self._natural_id_to_partition_id = {}
for partition_id, natural_ids in self._partition_id_to_natural_ids.items():
for natural_id in natural_ids:
self._natural_id_to_partition_id[natural_id] = partition_id

def _create_partition_id_to_indices(self) -> None:
natural_id_to_indices = {} # type: ignore
Expand Down Expand Up @@ -199,15 +200,14 @@ def num_partitions(self) -> int:
return len(self._partition_id_to_natural_ids)

@property
def partition_id_to_natural_ids(self) -> Dict[int, Tuple[str]]:
def partition_id_to_natural_ids(self) -> Dict[int, List[Any]]:
"""Partition id to the corresponding group of natural ids present.
Natural ids are the unique values in `partition_by` column in dataset.
"""
return self._partition_id_to_natural_ids

@partition_id_to_natural_ids.setter
def partition_id_to_natural_ids(self, value: Dict[int, Tuple[str]]) -> None:
raise AttributeError(
"Setting the partition_id_to_natural_ids dictionary is not allowed."
)
@property
def natural_id_to_partition_id(self) -> Dict[Any, int]:
"""Natural id to the corresponding partition id."""
return self._natural_id_to_partition_id

0 comments on commit 2c6ebb0

Please sign in to comment.