diff --git a/datasets/doc/source/how-to-use-with-local-data.rst b/datasets/doc/source/how-to-use-with-local-data.rst new file mode 100644 index 000000000000..276f6d6936ee --- /dev/null +++ b/datasets/doc/source/how-to-use-with-local-data.rst @@ -0,0 +1,257 @@ +Use with Local Data +=================== + +You can partition your local files and Python objects in +``Flower Datasets`` library using any available ``Partitioner``. + +This guide details how to create a `Hugging Face `_ `Dataset `_ which is the required type of input for Partitioners. +We will cover: + +* local files: CSV, JSON, image, audio, +* in-memory data: dictionary, list, pd.DataFrame, np.ndarray. + + +General Overview +---------------- +An all-in-one dataset preparation (downloading, preprocessing, partitioning) happens +using `FederatedDataset `_. However, we +will use only the `Partitioner` here since we use locally accessible data. + +The rest of this guide will explain how to create a +`Dataset `_ +from local files and existing (in memory) Python objects. + +Local Files +----------- +CSV +^^^ +.. code-block:: python + + from datasets import load_dataset + from flwr_datasets.partitioner import ChosenPartitioner + + # Single file + data_files = "path-to-my-file.csv" + + # Multiple Files + data_files = [ "path-to-my-file-1.csv", "path-to-my-file-2.csv", ...] + dataset = load_dataset("csv", data_files=data_files) + + # Divided Dataset + data_files = { + "train": single_train_file_or_list_of_files, + "test": single_test_file_or_list_of_files, + "can-have-more-splits": ... + } + dataset = load_dataset("csv", data_files=data_files) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +JSON +^^^^ + +.. code-block:: python + + from datasets import load_dataset + from flwr_datasets.partitioner import ChosenPartitioner + + # Single file + data_files = "path-to-my-file.json" + + # Multitple Files + data_files = [ "path-to-my-file-1.json", "path-to-my-file-2.json", ...] + dataset = load_dataset("json", data_files=data_files) + + # Divided Dataset + data_files = { + "train": single_train_file_or_list_of_files, + "test": single_test_file_or_list_of_files, + "can-have-more-splits": ... + } + dataset = load_dataset("json", data_files=data_files) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + + +Image +^^^^^ +You can create an image dataset in two ways: + +1) give a path the directory + +The directory needs to be structured in the following way: dataset-name/split/class/name. For example: + +.. code-block:: + + mnist/train/1/unique_name.png + mnist/train/1/unique_name.png + mnist/train/2/unique_name.png + ... + mnist/test/1/unique_name.png + mnist/test/1/unique_name.png + mnist/test/2/unique_name.png + +Then, the path you can give is `./mnist`. + +.. code-block:: python + + from datasets import load_dataset + from flwr_datasets.partitioner import ChosenPartitioner + + # Directly from a directory + dataset = load_dataset("imagefolder", data_dir="/path/to/folder") + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +2) create a dataset from a CSV/JSON file and cast the path column to Image. + +.. code-block:: python + + from datasets import Image, load_dataset + from flwr_datasets.partitioner import ChosenPartitioner + + dataset = load_dataset(...) + dataset = dataset.cast_column("path", Image()) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + + +Audio +^^^^^ +Analogously to the image datasets, there are two methods here: + +1) give a path to the directory + +.. code-block:: python + + from datasets import load_dataset + from flwr_datasets.partitioner import ChosenPartitioner + + dataset = load_dataset("audiofolder", data_dir="/path/to/folder") + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +2) create a dataset from a CSV/JSON file and cast the path column to Audio. + +.. code-block:: python + + from datasets import Audio, load_dataset + from flwr_datasets.partitioner import ChosenPartitioner + + dataset = load_dataset(...) + dataset = dataset.cast_column("path", Audio()) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +In-Memory +--------- + +From dictionary +^^^^^^^^^^^^^^^ +.. code-block:: python + + from datasets import Dataset + from flwr_datasets.partitioner import ChosenPartitioner + data = {"features": [1, 2, 3], "labels": [0, 0, 1]} + dataset = Dataset.from_dict(data) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +From list +^^^^^^^^^ +.. code-block:: python + + from datasets import Dataset + from flwr_datasets.partitioner import ChosenPartitioner + + my_list = [ + {"features": 1, "labels": 0}, + {"features": 2, "labels": 0}, + {"features": 3, "labels": 1} + ] + dataset = Dataset.from_list(my_list) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +From pd.DataFrame +^^^^^^^^^^^^^^^^^ +.. code-block:: python + + from datasets import Dataset + from flwr_datasets.partitioner import ChosenPartitioner + + data = {"features": [1, 2, 3], "labels": [0, 0, 1]} + df = pd.DataFrame(data) + dataset = Dataset.from_pandas(df) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +From np.ndarray +^^^^^^^^^^^^^^^ +The np.ndarray will be first transformed to pd.DataFrame + +.. code-block:: python + + from datasets import Dataset + from flwr_datasets.partitioner import ChosenPartitioner + + data = np.array([[1, 2, 3], [0, 0, 1]]).T + # You can add the column names by passing columns=["features", "labels"] + df = pd.DataFrame(data) + dataset = Dataset.from_pandas(df) + + partitioner = ChosenPartitioner(...) + partitioner.dataset = dataset + partition = partitioner.load_partition(partition_id=0) + +Partitioner Details +------------------- +Partitioning is triggered automatically during the first ``load_partition`` call. +You do not need to call any “do_partitioning” method. + +Partitioner abstraction is designed to allow for a single dataset assignment. + +.. code-block:: python + + partitioner.dataset = your_dataset + +If you need to do the same partitioning on a different dataset, create a new Partitioner +for that, e.g.: + +.. code-block:: python + + from flwr_datasets.partitioner import IidPartitioner + + iid_partitioner_for_mnist = IidPartitioner(num_partitions=10) + iid_partitioner_for_mnist.dataset = mnist_dataset + + iid_partitioner_for_cifar = IidPartitioner(num_partitions=10) + iid_partitioner_for_cifar.dataset = cifar_dataset + + +More Resources +-------------- +If you are looking for more details or you have not found the format you are looking for, please visit the `HuggingFace Datasets docs `_. +This guide is based on the following ones: + +* `General Information `_ +* `Tabular Data `_ +* `Image Data `_ +* `Audio Data `_ diff --git a/datasets/doc/source/index.rst b/datasets/doc/source/index.rst index fd226b308bd5..2144c527f8cd 100644 --- a/datasets/doc/source/index.rst +++ b/datasets/doc/source/index.rst @@ -31,6 +31,7 @@ Problem-oriented how-to guides show step-by-step how to achieve a specific goal. how-to-use-with-pytorch how-to-use-with-tensorflow how-to-use-with-numpy + how-to-use-with-local-data how-to-disable-enable-progress-bar References diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 55a7e597f6b4..6c41eaa3562f 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -59,7 +59,8 @@ class FederatedDataset: argument. Defaults to True. seed : Optional[int] Seed used for dataset shuffling. It has no effect if `shuffle` is False. The - seed cannot be set in the later stages. + seed cannot be set in the later stages. If `None`, then fresh, unpredictable entropy + will be pulled from the OS. Defaults to 42. Examples -------- diff --git a/examples/embedded-devices/Dockerfile b/examples/embedded-devices/Dockerfile index a85c05c4bb7a..48602c89970a 100644 --- a/examples/embedded-devices/Dockerfile +++ b/examples/embedded-devices/Dockerfile @@ -8,7 +8,7 @@ RUN pip3 install --upgrade pip # Install flower RUN pip3 install flwr>=1.0 -RUN pip3 install flwr-datsets>=0.2 +RUN pip3 install flwr-datsets>=0.0.2 RUN pip3 install tqdm==4.65.0 WORKDIR /client diff --git a/src/py/flwr/cli/example.py b/src/py/flwr/cli/example.py index 625ca8729640..4790e72d48bf 100644 --- a/src/py/flwr/cli/example.py +++ b/src/py/flwr/cli/example.py @@ -39,7 +39,9 @@ def example() -> None: with urllib.request.urlopen(examples_directory_url) as res: data = json.load(res) example_names = [ - item["path"] for item in data["tree"] if item["path"] not in [".gitignore"] + item["path"] + for item in data["tree"] + if item["path"] not in [".gitignore", "doc"] ] example_name = prompt_options( diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 7eb47e3e3548..0c429ce34cf2 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -22,7 +22,12 @@ import typer from typing_extensions import Annotated -from ..utils import prompt_options, prompt_text +from ..utils import ( + is_valid_project_name, + prompt_options, + prompt_text, + sanitize_project_name, +) class MlFramework(str, Enum): @@ -81,6 +86,16 @@ def new( ] = None, ) -> None: """Create new Flower project.""" + if project_name is None: + project_name = prompt_text("Please provide project name") + if not is_valid_project_name(project_name): + project_name = prompt_text( + "Please provide a name that only contains " + "characters in {'_', 'a-zA-Z', '0-9'}", + predicate=is_valid_project_name, + default=sanitize_project_name(project_name), + ) + print( typer.style( f"🔨 Creating Flower project {project_name}...", @@ -89,9 +104,6 @@ def new( ) ) - if project_name is None: - project_name = prompt_text("Please provide project name") - if framework is not None: framework_str = str(framework.value) else: @@ -116,7 +128,6 @@ def new( # List of files to render files = { "README.md": {"template": "app/README.md.tpl"}, - "requirements.txt": {"template": f"app/requirements.{framework_str}.txt.tpl"}, "flower.toml": {"template": "app/flower.toml.tpl"}, "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, f"{pnl}/__init__.py": {"template": "app/code/__init__.py.tpl"}, diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py index cedcb09b7755..11620d234191 100644 --- a/src/py/flwr/cli/new/new_test.py +++ b/src/py/flwr/cli/new/new_test.py @@ -66,7 +66,6 @@ def test_new(tmp_path: str) -> None: project_name = "FedGPT" framework = MlFramework.PYTORCH expected_files_top_level = { - "requirements.txt", "fedgpt", "README.md", "flower.toml", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index 15d8211a1a25..9701c62af6f0 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -1,19 +1,19 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "$project_name" version = "1.0.0" description = "" -license = "Apache-2.0" authors = [ - "The Flower Authors ", + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +license = {text = "Apache License (2.0)"} +dependencies = [ + "flwr[simulation]>=1.8.0,<2.0", + "numpy>=1.21.0", ] -readme = "README.md" -[tool.poetry.dependencies] -python = "^3.9" -# Mandatory dependencies -numpy = "^1.21.0" -flwr = { version = "^1.8.0", extras = ["simulation"] } +[tool.hatch.build.targets.wheel] +packages = ["."] diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index 46a5508fe2ac..0661c7b730c1 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -1,21 +1,21 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "$project_name" version = "1.0.0" description = "" -license = "Apache-2.0" authors = [ - "The Flower Authors ", + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +license = {text = "Apache License (2.0)"} +dependencies = [ + "flwr[simulation]>=1.8.0,<2.0", + "flwr-datasets[vision]>=0.0.2,<1.0.0", + "torch==2.2.1", + "torchvision==0.17.1", ] -readme = "README.md" -[tool.poetry.dependencies] -python = "^3.9" -# Mandatory dependencies -flwr = { version = "^1.8.0", extras = ["simulation"] } -flwr-datasets = { version = "0.0.2", extras = ["vision"] } -torch = "2.2.1" -torchvision = "0.17.1" +[tool.hatch.build.targets.wheel] +packages = ["."] diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index f7383a78b7d5..5a017eb6ed74 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -1,21 +1,20 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "$project_name" version = "1.0.0" description = "" -license = "Apache-2.0" authors = [ - "The Flower Authors ", + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +license = {text = "Apache License (2.0)"} +dependencies = [ + "flwr[simulation]>=1.8.0,<2.0", + "flwr-datasets[vision]>=0.0.2,<1.0.0", + "tensorflow>=2.11.1", ] -readme = "README.md" -[tool.poetry.dependencies] -python = ">=3.9,<3.11" -# Mandatory dependencies -flwr = { version = "^1.8.0", extras = ["simulation"] } -flwr-datasets = { version = "^0.0.2", extras = ["vision"] } -tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } -tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } +[tool.hatch.build.targets.wheel] +packages = ["."] diff --git a/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl deleted file mode 100644 index 4b460798e96f..000000000000 --- a/src/py/flwr/cli/new/templates/app/requirements.numpy.txt.tpl +++ /dev/null @@ -1,2 +0,0 @@ -flwr>=1.8, <2.0 -numpy>=1.21.0 diff --git a/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl deleted file mode 100644 index f20b9d71e339..000000000000 --- a/src/py/flwr/cli/new/templates/app/requirements.pytorch.txt.tpl +++ /dev/null @@ -1,4 +0,0 @@ -flwr[simulation]>=1.8.0 -flwr-datasets[vision]==0.0.2 -torch==2.2.1 -torchvision==0.17.1 diff --git a/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl b/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl deleted file mode 100644 index b6fb49a4bbcb..000000000000 --- a/src/py/flwr/cli/new/templates/app/requirements.tensorflow.txt.tpl +++ /dev/null @@ -1,4 +0,0 @@ -flwr>=1.8, <2.0 -flwr-datasets[vision]>=0.0.2, <1.0.0 -tensorflow-macos>=2.9.1, !=2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" -tensorflow-cpu>=2.9.1, !=2.11.1 ; platform_machine == "x86_64" diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index 4e86f0c3b8c8..7a36c3eb7b84 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -14,18 +14,23 @@ # ============================================================================== """Flower command line interface utils.""" -from typing import List, cast +from typing import Callable, List, Optional, cast import typer -def prompt_text(text: str) -> str: +def prompt_text( + text: str, + predicate: Callable[[str], bool] = lambda _: True, + default: Optional[str] = None, +) -> str: """Ask user to enter text input.""" while True: result = typer.prompt( - typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True) + typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True), + default=default, ) - if len(result) > 0: + if predicate(result) and len(result) > 0: break print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True)) @@ -65,3 +70,54 @@ def prompt_options(text: str, options: List[str]) -> str: result = options[int(index)] return result + + +def is_valid_project_name(name: str) -> bool: + """Check if the given string is a valid Python module name. + + A valid module name must start with a letter or an underscore, and can only contain + letters, digits, and underscores. + """ + if not name: + return False + + # Check if the first character is a letter or underscore + if not (name[0].isalpha() or name[0] == "_"): + return False + + # Check if the rest of the characters are valid (letter, digit, or underscore) + for char in name[1:]: + if not (char.isalnum() or char == "_"): + return False + + return True + + +def sanitize_project_name(name: str) -> str: + """Sanitize the given string to make it a valid Python module name. + + This version replaces hyphens with underscores, removes any characters not allowed + in Python module names, makes the string lowercase, and ensures it starts with a + valid character. + """ + # Replace '-' with '_' + name_with_underscores = name.replace("-", "_").replace(" ", "_") + + # Allowed characters in a module name: letters, digits, underscore + allowed_chars = set( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + ) + + # Make the string lowercase + sanitized_name = name_with_underscores.lower() + + # Remove any characters not allowed in Python module names + sanitized_name = "".join(c for c in sanitized_name if c in allowed_chars) + + # Ensure the first character is a letter or underscore + if sanitized_name and ( + sanitized_name[0].isdigit() or sanitized_name[0] not in allowed_chars + ): + sanitized_name = "_" + sanitized_name + + return sanitized_name diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 1720405ab867..4fa9c80c6cdf 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -456,12 +456,13 @@ def _load_client_app() -> ClientApp: continue log(INFO, "") - log( - INFO, - "[RUN %s, ROUND %s]", - message.metadata.run_id, - message.metadata.group_id, - ) + if len(message.metadata.group_id) > 0: + log( + INFO, + "[RUN %s, ROUND %s]", + message.metadata.run_id, + message.metadata.group_id, + ) log( INFO, "Received: %s message %s", diff --git a/src/py/flwr/client/heartbeat.py b/src/py/flwr/client/heartbeat.py index 0cc979ddfd13..b68e6163cc01 100644 --- a/src/py/flwr/client/heartbeat.py +++ b/src/py/flwr/client/heartbeat.py @@ -66,7 +66,9 @@ def start_ping_loop( asynchronous ping operations. The loop can be terminated through the provided stop event. """ - thread = threading.Thread(target=_ping_loop, args=(ping_fn, stop_event)) + thread = threading.Thread( + target=_ping_loop, args=(ping_fn, stop_event), daemon=True + ) thread.start() return thread diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index 7cec319e7906..d12124b89840 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -261,6 +261,7 @@ def try_call_event_handler( try: ret = target(*args, **kwargs) except self.recoverable_exceptions as err: + state.exception = err # Check if giveup event should be triggered max_tries_exceeded = try_cnt == self.max_tries max_time_exceeded = ( diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py index 696ec1132c4a..709be3d96a33 100644 --- a/src/py/flwr/server/compat/app_utils.py +++ b/src/py/flwr/server/compat/app_utils.py @@ -16,7 +16,6 @@ import threading -import time from typing import Dict, Tuple from ..client_manager import ClientManager @@ -60,6 +59,7 @@ def start_update_client_manager_thread( client_manager, f_stop, ), + daemon=True, ) thread.start() @@ -99,4 +99,5 @@ def _update_client_manager( raise RuntimeError("Could not register node.") # Sleep for 3 seconds - time.sleep(3) + if not f_stop.is_set(): + f_stop.wait(3) diff --git a/src/py/flwr/server/compat/app_utils_test.py b/src/py/flwr/server/compat/app_utils_test.py index 7e47e6eaaf32..023d65b0dc72 100644 --- a/src/py/flwr/server/compat/app_utils_test.py +++ b/src/py/flwr/server/compat/app_utils_test.py @@ -17,6 +17,8 @@ import time import unittest +from threading import Event +from typing import Optional from unittest.mock import Mock, patch from ..client_manager import SimpleClientManager @@ -29,9 +31,6 @@ class TestUtils(unittest.TestCase): def test_start_update_client_manager_thread(self) -> None: """Test start_update_client_manager_thread function.""" # Prepare - sleep = time.sleep - sleep_patch = patch("time.sleep", lambda x: sleep(x / 100)) - sleep_patch.start() expected_node_ids = list(range(100)) updated_expected_node_ids = list(range(80, 120)) driver = Mock() @@ -39,20 +38,30 @@ def test_start_update_client_manager_thread(self) -> None: driver.run_id = 123 driver.get_node_ids.return_value = expected_node_ids client_manager = SimpleClientManager() + original_wait = Event.wait + + def custom_wait(self: Event, timeout: Optional[float] = None) -> None: + if timeout is not None: + timeout /= 100 + original_wait(self, timeout) # Execute - thread, f_stop = start_update_client_manager_thread(driver, client_manager) - # Wait until all nodes are registered via `client_manager.sample()` - client_manager.sample(len(expected_node_ids)) - # Retrieve all nodes in `client_manager` - node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Update the GetNodesResponse and wait until the `client_manager` is updated - driver.get_node_ids.return_value = updated_expected_node_ids - sleep(0.1) - # Retrieve all nodes in `client_manager` - updated_node_ids = {proxy.node_id for proxy in client_manager.all().values()} - # Stop the thread - f_stop.set() + # Patching Event.wait with our custom function + with patch.object(Event, "wait", new=custom_wait): + thread, f_stop = start_update_client_manager_thread(driver, client_manager) + # Wait until all nodes are registered via `client_manager.sample()` + client_manager.sample(len(expected_node_ids)) + # Retrieve all nodes in `client_manager` + node_ids = {proxy.node_id for proxy in client_manager.all().values()} + # Update the GetNodesResponse and wait until the `client_manager` is updated + driver.get_node_ids.return_value = updated_expected_node_ids + time.sleep(0.1) + # Retrieve all nodes in `client_manager` + updated_node_ids = { + proxy.node_id for proxy in client_manager.all().values() + } + # Stop the thread + f_stop.set() # Assert assert node_ids == set(expected_node_ids)