Skip to content

Commit

Permalink
Improve speed of NaturalIdPartitioner (#3276)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <[email protected]>
  • Loading branch information
adam-narozniak and jafermarq authored May 6, 2024
1 parent 6bcf1fc commit 977844f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 5 deletions.
72 changes: 67 additions & 5 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,52 @@

from typing import Dict

import numpy as np
from tqdm import tqdm

import datasets
from flwr_datasets.common.typing import NDArrayInt
from flwr_datasets.partitioner.partitioner import Partitioner


class NaturalIdPartitioner(Partitioner):
"""Partitioner for dataset that can be divided by a reference to id in dataset."""
"""Partitioner for a dataset that can be divided by a column with partition ids.
Parameters
----------
partition_by: str
The name of the column that contains the unique values of partitions.
Examples
--------
"flwrlabs/shakespeare" dataset
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import NaturalIdPartitioner
>>>
>>> partitioner = NaturalIdPartitioner(partition_by="character_id")
>>> fds = FederatedDataset(dataset="flwrlabs/shakespeare",
>>> partitioners={"train": partitioner})
>>> partition = fds.load_partition(0)
"sentiment140" (aka Twitter) dataset
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import NaturalIdPartitioner
>>>
>>> partitioner = NaturalIdPartitioner(partition_by="user")
>>> fds = FederatedDataset(dataset="sentiment140",
>>> partitioners={"train": partitioner})
>>> partition = fds.load_partition(0)
"""

def __init__(
self,
partition_by: str,
):
super().__init__()
self._partition_id_to_natural_id: Dict[int, str] = {}
self._natural_id_to_partition_id: Dict[str, int] = {}
self._partition_id_to_indices: Dict[int, NDArrayInt] = {}
self._partition_by = partition_by

def _create_int_partition_id_to_natural_id(self) -> None:
Expand All @@ -42,6 +75,33 @@ def _create_int_partition_id_to_natural_id(self) -> None:
zip(range(len(unique_natural_ids)), unique_natural_ids)
)

def _create_natural_id_to_int_partition_id(self) -> None:
"""Create a mapping from unique client ids from dataset to int indices.
Natural ids come from the column specified in `partition_by`. This object is
inverse of the `self._partition_id_to_natural_id`. This method assumes that
`self._partition_id_to_natural_id` already exist.
"""
self._natural_id_to_partition_id = {
value: key for key, value in self._partition_id_to_natural_id.items()
}

def _create_partition_id_to_indices(self) -> None:
natural_id_to_indices = {} # type: ignore
natural_ids = np.array(self.dataset[self._partition_by])

for index, natural_id in tqdm(
enumerate(natural_ids), desc="Generating partition_id_to_indices"
):
if natural_id not in natural_id_to_indices:
natural_id_to_indices[natural_id] = []
natural_id_to_indices[natural_id].append(index)

self._partition_id_to_indices = {
self._natural_id_to_partition_id[natural_id]: indices
for natural_id, indices in natural_id_to_indices.items()
}

def load_partition(self, partition_id: int) -> datasets.Dataset:
"""Load a single partition corresponding to a single `partition_id`.
Expand All @@ -60,17 +120,19 @@ def load_partition(self, partition_id: int) -> datasets.Dataset:
"""
if len(self._partition_id_to_natural_id) == 0:
self._create_int_partition_id_to_natural_id()
self._create_natural_id_to_int_partition_id()

return self.dataset.filter(
lambda row: row[self._partition_by]
== self._partition_id_to_natural_id[partition_id]
)
if len(self._partition_id_to_indices) == 0:
self._create_partition_id_to_indices()

return self.dataset.select(self._partition_id_to_indices[partition_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
if len(self._partition_id_to_natural_id) == 0:
self._create_int_partition_id_to_natural_id()
self._create_natural_id_to_int_partition_id()
return len(self._partition_id_to_natural_id)

@property
Expand Down
1 change: 1 addition & 0 deletions datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ datasets = "^2.14.6"
pillow = { version = ">=6.2.1", optional = true }
soundfile = { version = ">=0.12.1", optional = true }
librosa = { version = ">=0.10.0.post2", optional = true }
tqdm ="^4.66.1"

[tool.poetry.dev-dependencies]
isort = "==5.13.2"
Expand Down

0 comments on commit 977844f

Please sign in to comment.