Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve speed of NaturalIdPartitioner #3276

Merged
merged 19 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions datasets/flwr_datasets/partitioner/natural_id_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,52 @@

from typing import Dict

import numpy as np
from tqdm import tqdm

import datasets
from flwr_datasets.common.typing import NDArrayInt
from flwr_datasets.partitioner.partitioner import Partitioner


class NaturalIdPartitioner(Partitioner):
"""Partitioner for dataset that can be divided by a reference to id in dataset."""
"""Partitioner for dataset that can be divided by a reference to id in dataset.

Parameters
----------
partition_by: str
The name of the column that contains the unique values of partitions.


Examples
--------
"flwrlabs/shakespeare" dataset
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import NaturalIdPartitioner
>>>
>>> partitioner = NaturalIdPartitioner(partition_by="character_id")
>>> fds = FederatedDataset(dataset="flwrlabs/shakespeare",
>>> partitioners={"train": partitioner})
>>> partition = fds.load_partition(0)

"sentiment140" (aka Twitter) dataset
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import NaturalIdPartitioner
>>>
>>> partitioner = NaturalIdPartitioner(partition_by="user")
>>> fds = FederatedDataset(dataset="sentiment140",
>>> partitioners={"train": partitioner})
>>> partition = fds.load_partition(0)
"""

def __init__(
self,
partition_by: str,
):
super().__init__()
self._partition_id_to_natural_id: Dict[int, str] = {}
self._natural_id_to_partition_id: Dict[str, int] = {}
self._partition_id_to_indices: Dict[int, NDArrayInt] = {}
self._partition_by = partition_by

def _create_int_partition_id_to_natural_id(self) -> None:
Expand All @@ -42,6 +75,33 @@ def _create_int_partition_id_to_natural_id(self) -> None:
zip(range(len(unique_natural_ids)), unique_natural_ids)
)

def _create_natural_id_to_int_partition_id(self) -> None:
"""Create a mapping from unique client ids from dataset to int indices.

Natural ids come from the column specified in `partition_by`. This object is
inverse of the `self._partition_id_to_natural_id`. This method assumes that
`self._partition_id_to_natural_id` already exist.
"""
self._natural_id_to_partition_id = {
value: key for key, value in self._partition_id_to_natural_id.items()
}

def _create_partition_id_to_indices(self) -> None:
natural_id_to_indices = {} # type: ignore
natural_ids = np.array(self.dataset[self._partition_by])

for index, natural_id in tqdm(
enumerate(natural_ids), desc="Generating partition_id_to_indices"
):
if natural_id not in natural_id_to_indices:
natural_id_to_indices[natural_id] = []
natural_id_to_indices[natural_id].append(index)

self._partition_id_to_indices = {
self._natural_id_to_partition_id[natural_id]: indices
for natural_id, indices in natural_id_to_indices.items()
}

def load_partition(self, partition_id: int) -> datasets.Dataset:
"""Load a single partition corresponding to a single `partition_id`.

Expand All @@ -60,17 +120,19 @@ def load_partition(self, partition_id: int) -> datasets.Dataset:
"""
if len(self._partition_id_to_natural_id) == 0:
self._create_int_partition_id_to_natural_id()
self._create_natural_id_to_int_partition_id()

return self.dataset.filter(
lambda row: row[self._partition_by]
== self._partition_id_to_natural_id[partition_id]
)
if len(self._partition_id_to_indices) == 0:
self._create_partition_id_to_indices()

return self.dataset.select(self._partition_id_to_indices[partition_id])

@property
def num_partitions(self) -> int:
"""Total number of partitions."""
if len(self._partition_id_to_natural_id) == 0:
self._create_int_partition_id_to_natural_id()
self._create_natural_id_to_int_partition_id()
return len(self._partition_id_to_natural_id)

@property
Expand Down
1 change: 1 addition & 0 deletions datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ datasets = "^2.14.6"
pillow = { version = ">=6.2.1", optional = true }
soundfile = { version = ">=0.12.1", optional = true }
librosa = { version = ">=0.10.0.post2", optional = true }
tqdm ="^4.66.1"

[tool.poetry.dev-dependencies]
isort = "==5.13.2"
Expand Down