From 977844fdf121211ee00ae83708287015add8869a Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 6 May 2024 13:42:08 +0200 Subject: [PATCH] Improve speed of NaturalIdPartitioner (#3276) Co-authored-by: jafermarq --- .../partitioner/natural_id_partitioner.py | 72 +++++++++++++++++-- datasets/pyproject.toml | 1 + 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py index 8bad0668595b..85f1b3af43c2 100644 --- a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py +++ b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py @@ -17,12 +17,43 @@ from typing import Dict +import numpy as np +from tqdm import tqdm + import datasets +from flwr_datasets.common.typing import NDArrayInt from flwr_datasets.partitioner.partitioner import Partitioner class NaturalIdPartitioner(Partitioner): - """Partitioner for dataset that can be divided by a reference to id in dataset.""" + """Partitioner for a dataset that can be divided by a column with partition ids. + + Parameters + ---------- + partition_by: str + The name of the column that contains the unique values of partitions. + + + Examples + -------- + "flwrlabs/shakespeare" dataset + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import NaturalIdPartitioner + >>> + >>> partitioner = NaturalIdPartitioner(partition_by="character_id") + >>> fds = FederatedDataset(dataset="flwrlabs/shakespeare", + >>> partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + + "sentiment140" (aka Twitter) dataset + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import NaturalIdPartitioner + >>> + >>> partitioner = NaturalIdPartitioner(partition_by="user") + >>> fds = FederatedDataset(dataset="sentiment140", + >>> partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + """ def __init__( self, @@ -30,6 +61,8 @@ def __init__( ): super().__init__() self._partition_id_to_natural_id: Dict[int, str] = {} + self._natural_id_to_partition_id: Dict[str, int] = {} + self._partition_id_to_indices: Dict[int, NDArrayInt] = {} self._partition_by = partition_by def _create_int_partition_id_to_natural_id(self) -> None: @@ -42,6 +75,33 @@ def _create_int_partition_id_to_natural_id(self) -> None: zip(range(len(unique_natural_ids)), unique_natural_ids) ) + def _create_natural_id_to_int_partition_id(self) -> None: + """Create a mapping from unique client ids from dataset to int indices. + + Natural ids come from the column specified in `partition_by`. This object is + inverse of the `self._partition_id_to_natural_id`. This method assumes that + `self._partition_id_to_natural_id` already exist. + """ + self._natural_id_to_partition_id = { + value: key for key, value in self._partition_id_to_natural_id.items() + } + + def _create_partition_id_to_indices(self) -> None: + natural_id_to_indices = {} # type: ignore + natural_ids = np.array(self.dataset[self._partition_by]) + + for index, natural_id in tqdm( + enumerate(natural_ids), desc="Generating partition_id_to_indices" + ): + if natural_id not in natural_id_to_indices: + natural_id_to_indices[natural_id] = [] + natural_id_to_indices[natural_id].append(index) + + self._partition_id_to_indices = { + self._natural_id_to_partition_id[natural_id]: indices + for natural_id, indices in natural_id_to_indices.items() + } + def load_partition(self, partition_id: int) -> datasets.Dataset: """Load a single partition corresponding to a single `partition_id`. @@ -60,17 +120,19 @@ def load_partition(self, partition_id: int) -> datasets.Dataset: """ if len(self._partition_id_to_natural_id) == 0: self._create_int_partition_id_to_natural_id() + self._create_natural_id_to_int_partition_id() - return self.dataset.filter( - lambda row: row[self._partition_by] - == self._partition_id_to_natural_id[partition_id] - ) + if len(self._partition_id_to_indices) == 0: + self._create_partition_id_to_indices() + + return self.dataset.select(self._partition_id_to_indices[partition_id]) @property def num_partitions(self) -> int: """Total number of partitions.""" if len(self._partition_id_to_natural_id) == 0: self._create_int_partition_id_to_natural_id() + self._create_natural_id_to_int_partition_id() return len(self._partition_id_to_natural_id) @property diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml index 7dfa60138582..c16389e1529b 100644 --- a/datasets/pyproject.toml +++ b/datasets/pyproject.toml @@ -58,6 +58,7 @@ datasets = "^2.14.6" pillow = { version = ">=6.2.1", optional = true } soundfile = { version = ">=0.12.1", optional = true } librosa = { version = ">=0.10.0.post2", optional = true } +tqdm ="^4.66.1" [tool.poetry.dev-dependencies] isort = "==5.13.2"