diff --git a/datasets/doc/source/how-to-use-with-local-data.rst b/datasets/doc/source/how-to-use-with-local-data.rst
new file mode 100644
index 000000000000..276f6d6936ee
--- /dev/null
+++ b/datasets/doc/source/how-to-use-with-local-data.rst
@@ -0,0 +1,257 @@
+Use with Local Data
+===================
+
+You can partition your local files and Python objects in
+``Flower Datasets`` library using any available ``Partitioner``.
+
+This guide details how to create a `Hugging Face `_ `Dataset `_ which is the required type of input for Partitioners.
+We will cover:
+
+* local files: CSV, JSON, image, audio,
+* in-memory data: dictionary, list, pd.DataFrame, np.ndarray.
+
+
+General Overview
+----------------
+An all-in-one dataset preparation (downloading, preprocessing, partitioning) happens
+using `FederatedDataset `_. However, we
+will use only the `Partitioner` here since we use locally accessible data.
+
+The rest of this guide will explain how to create a
+`Dataset `_
+from local files and existing (in memory) Python objects.
+
+Local Files
+-----------
+CSV
+^^^
+.. code-block:: python
+
+ from datasets import load_dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ # Single file
+ data_files = "path-to-my-file.csv"
+
+ # Multiple Files
+ data_files = [ "path-to-my-file-1.csv", "path-to-my-file-2.csv", ...]
+ dataset = load_dataset("csv", data_files=data_files)
+
+ # Divided Dataset
+ data_files = {
+ "train": single_train_file_or_list_of_files,
+ "test": single_test_file_or_list_of_files,
+ "can-have-more-splits": ...
+ }
+ dataset = load_dataset("csv", data_files=data_files)
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+JSON
+^^^^
+
+.. code-block:: python
+
+ from datasets import load_dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ # Single file
+ data_files = "path-to-my-file.json"
+
+ # Multitple Files
+ data_files = [ "path-to-my-file-1.json", "path-to-my-file-2.json", ...]
+ dataset = load_dataset("json", data_files=data_files)
+
+ # Divided Dataset
+ data_files = {
+ "train": single_train_file_or_list_of_files,
+ "test": single_test_file_or_list_of_files,
+ "can-have-more-splits": ...
+ }
+ dataset = load_dataset("json", data_files=data_files)
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+
+Image
+^^^^^
+You can create an image dataset in two ways:
+
+1) give a path the directory
+
+The directory needs to be structured in the following way: dataset-name/split/class/name. For example:
+
+.. code-block::
+
+ mnist/train/1/unique_name.png
+ mnist/train/1/unique_name.png
+ mnist/train/2/unique_name.png
+ ...
+ mnist/test/1/unique_name.png
+ mnist/test/1/unique_name.png
+ mnist/test/2/unique_name.png
+
+Then, the path you can give is `./mnist`.
+
+.. code-block:: python
+
+ from datasets import load_dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ # Directly from a directory
+ dataset = load_dataset("imagefolder", data_dir="/path/to/folder")
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+2) create a dataset from a CSV/JSON file and cast the path column to Image.
+
+.. code-block:: python
+
+ from datasets import Image, load_dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ dataset = load_dataset(...)
+ dataset = dataset.cast_column("path", Image())
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+
+Audio
+^^^^^
+Analogously to the image datasets, there are two methods here:
+
+1) give a path to the directory
+
+.. code-block:: python
+
+ from datasets import load_dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ dataset = load_dataset("audiofolder", data_dir="/path/to/folder")
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+2) create a dataset from a CSV/JSON file and cast the path column to Audio.
+
+.. code-block:: python
+
+ from datasets import Audio, load_dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ dataset = load_dataset(...)
+ dataset = dataset.cast_column("path", Audio())
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+In-Memory
+---------
+
+From dictionary
+^^^^^^^^^^^^^^^
+.. code-block:: python
+
+ from datasets import Dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+ data = {"features": [1, 2, 3], "labels": [0, 0, 1]}
+ dataset = Dataset.from_dict(data)
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+From list
+^^^^^^^^^
+.. code-block:: python
+
+ from datasets import Dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ my_list = [
+ {"features": 1, "labels": 0},
+ {"features": 2, "labels": 0},
+ {"features": 3, "labels": 1}
+ ]
+ dataset = Dataset.from_list(my_list)
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+From pd.DataFrame
+^^^^^^^^^^^^^^^^^
+.. code-block:: python
+
+ from datasets import Dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ data = {"features": [1, 2, 3], "labels": [0, 0, 1]}
+ df = pd.DataFrame(data)
+ dataset = Dataset.from_pandas(df)
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+From np.ndarray
+^^^^^^^^^^^^^^^
+The np.ndarray will be first transformed to pd.DataFrame
+
+.. code-block:: python
+
+ from datasets import Dataset
+ from flwr_datasets.partitioner import ChosenPartitioner
+
+ data = np.array([[1, 2, 3], [0, 0, 1]]).T
+ # You can add the column names by passing columns=["features", "labels"]
+ df = pd.DataFrame(data)
+ dataset = Dataset.from_pandas(df)
+
+ partitioner = ChosenPartitioner(...)
+ partitioner.dataset = dataset
+ partition = partitioner.load_partition(partition_id=0)
+
+Partitioner Details
+-------------------
+Partitioning is triggered automatically during the first ``load_partition`` call.
+You do not need to call any “do_partitioning” method.
+
+Partitioner abstraction is designed to allow for a single dataset assignment.
+
+.. code-block:: python
+
+ partitioner.dataset = your_dataset
+
+If you need to do the same partitioning on a different dataset, create a new Partitioner
+for that, e.g.:
+
+.. code-block:: python
+
+ from flwr_datasets.partitioner import IidPartitioner
+
+ iid_partitioner_for_mnist = IidPartitioner(num_partitions=10)
+ iid_partitioner_for_mnist.dataset = mnist_dataset
+
+ iid_partitioner_for_cifar = IidPartitioner(num_partitions=10)
+ iid_partitioner_for_cifar.dataset = cifar_dataset
+
+
+More Resources
+--------------
+If you are looking for more details or you have not found the format you are looking for, please visit the `HuggingFace Datasets docs `_.
+This guide is based on the following ones:
+
+* `General Information `_
+* `Tabular Data `_
+* `Image Data `_
+* `Audio Data `_
diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst
index fd226b308bd5..2144c527f8cd 100644
--- a/datasets/doc/source/index.rst
+++ b/datasets/doc/source/index.rst
@@ -31,6 +31,7 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal.
how-to-use-with-pytorch
how-to-use-with-tensorflow
how-to-use-with-numpy
+ how-to-use-with-local-data
how-to-disable-enable-progress-bar
References
diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py
index 55a7e597f6b4..6c41eaa3562f 100644
--- a/datasets/flwr_datasets/federated_dataset.py
+++ b/datasets/flwr_datasets/federated_dataset.py
@@ -59,7 +59,8 @@ class FederatedDataset:
argument. Defaults to True.
seed : Optional[int]
Seed used for dataset shuffling. It has no effect if `shuffle` is False. The
- seed cannot be set in the later stages.
+ seed cannot be set in the later stages. If `None`, then fresh, unpredictable entropy
+ will be pulled from the OS. Defaults to 42.
Examples
--------
diff --git a/examples/embedded-devices/Dockerfile b/examples/embedded-devices/Dockerfile
index a85c05c4bb7a..48602c89970a 100644
--- a/examples/embedded-devices/Dockerfile
+++ b/examples/embedded-devices/Dockerfile
@@ -8,7 +8,7 @@ RUN pip3 install --upgrade pip
# Install flower
RUN pip3 install flwr>=1.0
-RUN pip3 install flwr-datsets>=0.2
+RUN pip3 install flwr-datsets>=0.0.2
RUN pip3 install tqdm==4.65.0
WORKDIR /client
diff --git a/src/py/flwr/cli/example.py b/src/py/flwr/cli/example.py
index 625ca8729640..4790e72d48bf 100644
--- a/src/py/flwr/cli/example.py
+++ b/src/py/flwr/cli/example.py
@@ -39,7 +39,9 @@ def example() -> None:
with urllib.request.urlopen(examples_directory_url) as res:
data = json.load(res)
example_names = [
- item["path"] for item in data["tree"] if item["path"] not in [".gitignore"]
+ item["path"]
+ for item in data["tree"]
+ if item["path"] not in [".gitignore", "doc"]
]
example_name = prompt_options(
diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py
index 7eb47e3e3548..0c429ce34cf2 100644
--- a/src/py/flwr/cli/new/new.py
+++ b/src/py/flwr/cli/new/new.py
@@ -22,7 +22,12 @@
import typer
from typing_extensions import Annotated
-from ..utils import prompt_options, prompt_text
+from ..utils import (
+ is_valid_project_name,
+ prompt_options,
+ prompt_text,
+ sanitize_project_name,
+)
class MlFramework(str, Enum):
@@ -81,6 +86,16 @@ def new(
] = None,
) -> None:
"""Create new Flower project."""
+ if project_name is None:
+ project_name = prompt_text("Please provide project name")
+ if not is_valid_project_name(project_name):
+ project_name = prompt_text(
+ "Please provide a name that only contains "
+ "characters in {'_', 'a-zA-Z', '0-9'}",
+ predicate=is_valid_project_name,
+ default=sanitize_project_name(project_name),
+ )
+
print(
typer.style(
f"🔨 Creating Flower project {project_name}...",
@@ -89,9 +104,6 @@ def new(
)
)
- if project_name is None:
- project_name = prompt_text("Please provide project name")
-
if framework is not None:
framework_str = str(framework.value)
else:
@@ -116,7 +128,6 @@ def new(
# List of files to render
files = {
"README.md": {"template": "app/README.md.tpl"},
- "requirements.txt": {"template": f"app/requirements.{framework_str}.txt.tpl"},
"flower.toml": {"template": "app/flower.toml.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"},
diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py
index cedcb09b7755..11620d234191 100644
--- a/src/py/flwr/cli/new/new_test.py
+++ b/src/py/flwr/cli/new/new_test.py
@@ -66,7 +66,6 @@ def test_new(tmp_path: str) -> None:
project_name = "FedGPT"
framework = MlFramework.PYTORCH
expected_files_top_level = {
- "requirements.txt",
"fedgpt",
"README.md",
"flower.toml",
diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl
index 15d8211a1a25..9701c62af6f0 100644
--- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl
+++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl
@@ -1,19 +1,19 @@
[build-system]
-requires = ["poetry-core>=1.4.0"]
-build-backend = "poetry.core.masonry.api"
+requires = ["hatchling"]
+build-backend = "hatchling.build"
-[tool.poetry]
+[project]
name = "$project_name"
version = "1.0.0"
description = ""
-license = "Apache-2.0"
authors = [
- "The Flower Authors ",
+ { name = "The Flower Authors", email = "hello@flower.ai" },
+]
+license = {text = "Apache License (2.0)"}
+dependencies = [
+ "flwr[simulation]>=1.8.0,<2.0",
+ "numpy>=1.21.0",
]
-readme = "README.md"
-[tool.poetry.dependencies]
-python = "^3.9"
-# Mandatory dependencies
-numpy = "^1.21.0"
-flwr = { version = "^1.8.0", extras = ["simulation"] }
+[tool.hatch.build.targets.wheel]
+packages = ["."]
diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl
index 46a5508fe2ac..0661c7b730c1 100644
--- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl
+++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl
@@ -1,21 +1,21 @@
[build-system]
-requires = ["poetry-core>=1.4.0"]
-build-backend = "poetry.core.masonry.api"
+requires = ["hatchling"]
+build-backend = "hatchling.build"
-[tool.poetry]
+[project]
name = "$project_name"
version = "1.0.0"
description = ""
-license = "Apache-2.0"
authors = [
- "The Flower Authors ",
+ { name = "The Flower Authors", email = "hello@flower.ai" },
+]
+license = {text = "Apache License (2.0)"}
+dependencies = [
+ "flwr[simulation]>=1.8.0,<2.0",
+ "flwr-datasets[vision]>=0.0.2,<1.0.0",
+ "torch==2.2.1",
+ "torchvision==0.17.1",
]
-readme = "README.md"
-[tool.poetry.dependencies]
-python = "^3.9"
-# Mandatory dependencies
-flwr = { version = "^1.8.0", extras = ["simulation"] }
-flwr-datasets = { version = "0.0.2", extras = ["vision"] }
-torch = "2.2.1"
-torchvision = "0.17.1"
+[tool.hatch.build.targets.wheel]
+packages = ["."]
diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl
index f7383a78b7d5..5a017eb6ed74 100644
--- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl
+++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl
@@ -1,21 +1,20 @@
[build-system]
-requires = ["poetry-core>=1.4.0"]
-build-backend = "poetry.core.masonry.api"
+requires = ["hatchling"]
+build-backend = "hatchling.build"
-[tool.poetry]
+[project]
name = "$project_name"
version = "1.0.0"
description = ""
-license = "Apache-2.0"
authors = [
- "The Flower Authors ",
+ { name = "The Flower Authors", email = "hello@flower.ai" },
+]
+license = {text = "Apache License (2.0)"}
+dependencies = [
+ "flwr[simulation]>=1.8.0,<2.0",
+ "flwr-datasets[vision]>=0.0.2,<1.0.0",
+ "tensorflow>=2.11.1",
]
-readme = "README.md"
-[tool.poetry.dependencies]
-python = ">=3.9,<3.11"
-# Mandatory dependencies
-flwr = { version = "^1.8.0", extras = ["simulation"] }
-flwr-datasets = { version = "^0.0.2", extras = ["vision"] }
-tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" }
-tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" }
+[tool.hatch.build.targets.wheel]
+packages = ["."]
diff --git a/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl
deleted file mode 100644
index 4b460798e96f..000000000000
--- a/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl
+++ /dev/null
@@ -1,2 +0,0 @@
-flwr>=1.8, <2.0
-numpy>=1.21.0
diff --git a/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl
deleted file mode 100644
index f20b9d71e339..000000000000
--- a/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl
+++ /dev/null
@@ -1,4 +0,0 @@
-flwr[simulation]>=1.8.0
-flwr-datasets[vision]==0.0.2
-torch==2.2.1
-torchvision==0.17.1
diff --git a/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl
deleted file mode 100644
index b6fb49a4bbcb..000000000000
--- a/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl
+++ /dev/null
@@ -1,4 +0,0 @@
-flwr>=1.8, <2.0
-flwr-datasets[vision]>=0.0.2, <1.0.0
-tensorflow-macos>=2.9.1, !=2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64"
-tensorflow-cpu>=2.9.1, !=2.11.1 ; platform_machine == "x86_64"
diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py
index 4e86f0c3b8c8..7a36c3eb7b84 100644
--- a/src/py/flwr/cli/utils.py
+++ b/src/py/flwr/cli/utils.py
@@ -14,18 +14,23 @@
# ==============================================================================
"""Flower command line interface utils."""
-from typing import List, cast
+from typing import Callable, List, Optional, cast
import typer
-def prompt_text(text: str) -> str:
+def prompt_text(
+ text: str,
+ predicate: Callable[[str], bool] = lambda _: True,
+ default: Optional[str] = None,
+) -> str:
"""Ask user to enter text input."""
while True:
result = typer.prompt(
- typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True)
+ typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True),
+ default=default,
)
- if len(result) > 0:
+ if predicate(result) and len(result) > 0:
break
print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True))
@@ -65,3 +70,54 @@ def prompt_options(text: str, options: List[str]) -> str:
result = options[int(index)]
return result
+
+
+def is_valid_project_name(name: str) -> bool:
+ """Check if the given string is a valid Python module name.
+
+ A valid module name must start with a letter or an underscore, and can only contain
+ letters, digits, and underscores.
+ """
+ if not name:
+ return False
+
+ # Check if the first character is a letter or underscore
+ if not (name[0].isalpha() or name[0] == "_"):
+ return False
+
+ # Check if the rest of the characters are valid (letter, digit, or underscore)
+ for char in name[1:]:
+ if not (char.isalnum() or char == "_"):
+ return False
+
+ return True
+
+
+def sanitize_project_name(name: str) -> str:
+ """Sanitize the given string to make it a valid Python module name.
+
+ This version replaces hyphens with underscores, removes any characters not allowed
+ in Python module names, makes the string lowercase, and ensures it starts with a
+ valid character.
+ """
+ # Replace '-' with '_'
+ name_with_underscores = name.replace("-", "_").replace(" ", "_")
+
+ # Allowed characters in a module name: letters, digits, underscore
+ allowed_chars = set(
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
+ )
+
+ # Make the string lowercase
+ sanitized_name = name_with_underscores.lower()
+
+ # Remove any characters not allowed in Python module names
+ sanitized_name = "".join(c for c in sanitized_name if c in allowed_chars)
+
+ # Ensure the first character is a letter or underscore
+ if sanitized_name and (
+ sanitized_name[0].isdigit() or sanitized_name[0] not in allowed_chars
+ ):
+ sanitized_name = "_" + sanitized_name
+
+ return sanitized_name
diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py
index 1720405ab867..4fa9c80c6cdf 100644
--- a/src/py/flwr/client/app.py
+++ b/src/py/flwr/client/app.py
@@ -456,12 +456,13 @@ def _load_client_app() -> ClientApp:
continue
log(INFO, "")
- log(
- INFO,
- "[RUN %s, ROUND %s]",
- message.metadata.run_id,
- message.metadata.group_id,
- )
+ if len(message.metadata.group_id) > 0:
+ log(
+ INFO,
+ "[RUN %s, ROUND %s]",
+ message.metadata.run_id,
+ message.metadata.group_id,
+ )
log(
INFO,
"Received: %s message %s",
diff --git a/src/py/flwr/client/heartbeat.py b/src/py/flwr/client/heartbeat.py
index 0cc979ddfd13..b68e6163cc01 100644
--- a/src/py/flwr/client/heartbeat.py
+++ b/src/py/flwr/client/heartbeat.py
@@ -66,7 +66,9 @@ def start_ping_loop(
asynchronous ping operations. The loop can be terminated through the provided stop
event.
"""
- thread = threading.Thread(target=_ping_loop, args=(ping_fn, stop_event))
+ thread = threading.Thread(
+ target=_ping_loop, args=(ping_fn, stop_event), daemon=True
+ )
thread.start()
return thread
diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py
index 7cec319e7906..d12124b89840 100644
--- a/src/py/flwr/common/retry_invoker.py
+++ b/src/py/flwr/common/retry_invoker.py
@@ -261,6 +261,7 @@ def try_call_event_handler(
try:
ret = target(*args, **kwargs)
except self.recoverable_exceptions as err:
+ state.exception = err
# Check if giveup event should be triggered
max_tries_exceeded = try_cnt == self.max_tries
max_time_exceeded = (
diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py
index 696ec1132c4a..709be3d96a33 100644
--- a/src/py/flwr/server/compat/app_utils.py
+++ b/src/py/flwr/server/compat/app_utils.py
@@ -16,7 +16,6 @@
import threading
-import time
from typing import Dict, Tuple
from ..client_manager import ClientManager
@@ -60,6 +59,7 @@ def start_update_client_manager_thread(
client_manager,
f_stop,
),
+ daemon=True,
)
thread.start()
@@ -99,4 +99,5 @@ def _update_client_manager(
raise RuntimeError("Could not register node.")
# Sleep for 3 seconds
- time.sleep(3)
+ if not f_stop.is_set():
+ f_stop.wait(3)
diff --git a/src/py/flwr/server/compat/app_utils_test.py b/src/py/flwr/server/compat/app_utils_test.py
index 7e47e6eaaf32..023d65b0dc72 100644
--- a/src/py/flwr/server/compat/app_utils_test.py
+++ b/src/py/flwr/server/compat/app_utils_test.py
@@ -17,6 +17,8 @@
import time
import unittest
+from threading import Event
+from typing import Optional
from unittest.mock import Mock, patch
from ..client_manager import SimpleClientManager
@@ -29,9 +31,6 @@ class TestUtils(unittest.TestCase):
def test_start_update_client_manager_thread(self) -> None:
"""Test start_update_client_manager_thread function."""
# Prepare
- sleep = time.sleep
- sleep_patch = patch("time.sleep", lambda x: sleep(x / 100))
- sleep_patch.start()
expected_node_ids = list(range(100))
updated_expected_node_ids = list(range(80, 120))
driver = Mock()
@@ -39,20 +38,30 @@ def test_start_update_client_manager_thread(self) -> None:
driver.run_id = 123
driver.get_node_ids.return_value = expected_node_ids
client_manager = SimpleClientManager()
+ original_wait = Event.wait
+
+ def custom_wait(self: Event, timeout: Optional[float] = None) -> None:
+ if timeout is not None:
+ timeout /= 100
+ original_wait(self, timeout)
# Execute
- thread, f_stop = start_update_client_manager_thread(driver, client_manager)
- # Wait until all nodes are registered via `client_manager.sample()`
- client_manager.sample(len(expected_node_ids))
- # Retrieve all nodes in `client_manager`
- node_ids = {proxy.node_id for proxy in client_manager.all().values()}
- # Update the GetNodesResponse and wait until the `client_manager` is updated
- driver.get_node_ids.return_value = updated_expected_node_ids
- sleep(0.1)
- # Retrieve all nodes in `client_manager`
- updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()}
- # Stop the thread
- f_stop.set()
+ # Patching Event.wait with our custom function
+ with patch.object(Event, "wait", new=custom_wait):
+ thread, f_stop = start_update_client_manager_thread(driver, client_manager)
+ # Wait until all nodes are registered via `client_manager.sample()`
+ client_manager.sample(len(expected_node_ids))
+ # Retrieve all nodes in `client_manager`
+ node_ids = {proxy.node_id for proxy in client_manager.all().values()}
+ # Update the GetNodesResponse and wait until the `client_manager` is updated
+ driver.get_node_ids.return_value = updated_expected_node_ids
+ time.sleep(0.1)
+ # Retrieve all nodes in `client_manager`
+ updated_node_ids = {
+ proxy.node_id for proxy in client_manager.all().values()
+ }
+ # Stop the thread
+ f_stop.set()
# Assert
assert node_ids == set(expected_node_ids)