diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 64544b1c4463..72ea54773564 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -317,7 +317,8 @@ def _prepare_dataset(self) -> None: raise ValueError( "Probably one of the specified parameter in `load_dataset_kwargs` " "change the return type of the datasets.load_dataset function. " - "Make sure to use parameter such that the return type is DatasetDict." + "Make sure to use parameter such that the return type is DatasetDict. " + f"The return type is currently: {type(self._dataset)}." ) if self._shuffle: # Note it shuffles all the splits. The self._dataset is DatasetDict diff --git a/examples/advanced-pytorch/pyproject.toml b/examples/advanced-pytorch/pyproject.toml index b846a6054cc8..f2c9ad731196 100644 --- a/examples/advanced-pytorch/pyproject.toml +++ b/examples/advanced-pytorch/pyproject.toml @@ -12,7 +12,7 @@ authors = [ ] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "1.13.1" diff --git a/examples/advanced-tensorflow/pyproject.toml b/examples/advanced-tensorflow/pyproject.toml index 02bd923129a4..9fc623a0f3ec 100644 --- a/examples/advanced-tensorflow/pyproject.toml +++ b/examples/advanced-tensorflow/pyproject.toml @@ -9,7 +9,7 @@ description = "Advanced Flower/TensorFlow Example" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } diff --git a/examples/android-kotlin/gen_tflite/pyproject.toml b/examples/android-kotlin/gen_tflite/pyproject.toml index aabf351bd51d..884e7148cc3d 100644 --- a/examples/android-kotlin/gen_tflite/pyproject.toml +++ b/examples/android-kotlin/gen_tflite/pyproject.toml @@ -5,7 +5,7 @@ description = "" authors = ["Steven Hé (Sīchàng) "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" numpy = ">=1.23,<2.0" tensorflow-cpu = ">=2.12,<3.0" pandas = ">=2.0,<3.0" diff --git a/examples/android-kotlin/pyproject.toml b/examples/android-kotlin/pyproject.toml index 9cf0688d83b5..b83b243a349d 100644 --- a/examples/android-kotlin/pyproject.toml +++ b/examples/android-kotlin/pyproject.toml @@ -9,5 +9,5 @@ description = "" authors = ["Steven Hé (Sīchàng) "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" diff --git a/examples/android/pyproject.toml b/examples/android/pyproject.toml index 0371f7208292..d0d18ebc48bc 100644 --- a/examples/android/pyproject.toml +++ b/examples/android/pyproject.toml @@ -9,7 +9,7 @@ description = "Android Example" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/app-pytorch/pyproject.toml b/examples/app-pytorch/pyproject.toml index c00e38aef19b..88e916546632 100644 --- a/examples/app-pytorch/pyproject.toml +++ b/examples/app-pytorch/pyproject.toml @@ -9,7 +9,7 @@ description = "Multi-Tenant Federated Learning with Flower and PyTorch" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" # Mandatory dependencies flwr = { version = "^1.8.0", extras = ["simulation"] } torch = "2.2.1" diff --git a/examples/custom-mods/pyproject.toml b/examples/custom-mods/pyproject.toml index e690e05bab8f..ff36398ef157 100644 --- a/examples/custom-mods/pyproject.toml +++ b/examples/custom-mods/pyproject.toml @@ -9,7 +9,7 @@ description = "Multi-Tenant Federated Learning with Flower and PyTorch" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = { path = "../../", develop = true, extras = ["simulation"] } tensorboard = "2.16.2" torch = "1.13.1" diff --git a/examples/ios/pyproject.toml b/examples/ios/pyproject.toml index 2e55b14cf761..03ea89ea3e54 100644 --- a/examples/ios/pyproject.toml +++ b/examples/ios/pyproject.toml @@ -9,5 +9,5 @@ description = "Example Server for Flower iOS/CoreML" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" diff --git a/examples/pytorch-from-centralized-to-federated/pyproject.toml b/examples/pytorch-from-centralized-to-federated/pyproject.toml index 3d1559e3a515..57a8082fd6bf 100644 --- a/examples/pytorch-from-centralized-to-federated/pyproject.toml +++ b/examples/pytorch-from-centralized-to-federated/pyproject.toml @@ -9,7 +9,7 @@ description = "PyTorch: From Centralized To Federated with Flower" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "1.13.1" diff --git a/examples/quickstart-jax/pyproject.toml b/examples/quickstart-jax/pyproject.toml index c956191369b5..68a3455aedee 100644 --- a/examples/quickstart-jax/pyproject.toml +++ b/examples/quickstart-jax/pyproject.toml @@ -5,7 +5,7 @@ description = "JAX example training a linear regression model with federated lea authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = "1.0.0" jax = "0.4.17" jaxlib = "0.4.17" diff --git a/examples/quickstart-mlcube/pyproject.toml b/examples/quickstart-mlcube/pyproject.toml index a2862bd5ebb7..f790a596ed19 100644 --- a/examples/quickstart-mlcube/pyproject.toml +++ b/examples/quickstart-mlcube/pyproject.toml @@ -9,7 +9,7 @@ description = "Keras Federated Learning Quickstart with Flower" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" # For development: { path = "../../", develop = true } tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/quickstart-monai/monaiexample/task.py b/examples/quickstart-monai/monaiexample/task.py index 09597562a1f2..4f7972d455fd 100644 --- a/examples/quickstart-monai/monaiexample/task.py +++ b/examples/quickstart-monai/monaiexample/task.py @@ -189,9 +189,10 @@ def _download_and_extract_if_needed(url, dest_folder): # Download the tar.gz file tar_gz_filename = url.split("/")[-1] if not os.path.isfile(tar_gz_filename): - with request.urlopen(url) as response, open( - tar_gz_filename, "wb" - ) as out_file: + with ( + request.urlopen(url) as response, + open(tar_gz_filename, "wb") as out_file, + ): out_file.write(response.read()) # Extract the tar.gz file diff --git a/examples/quickstart-pandas/pyproject.toml b/examples/quickstart-pandas/pyproject.toml index 2e6b1424bb54..4111815660f7 100644 --- a/examples/quickstart-pandas/pyproject.toml +++ b/examples/quickstart-pandas/pyproject.toml @@ -10,7 +10,7 @@ authors = ["Ragy Haddad "] maintainers = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } numpy = "1.23.2" diff --git a/examples/quickstart-tabnet/pyproject.toml b/examples/quickstart-tabnet/pyproject.toml index 6b7311f068f0..8345d6bd3da2 100644 --- a/examples/quickstart-tabnet/pyproject.toml +++ b/examples/quickstart-tabnet/pyproject.toml @@ -9,7 +9,7 @@ description = "Tabnet Federated Learning Quickstart with Flower" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = ">=1.0,<2.0" tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } diff --git a/examples/vertical-fl/pyproject.toml b/examples/vertical-fl/pyproject.toml index 19dcd0e7a842..f39af987cd72 100644 --- a/examples/vertical-fl/pyproject.toml +++ b/examples/vertical-fl/pyproject.toml @@ -9,7 +9,7 @@ description = "PyTorch Vertical FL with Flower" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } torch = "2.1.0" matplotlib = "3.7.3" diff --git a/examples/whisper-federated-finetuning/pyproject.toml b/examples/whisper-federated-finetuning/pyproject.toml index 27a89578c5a0..3d7bb023537c 100644 --- a/examples/whisper-federated-finetuning/pyproject.toml +++ b/examples/whisper-federated-finetuning/pyproject.toml @@ -9,7 +9,7 @@ description = "On-device Federated Downstreaming for Speech Classification" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } transformers = "4.32.1" tokenizers = "0.13.3" diff --git a/examples/xgboost-comprehensive/pyproject.toml b/examples/xgboost-comprehensive/pyproject.toml index c9259ffa1db4..4b20edd55047 100644 --- a/examples/xgboost-comprehensive/pyproject.toml +++ b/examples/xgboost-comprehensive/pyproject.toml @@ -9,7 +9,7 @@ description = "Federated XGBoost with Flower (comprehensive)" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" +python = ">=3.9,<3.11" flwr = { extras = ["simulation"], version = ">=1.7.0,<2.0" } flwr-datasets = ">=0.2.0,<1.0.0" xgboost = ">=2.0.0,<3.0.0" diff --git a/glossary/aggregation.mdx b/glossary/aggregation.mdx new file mode 100644 index 000000000000..82cadd6948bb --- /dev/null +++ b/glossary/aggregation.mdx @@ -0,0 +1,18 @@ +--- +title: "Aggregation" +description: "Combine model weights from sampled clients to update the global model. This process enables the global model to learn from each client's data." +date: "2024-05-23" +author: + name: "Charles Beauville" + position: "Machine Learning Engineer" + website: "https://www.linkedin.com/in/charles-beauville/" + github: "github.com/charlesbvll" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Tutorial: What is Federated Learning?" + link: "/docs/framework/tutorial-series-what-is-federated-learning.html" +--- + +During each Federated Learning round, the server will receive model weights from sampled clients and needs a function to improve its global model using those weights. This is what is called `aggregation`. It can be a simple weighted average function (like `FedAvg`), or can be more complex (e.g. incorporating optimization techniques). The aggregation is where FL's magic happens, it allows the global model to learn and improve from each client's particular data distribution with only their trained weights. + diff --git a/glossary/client.mdx b/glossary/client.mdx new file mode 100644 index 000000000000..52b14f124add --- /dev/null +++ b/glossary/client.mdx @@ -0,0 +1,17 @@ +--- +title: "Client" +description: "A client is any machine with local data that connects to a server, trains on received global model weights, and sends back updated weights. Clients may also evaluate global model weights." +date: "2024-05-23" +author: + name: "Charles Beauville" + position: "Machine Learning Engineer" + website: "https://www.linkedin.com/in/charles-beauville/" + github: "github.com/charlesbvll" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Tutorial: What is Federated Learning?" + link: "/docs/framework/tutorial-series-what-is-federated-learning.html" +--- + +Any machine with access to some data that connects to a server to perform Federated Learning. During each round of FL (if it is sampled), it will receive global model weights from the server, train on the data they have access to, and send the resulting trained weights back to the server. Clients can also be sampled to evaluate the global server weights on the data they have access to, this is called federated evaluation. diff --git a/glossary/docker.mdx b/glossary/docker.mdx new file mode 100644 index 000000000000..9ca079b90f06 --- /dev/null +++ b/glossary/docker.mdx @@ -0,0 +1,22 @@ +--- +title: "Docker" +description: "Docker is a containerization tool that allows for consistent and reliable deployment of applications across different environments." +date: "2024-07-08" +author: + name: "Robert Steiner" + position: "DevOps Engineer at Flower Labs" + website: "https://github.com/Robert-Steiner" +--- + +Docker is an open-source containerization tool for deploying and running applications. Docker +containers encapsulate an application's code, dependencies, and configuration files, allowing +for consistent and reliable deployment across different environments. + +In the context of federated learning, Docker containers can be used to package the entire client +and server application, including all the necessary dependencies, and then deployed on various +devices such as edge devices, cloud servers, or even on-premises servers. + +In Flower, Docker containers are used to containerize various applications like `SuperLink`, +`SuperNode`, and `SuperExec`. Flower's Docker images allow users to quickly get Flower up and +running, reducing the time and effort required to set up and configure the necessary software +and dependencies. diff --git a/glossary/edge-computing.mdx b/glossary/edge-computing.mdx new file mode 100644 index 000000000000..6499a48e8f07 --- /dev/null +++ b/glossary/edge-computing.mdx @@ -0,0 +1,40 @@ +--- +title: "Edge Computing" +description: "Edge computing is a distributed computing concept of bringing compute and data storage as close as possible to the source of data generation and consumption by users." +date: "2024-09-10" +author: + name: "Chong Shen Ng" + position: "Research Engineer @ Flower Labs" + website: "https://discuss.flower.ai/u/chongshenng" + github: "github.com/chongshenng" +related: + - text: "IoT" + link: "/glossary/iot" + - text: "Run Flower using Docker" + link: "/docs/framework/docker/index.html" + - text: "Flower Clients in C++" + link: "/docs/examples/quickstart-cpp.html" + - text: "Federated Learning on Embedded Devices with Flower" + link: "/docs/examples/embedded-devices.html" +--- + +### Introduction to Edge Computing + +Edge computing is a distributed computing concept of bringing compute and data storage as close as possible to the source of data generation and consumption by users. By performing computation close to the data source, edge computing aims to address limitations typically encountered in centralized computing, such as bandwidth, latency, privacy, and autonomy. + +Edge computing works alongside cloud and fog computing, but each serves different purposes. Cloud computing delivers on-demand resources like data storage, servers, analytics, and networking via the Internet. Fog computing, however, brings computing closer to devices by distributing communication and computation across clusters of IoT or edge devices. While edge computing is sometimes used interchangeably with fog computing, edge computing specifically handles data processing directly at or near the devices themselves, whereas fog computing distributes tasks across multiple nodes, bridging the gap between edge devices and the cloud. + +### Advantages and Use Cases of Edge Computing + +The key benefit of edge computing is that the volume of data moved is significantly reduced because computation runs directly on board the device on the acquired data. This reduces the amount of long-distance communication between machines, which improves latency and reduces transmissions costs. Examples of edge computing that benefit from offloading computation include: +1. Smart watches and fitness monitors that measure live health metrics. +2. Facial recognition and wake word detection on smartphones. +3. Real-time lane departure warning systems in road transport that detect lane lines using on-board videos and sensors. + +### Federated Learning in Edge Computing + +When deploying federated learning systems, edge computing is an important component to consider. Edge computing typically take the role of "clients" in federated learning. In a healthcare use case, servers in different hospitals can train models on their local data. In mobile computing, smartphones perform local training (and inference) on user data such as for next word prediction. + +### Edge Computing with Flower + +With the Flower framework, you can easily deploy federated learning workflows and maximise the use of edge computing resources. Flower provides the infrastructure to perform federated learning, federated evaluation, and federated analytics, all in a easy, scalable and secure way. Start with our tutorial on running Federated Learning on Embedded Devices (link [here](https://github.com/adap/flower/tree/main/examples/embedded-devices)), which shows you how to run Flower on NVidia Jetson devices and Raspberry Pis as your edge compute. diff --git a/glossary/evaluation.mdx b/glossary/evaluation.mdx new file mode 100644 index 000000000000..bf6b36cd0c4b --- /dev/null +++ b/glossary/evaluation.mdx @@ -0,0 +1,19 @@ +--- +title: "Evaluation" +description: "Evaluation measures how well the trained model performs by testing it on each client's local data, providing insights into its generalizability across varied data sources." +date: "2024-07-08" +author: + name: "Heng Pan" + position: "Research Scientist" + website: "https://discuss.flower.ai/u/pan-h/summary" + github: "github.com/panh99" +related: + - text: "Server" + link: "/glossary/server" + - text: "Client" + link: "/glossary/client" +--- + +Evaluation in machine learning is the process of assessing a model's performance on unseen data to determine its ability to generalize beyond the training set. This typically involves using a separate test set and various metrics like accuracy or F1-score to measure how well the model performs on new data, ensuring it isn't overfitting or underfitting. + +In federated learning, evaluation (or distributed evaluation) refers to the process of assessing a model's performance across multiple clients, such as devices or data centers. Each client evaluates the model locally using its own data and then sends the results to the server, which aggregates all the evaluation outcomes. This process allows for understanding how well the model generalizes to different data distributions without centralizing sensitive data. \ No newline at end of file diff --git a/glossary/federated-learning.mdx b/glossary/federated-learning.mdx new file mode 100644 index 000000000000..5f6b8a7f1732 --- /dev/null +++ b/glossary/federated-learning.mdx @@ -0,0 +1,14 @@ +--- +title: "Federated Learning" +description: "Federated Learning is a machine learning approach where model training occurs on decentralized devices, preserving data privacy and leveraging local computations." +date: "2024-05-23" +author: + name: "Julian Rußmeyer" + position: "UX/UI Designer" + website: "https://www.linkedin.com/in/julian-russmeyer/" +related: + - text: "Tutorial: What is Federated Learning?" + link: "/docs/framework/tutorial-series-what-is-federated-learning.html" +--- + +Federated learning is an approach to machine learning in which the model is trained on multiple decentralized devices or servers with local data samples without exchanging them. Instead of sending raw data to a central server, updates to the model are calculated locally and only the model parameters are aggregated centrally. In this way, user privacy is maintained and communication costs are reduced, while collaborative model training is enabled. diff --git a/glossary/grpc.mdx b/glossary/grpc.mdx new file mode 100644 index 000000000000..af58758d10bd --- /dev/null +++ b/glossary/grpc.mdx @@ -0,0 +1,44 @@ +--- +title: "gRPC" +description: "gRPC is an inter-process communication technology for building distributed apps. It allows developers to connect, invoke, operate, and debug apps as easily as making a local function call." +date: "2024-09-10" +author: + name: "Chong Shen Ng" + position: "Research Engineer @ Flower Labs" + website: "https://discuss.flower.ai/u/chongshenng" + github: "github.com/chongshenng" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Tutorial: What is Federated Learning?" + link: "/docs/framework/tutorial-series-what-is-federated-learning.html" + - text: "Protocol Buffers" + link: "/glossary/protocol-buffers" + - text: "Google: gRPC - A true internet scale RPC framework" + link: "https://cloud.google.com/blog/products/gcp/grpc-a-true-internet-scale-rpc-framework-is-now-1-and-ready-for-production-deployments" +--- + +### Introduction to gRPC + +gRPC is an inter-process communication technology for building distributed applications. It allows you to connect, invoke, operate, and debug these applications as easily as making a local function call. It can efficiently connect services in and across data centers. It is also applicable in the last mile of distributed computing to connect devices, mobile applications, and browsers to backend services. Supporting various languages like C++, Go, Java, and Python, and platforms like Android and the web, gRPC is a versatile framework for any environment. + +Google first [open-sourced gRPC in 2016](https://cloud.google.com/blog/products/gcp/grpc-a-true-internet-scale-rpc-framework-is-now-1-and-ready-for-production-deployments), basing it on their internal remote procedure call (RPC) framework, Stubby, designed to handle tens of billions of requests per second. Built on HTTP/2 and protocol buffers, gRPC is a popular high-performance framework for developers to built micro-services. Notable early adopters of gRPC include Square, Netflix, CockroachDB, Cisco, and Juniper Networks. + +By default, gRPC uses protocol buffers - Google's language-neutral and platform-neutral mechanism for efficiently serializing structured data - as its interface definition language and its underlying message interchange format. The recommended protocol buffer version as of writing is `proto3`, though other formats like JSON can also be used. + +### How does it work? + +gRPC operates similarly to many RPC systems. First, you specify the methods that can be called remotely on the server application, along with their parameters and return type. Then, with the appropriate code (more on this below), a gRPC client application can directly call these methods on the gRPC server application on a different machine as if it were a local object. Note that the definitions of client and server in gRPC is different to federated learning. For clarity, we will refer to client (server) applications in gRPC as gRPC client (server) applications. + +To use gRPC, follow these steps: +1. Define structure for the data you want to serialize in a proto file definition. `*.proto`. +2. Run the protocol buffer compiler `protoc` to generate to data access classes in the preferred language from the `*.proto` service definitions. This step generates the gRPC client and server code, as well as the regular protocol buffer code for handling your message types. +3. Use the generated class in your application to populate, serialize, and retrieve the class protocol buffer messages. + +### Use cases in Federated Learning + +There are several reasons why gRPC is particularly useful in federated learning. First, clients and server in a federation rely on stable and efficient communication. Using Protobuf, a highly efficient binary serialization format, gRPC overcomes the bandwidth limitations in federated learning, such as in low-bandwidth mobile connections. Second, gRPC’s language-independent communication allows developers to use a variety of programming languages, enabling broader adoption for on-device executions. + +### gRPC in Flower + +gRPC's benefits for distributed computing make it a natural choice for the Flower framework. Flower uses gRPC as its primary communication protocol. To make it easier to build your federated learning systems, we have introduced high-level APIs to take care of the serialization and deserialization of the model parameters, configurations, and metrics. For more details on how to use Flower, follow our "Get started with Flower" tutorial here. diff --git a/glossary/inference.mdx b/glossary/inference.mdx new file mode 100644 index 000000000000..06c93a834d2d --- /dev/null +++ b/glossary/inference.mdx @@ -0,0 +1,21 @@ +--- +title: "Inference" +description: "Inference is the phase in which a trained machine learning model applies its learned patterns to new, unseen data to make predictions or decisions." +date: "2024-07-12" +author: + name: "Yan Gao" + position: "Research Scientist" + website: "https://discuss.flower.ai/u/yan-gao/" + github: "github.com/yan-gao-GY" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Server" + link: "/glossary/server" + - text: "Client" + link: "/glossary/client" +--- + +Inference, also known as model prediction, is the stage in the machine learning workflow where a trained model is used to make predictions based on new, unseen data. In a typical machine learning setting, model inference involves the following steps: model loading, where the trained model is loaded into the application or service where it will be used; data preparation, which preprocess the new data in the same way as the training data; and model prediction, where the prepared data is fed into the model to compute outputs based on the learned patterns during training. + +In the context of federated learning (FL), inference can be performed locally on the user's device. A global model updated from FL process is deployed and loaded on individual nodes (e.g., smartphones, hospital servers) for local inference. This allows for keeping all data on-device, enhancing privacy and reducing latency. diff --git a/glossary/iot.mdx b/glossary/iot.mdx new file mode 100644 index 000000000000..ec1932c444f3 --- /dev/null +++ b/glossary/iot.mdx @@ -0,0 +1,48 @@ +--- +title: "IoT" +description: "The Internet of Things (IoT) refers to devices with sensors, software, and tech that connect and exchange data with other systems via the internet or communication networks." +date: "2024-09-10" +author: + name: "Chong Shen Ng" + position: "Research Engineer @ Flower Labs" + website: "https://discuss.flower.ai/u/chongshenng" + github: "github.com/chongshenng" +related: + - text: "Edge Computing" + link: "/glossary/edge-computing" + - text: "Run Flower using Docker" + link: "/docs/framework/docker/index.html" + - text: "Flower Clients in C++" + link: "/docs/examples/quickstart-cpp.html" + - text: "Federated Learning on Embedded Devices with Flower" + link: "/docs/examples/embedded-devices.html" + - text: "Cisco: Redefine Connectivity by Building a Network to Support the Internet of Things" + link: "https://www.cisco.com/c/en/us/solutions/service-provider/a-network-to-support-iot.html" +--- + +### Introduction to IoT + +The Internet of Things (IoT) describe devices with sensors, processing ability, software, and other technologies that connect and exchange data with other devices and systems over the Internet or other communications networks. IoT is often also referred as Machine-to-Machine (M2M) connections. Examples of IoT include embedded systems, wireless sensor networks, control systems, automation (home and building). In the consumer market, IoT technology is synonymous with smart home products. The IoT architecture bears resemblance to edge computing, but more broadly encompasses edge devices, gateways, and the cloud. + +### Use cases in Federated Learning + +From the perspective of federated learning, IoT systems provide two common configurations: first as a data source for training, and second as a point for running inference/analytics. + +Cisco's Global Cloud Index estimated that nearly 850 Zettabytes (ZB) of data will be generated by all people, machines and things in 2021 ([link](https://www.cisco.com/c/en/us/solutions/service-provider/a-network-to-support-iot.html) to article). In IoT, the data is different because not all of the data needs to be stored and instead, the most impactful business values come from running computations on the data. This positions IoT as an ideal candidate for implementing federated learning systems, where a model trained on a datastream from a single device may not be useful, but when trained collaboratively on hundreds or thousands of devices, yields a better performing and generalisable model. The key benefit is that the generated data remains local on the device and can even be offloaded after multiple rounds of federated learning. Some examples are presented below. + +Once a model is trained (e.g. in a federated way), the model can be put into production. What this means is to deploy the model on the IoT device and compute predictions based on the newly generated/acquired data. + +Federated learning in IoT can be organized on two axes: by industry and by use cases. + +For industry applications, examples include: +1. Healthcare - e.g. vital sign, activity levels, or sleep pattern monitoring using fitness trackers. +2. Transportation - e.g. trajectory prediction, object detection, driver drowsiness detection using on-board sensors and cameras. + +For use cases, examples include: +1. Predictive maintenance - e.g. using data acquired from physical sensors (impedence, temperature, vibration, pressure, viscosity, etc ...) +2. Anomaly detection - e.g. using environmental monitoring sensors for predicting air, noise, or water pollution, using internet network traffic data for network intrusion detection, using fiber optic sensors for remote sensing and monitoring, etc ... +3. Quality assurance and quality control - e.g. using in-line optical, acoustic, or sensor data during manufacturing processes to identify faulty products, etc ... + +### Using Flower for Federated Learning with IoT + +Flower is developed with a deployment engine that allows you to easily deploy your federated learning system on IoT devices. As a Data Scientist/ML Engineer, you will only need to write ClientApps and deploy them to IoT devices without needing to deal with the infrastructure and networking. To further help deployment, we provide [Docker images](https://hub.docker.com/u/flwr) for the SuperLink, SuperNode, and ServerApp so that you can easily ship the requirements of your Flower applications in containers in a production environment. Lastly, Flower supports the development of both Python and C++ clients, which provides developers with flexible ways of building ClientApps for resource-contrained devices. diff --git a/glossary/medical-ai.mdx b/glossary/medical-ai.mdx new file mode 100644 index 000000000000..d557f457c189 --- /dev/null +++ b/glossary/medical-ai.mdx @@ -0,0 +1,24 @@ +--- +title: "Medical AI" +description: "Medical AI involves the application of artificial intelligence technologies to healthcare, enhancing diagnosis, treatment planning, and patient monitoring by analyzing complex medical data." +date: "2024-07-12" +author: + name: "Yan Gao" + position: "Research Scientist" + website: "https://discuss.flower.ai/u/yan-gao/" + github: "github.com/yan-gao-GY" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Server" + link: "/glossary/server" + - text: "Client" + link: "/glossary/client" +--- + +Medical AI refers to the application of artificial intelligence technologies, particularly machine learning algorithms, to medical and healthcare-related fields. This includes, but is not limited to, tasks such as disease diagnosis, personalized treatment plans, drug development, medical imaging analysis, and healthcare management. The goal of Medical AI is to enhance healthcare services, improve treatment outcomes, reduce costs, and increase efficiency within healthcare systems. + +Federated learning (FL) introduces a novel approach to the training of machine learning models across multiple decentralized devices or servers holding local data samples, without exchanging them. This is particularly appropriate in the medical field due to the sensitive nature of medical data and strict privacy requirements. It leverages the strength of diverse datasets without compromising patient confidentiality, making it an increasingly popular choice in Medical AI applications. + +#### Medical AI in Flower +Flower, a friendly FL framework, is developing a more versatile and privacy-enhancing solution for Medical AI through the use of FL. Please check out [Flower industry healthcare](flower.ai/industry/healthcare) website for more detailed information. diff --git a/glossary/model-training.mdx b/glossary/model-training.mdx new file mode 100644 index 000000000000..ba5923962f1b --- /dev/null +++ b/glossary/model-training.mdx @@ -0,0 +1,24 @@ +--- +title: "Model Training" +description: "Model training is the process of teaching an algorithm to learn from data to make predictions or decisions." +date: "2024-07-12" +author: + name: "Yan Gao" + position: "Research Scientist" + website: "https://discuss.flower.ai/u/yan-gao/" + github: "github.com/yan-gao-GY" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Server" + link: "/glossary/server" + - text: "Client" + link: "/glossary/client" +--- + +Model training is a core component of developing machine learning (ML) systems, where an algorithm learns from data to make predictions or decisions. A typical model training process involves several key steps: dataset preparation, feature selection and engineering, choice of model based on the task (e.g., classification, regression), choice of training algorithm (e.g. optimizer), and model iteration for updating its weights and biases to minimize the loss function, which measures the difference between the predicted and actual outcomes on the training data. The traditional ML model training process typically involves considerable manual effort, whereas deep learning (DL) offers an end-to-end automated process. + +This approach assumes easy access to data and often requires substantial computational resources, depending on the size of the dataset and complexity of the model. However, large amounts of the data in the real world is distributed and protected due to privacy concerns, making it inaccessible for typical (centralized) model training. Federated learning (FL) migrates the model training from data center to local user ends. After local training, each participant sends only their model's updates (not the data) to a central server for aggregation. The updated global model is sent back to the participants for further rounds of local training and updates. This way, the model training benefits from diverse, real-world data without compromising individual data privacy. + +#### Model training in Flower +Flower, a friendly FL framework, offers a wealth of model training examples and baselines tailored for federated environments. Please refer to the [examples](https://flower.ai/docs/examples/) and [baselines](https://flower.ai/docs/baselines/) documentation for more detailed information. diff --git a/glossary/platform-independence.mdx b/glossary/platform-independence.mdx new file mode 100644 index 000000000000..9582fae057ff --- /dev/null +++ b/glossary/platform-independence.mdx @@ -0,0 +1,19 @@ +--- +title: "Platform Independence" +description: "The capability to run program across different hardware and operating systems." +date: "2024-07-08" +author: + name: "Heng Pan" + position: "Research Scientist" + website: "https://discuss.flower.ai/u/pan-h/summary" + github: "github.com/panh99" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" +--- + +Platform independence in federated learning refers to the capability of machine learning systems to operate seamlessly across various hardware and operating system environments. This ensures that the federated learning process can function effectively on various devices with different operating systems such as Windows, Linux, Mac OS, iOS, and Android without requiring platform-specific modifications. By achieving platform independence, federated learning frameworks enable efficient data analysis and model training across heterogeneous edge devices, enhancing scalability and flexibility in distributed machine learning scenarios. + +### Platform Independence in Flower + +Flower is interoperable with different operating systems and hardware platforms to work well in heterogeneous edge device environments. \ No newline at end of file diff --git a/glossary/protocol-buffers.mdx b/glossary/protocol-buffers.mdx new file mode 100644 index 000000000000..7e9bf6c7bbc7 --- /dev/null +++ b/glossary/protocol-buffers.mdx @@ -0,0 +1,31 @@ +--- +title: "Protocol Buffers" +description: "Protocol Buffers, often abbreviated as Protobuf, are a language-neutral, platform-neutral, extensible mechanism for serializing structured data, similar to XML but smaller, faster, and simpler." +date: "2024-05-24" +author: + name: "Taner Topal" + position: "Co-Creator and CTO @ Flower Labs" + website: "https://www.linkedin.com/in/tanertopal/" + github: "github.com/tanertopal" +related: + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Tutorial: What is Federated Learning?" + link: "/docs/framework/tutorial-series-what-is-federated-learning.html" +--- + +### Introduction to Protocol Buffers + +Protocol Buffers, often abbreviated as Protobuf, are a language-neutral, platform-neutral, extensible mechanism for serializing structured data, similar to XML but smaller, faster, and simpler. The method involves defining how you want your data to be structured once, then using language specific generated source code to write and read structured data to and from a variety of data streams. + +### How Protocol Buffers Work + +Protocol Buffers require a `.proto` file where the data structure (the messages) is defined. This is essentially a schema describing the data to be serialized. Once the `.proto` file is prepared, it is compiled using the Protobuf compiler (`protoc`), which generates data access classes in supported languages like Java, C++, Python, Swift, Kotlin, and more. These classes provide simple accessors for each field (like standard getters and setters) and methods to serialize the entire structure to a binary format that can be easily transmitted over network protocols or written to a file. + +### Advantages and Use Cases + +The primary advantages of Protocol Buffers include their simplicity, efficiency, and backward compatibility. They are more efficient than XML or JSON as they serialize to a binary format, which makes them both smaller and faster. They support backward compatibility, allowing to modify data structures without breaking deployed programs that are communicating using the protocol. This makes Protobuf an excellent choice for data storage or RPC (Remote Procedure Call) applications where small size, low latency, and schema evolution are critical. + +### Protocol Buffers in Flower + +In the context of Flower, Protocol Buffers play a crucial role in ensuring efficient and reliable communication between the server and clients. Federated learning involves heterogeneous clients (e.g., servers, mobile devices, edge devices) running different environments and programming languages. This setup requires frequent exchanges of model updates and other metadata between the server and clients. Protocol Buffers, with their efficient binary serialization, enable Flower to handle these exchanges with minimal overhead, ensuring low latency and reducing the bandwidth required for communication. Moreover, the backward compatibility feature of Protobuf allows Flower to evolve and update its communication protocols without disrupting existing deployments. Best of all, Flower users typically do not have to deal directly with Protobuf, as Flower provides language-specific abstractions that simplify interaction with the underlying communication protocols. diff --git a/glossary/scalability.mdx b/glossary/scalability.mdx new file mode 100644 index 000000000000..4bfb736ff08c --- /dev/null +++ b/glossary/scalability.mdx @@ -0,0 +1,22 @@ +--- +title: "Scalability" +description: "Scalability ensures systems grow with demand. In Federated Learning, it involves efficiently managing dynamic clients and diverse devices. Flower supports large-scale FL on various devices/ resources." +date: "2024-05-23" +author: + name: "Daniel Nata Nugraha" + position: "Software Engineer" + image: "daniel_nata_nugraha.png" + website: "https://www.linkedin.com/in/daniel-nugraha/" + github: "github.com/danielnugraha" +related: + - text: "Flower Paper" + link: "https://arxiv.org/pdf/2007.14390" + - text: "Federated Learning" + link: "/glossary/federated-learning" + - text: "Tutorial: What is Federated Learning?" + link: "/docs/framework/tutorial-series-what-is-federated-learning.html" +--- + +Scalability is the ability of a system, network, or process to accommodate an increasing amount of work. This involves adding resources (like servers) or optimizing existing ones to maintain or enhance performance. There are two main types of scalability: horizontal scalability (adding more nodes, such as servers) and vertical scalability (adding more power to existing nodes, like increasing CPU or RAM). Ideally, a scalable system can do both, seamlessly adapting to increased demands without significant downtime. Scalability is essential for businesses to grow while ensuring services remain reliable and responsive. +Scalability in Federated Learning involves managing dynamic client participation, as clients may join or leave unpredictably. This requires algorithms that adapt to varying availability and efficiently aggregate updates from numerous models. Additionally, scalable federated learning systems must handle heterogeneous client devices with different processing powers, network conditions, and data distributions, ensuring balanced contributions to the global model. +Scalability in Flower means efficiently conducting large-scale federated learning (FL) training and evaluation. Flower enables researchers to launch FL experiments with many clients using reasonable computing resources, such as a single machine or a multi-GPU rack. Flower supports scaling workloads to millions of clients, including diverse devices like Raspberry Pis, Android and iOS mobile devices, laptops, etc. It offers complete control over connection management and includes a virtual client engine for large-scale simulations. diff --git a/glossary/server.mdx b/glossary/server.mdx new file mode 100644 index 000000000000..efc25a227791 --- /dev/null +++ b/glossary/server.mdx @@ -0,0 +1,17 @@ +--- +title: "Server" +description: "The central entity coordinating the aggregation of local model updates from multiple clients to build a comprehensive, privacy-preserving global model." +date: "2024-07-08" +author: + name: "Heng Pan" + position: "Research Scientist" + website: "https://discuss.flower.ai/u/pan-h/summary" + github: "github.com/panh99" +related: + - text: "Client" + link: "/glossary/client" + - text: "Federated Learning" + link: "/glossary/federated-learning" +--- + +A server in federated learning plays a pivotal role by managing the distributed training process across various clients. Each client independently trains its local model using the local data and then sends the model updates to the server. The server aggregates the received updates to create a new global model, which is subsequently sent back to the clients. This iterative process allows the global model to improve over time without the need for the clients to share their raw data, ensuring data privacy and minimizing data transfer. \ No newline at end of file diff --git a/glossary/xgboost.mdx b/glossary/xgboost.mdx new file mode 100644 index 000000000000..51b5a2912e0b --- /dev/null +++ b/glossary/xgboost.mdx @@ -0,0 +1,34 @@ +--- +title: "XGBoost" +description: "XGBoost - or eXtreme Gradient Boosting - is an open-source library providing a regularizing gradient boosting decisiong tree framework for many programming languages including Python, C++, and Java." +date: "2024-09-10" +author: + name: "Chong Shen Ng" + position: "Research Engineer @ Flower Labs" + website: "https://discuss.flower.ai/u/chongshenng" + github: "github.com/chongshenng" +related: + - text: "Quickstart Federated Learning with XGBoost and Flower" + link: "/docs/framework/tutorial-quickstart-xgboost.html" + - text: "Flower Example using XGBoost (Comprehensive)" + link: "/docs/examples/xgboost-comprehensive.html" +--- + +### Introduction to XGBoost + +XGBoost - or eXtreme Gradient Boosting - is an open-source library which provides a regularizing gradient boosting framework for Python, C++, Java, R, Julia, Perl, and Scala. It implements machine learning algorithms based on the gradient boosting concept, where a single model is created from an ensemble of weak learners (decision trees). This is commonly referred as a Gradient Boosting Decision Trees (GBDT), a decision tree ensemble learning algorithm. + +GBDTs are commonly compared with the random forest algorithm. They are similar in the sense that they build multiple decision trees. But the key differences are in how they are built and combined. Random forest first builds full decision trees in parallel from bootstrap samples of the dataset, and then generates the final prediction based on an average of all of the predictions. In contrast, GBDT iteratively trains decision trees with the objective that each subsequent tree reduces the error residuals of the previous model - this is the concept of boosting. The final prediction in a GBDT is a weighted sum of all of the tree predictions. While the bootstrap aggregation method of random forest minimizes variance and overfitting, the boosting method of GBDT minimizes bias and underfitting. + +XGBoost includes many features that optimizes the implementation of GBDT, including parallelized trees training (instead of sequential) and integration with distributed processing frameworks like Apache Spark and Dask. These various performance improvements have historically made XGBoost the preferred framework of choice when training models for supervised learning tasks, and have seen widespread success in Kaggle competitions on structured data. + +### Use cases in Federated Learning + +While there is no way to know before hand what model would perform the best in federated learning, XGBoost is appealing for several reasons: +1. To train the first model, XGBoost hyperparameters require significantly less tuning compared to neural network-based models. +2. XGBoost is known to produce models that perform far better than neural networks on tabular datasets, which can be encountered in real-world federated learning systems such as in healthcare or IoT applications. +3. Feature scaling is unnecessary when training XGBoost models. This not only facilitates fine-tuning on new data distributions, but also supports cross-device and cross-silo federated learning, where the data distributions from participating clients are not know a priori. + +### XGBoost in Flower + +In Flower, we have provided two strategies for performing federated learning with XGBoost: [`FedXgbBagging`](https://github.com/adap/flower/blob/main/src/py/flwr/server/strategy/fedxgb_bagging.py) and [`FedXgbCyclic`](https://github.com/adap/flower/blob/main/src/py/flwr/server/strategy/fedxgb_cyclic.py), which are inspired from the work at Nvidia NVFlare. These implementations allow Flower users to use different aggregation strategies for the XGBoost model. `FedXgbBagging` aggregates trees from all participating clients every round, whereas `FedXgbCyclic` aggregates clients' trees sequentially in a round-robin manner. With these strategies, Flower users can very quickly and easily run and compare the performance of federated learning systems on distributed tabular datasets using state-of-the-art XGBoost aggregation strategies, without needing to implement them from scratch. diff --git a/pyproject.toml b/pyproject.toml index b4ff7be41bf3..e8708b5fa56c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -66,10 +65,10 @@ flwr-clientapp = "flwr.client.clientapp:flwr_clientapp" flower-client-app = "flwr.client.supernode:run_client_app" # Deprecated [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" # Mandatory dependencies numpy = "^1.21.0" -grpcio = "^1.60.0,!=1.64.2,!=1.65.1,!=1.65.2,!=1.65.4" +grpcio = "^1.60.0,!=1.64.2,!=1.65.1,!=1.65.2,!=1.65.4,!=1.65.5,!=1.66.0,!=1.66.1" protobuf = "^4.25.2" cryptography = "^42.0.4" pycryptodome = "^3.18.0" @@ -79,7 +78,7 @@ tomli = "^2.0.1" tomli-w = "^1.0.0" pathspec = "^0.12.1" # Optional dependencies (Simulation Engine) -ray = { version = "==2.10.0", optional = true, python = ">=3.8,<3.12" } +ray = { version = "==2.10.0", optional = true, python = ">=3.9,<3.12" } # Optional dependencies (REST transport layer) requests = { version = "^2.31.0", optional = true } starlette = { version = "^0.31.0", optional = true } @@ -144,7 +143,7 @@ known_first_party = ["flwr", "flwr_tool"] [tool.black] line-length = 88 -target-version = ["py38", "py39", "py310", "py311"] +target-version = ["py39", "py310", "py311"] [tool.pylint."MESSAGES CONTROL"] disable = "duplicate-code,too-few-public-methods,useless-import-alias" @@ -193,7 +192,7 @@ wrap-summaries = 88 wrap-descriptions = 88 [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 88 select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] diff --git a/src/docker/base/README.md b/src/docker/base/README.md index 197e3be2c95a..b17c3d6e5c6f 100644 --- a/src/docker/base/README.md +++ b/src/docker/base/README.md @@ -9,7 +9,7 @@ ## Quick reference - **Learn more:**
- [Flower Docs](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html) + [Quickstart with Docker](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker.html) and [Quickstart with Docker Compose](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker-compose.html) - **Where to get help:**
[Flower Discuss](https://discuss.flower.ai), [Slack](https://flower.ai/join-slack) or [GitHub](https://github.com/adap/flower) diff --git a/src/docker/clientapp/README.md b/src/docker/clientapp/README.md index 871e4e4afc25..c7975ccd762c 100644 --- a/src/docker/clientapp/README.md +++ b/src/docker/clientapp/README.md @@ -9,7 +9,7 @@ ## Quick reference - **Learn more:**
- [Flower Docs](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html) + [Quickstart with Docker](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker.html) and [Quickstart with Docker Compose](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker-compose.html) - **Where to get help:**
[Flower Discuss](https://discuss.flower.ai), [Slack](https://flower.ai/join-slack) or [GitHub](https://github.com/adap/flower) diff --git a/src/docker/serverapp/README.md b/src/docker/serverapp/README.md index 9084d5d61c1b..da49eb3596b9 100644 --- a/src/docker/serverapp/README.md +++ b/src/docker/serverapp/README.md @@ -9,7 +9,7 @@ ## Quick reference - **Learn more:**
- [Flower Docs](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html) + [Quickstart with Docker](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker.html) and [Quickstart with Docker Compose](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker-compose.html) - **Where to get help:**
[Flower Discuss](https://discuss.flower.ai), [Slack](https://flower.ai/join-slack) or [GitHub](https://github.com/adap/flower) diff --git a/src/docker/superexec/README.md b/src/docker/superexec/README.md index b7c623fb3494..ed44f24ca7ae 100644 --- a/src/docker/superexec/README.md +++ b/src/docker/superexec/README.md @@ -9,7 +9,7 @@ ## Quick reference - **Learn more:**
- [Flower Docs](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html) + [Quickstart with Docker](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker.html) and [Quickstart with Docker Compose](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker-compose.html) - **Where to get help:**
[Flower Discuss](https://discuss.flower.ai), [Slack](https://flower.ai/join-slack) or [GitHub](https://github.com/adap/flower) diff --git a/src/docker/superlink/README.md b/src/docker/superlink/README.md index 3543b2de1c00..0e20bf3d039f 100644 --- a/src/docker/superlink/README.md +++ b/src/docker/superlink/README.md @@ -9,7 +9,7 @@ ## Quick reference - **Learn more:**
- [Flower Docs](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html) + [Quickstart with Docker](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker.html) and [Quickstart with Docker Compose](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker-compose.html) - **Where to get help:**
[Flower Discuss](https://discuss.flower.ai), [Slack](https://flower.ai/join-slack) or [GitHub](https://github.com/adap/flower) diff --git a/src/docker/supernode/README.md b/src/docker/supernode/README.md index 46f598134d0a..c2d99a500da4 100644 --- a/src/docker/supernode/README.md +++ b/src/docker/supernode/README.md @@ -9,7 +9,7 @@ ## Quick reference - **Learn more:**
- [Flower Docs](https://flower.ai/docs/framework/how-to-run-flower-using-docker.html) + [Quickstart with Docker](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker.html) and [Quickstart with Docker Compose](https://flower.ai/docs/framework/docker/tutorial-quickstart-docker-compose.html) - **Where to get help:**
[Flower Discuss](https://discuss.flower.ai), [Slack](https://flower.ai/join-slack) or [GitHub](https://github.com/adap/flower) diff --git a/src/py/flwr/cli/build.py b/src/py/flwr/cli/build.py index 676bc1723568..137e2dc31aff 100644 --- a/src/py/flwr/cli/build.py +++ b/src/py/flwr/cli/build.py @@ -17,12 +17,11 @@ import os import zipfile from pathlib import Path -from typing import Optional +from typing import Annotated, Optional import pathspec import tomli_w import typer -from typing_extensions import Annotated from .config_utils import load_and_validate from .utils import get_sha256_hash, is_valid_project_name diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index 233d35a5fa17..79e4973ccf9c 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -17,7 +17,7 @@ import zipfile from io import BytesIO from pathlib import Path -from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args +from typing import IO, Any, Optional, Union, get_args import tomli @@ -25,7 +25,7 @@ from flwr.common.typing import UserConfigValue -def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]: +def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]: """Extract the config from a FAB file or path. Parameters @@ -62,7 +62,7 @@ def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]: return conf -def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]: +def get_fab_metadata(fab_file: Union[Path, bytes]) -> tuple[str, str]: """Extract the fab_id and the fab_version from a FAB file or path. Parameters @@ -87,7 +87,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]: def load_and_validate( path: Optional[Path] = None, check_module: bool = True, -) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]: +) -> tuple[Optional[dict[str, Any]], list[str], list[str]]: """Load and validate pyproject.toml as dict. Returns @@ -116,7 +116,7 @@ def load_and_validate( return (config, errors, warnings) -def load(toml_path: Path) -> Optional[Dict[str, Any]]: +def load(toml_path: Path) -> Optional[dict[str, Any]]: """Load pyproject.toml and return as dict.""" if not toml_path.is_file(): return None @@ -125,7 +125,7 @@ def load(toml_path: Path) -> Optional[Dict[str, Any]]: return load_from_string(toml_file.read()) -def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None: +def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None: for key, value in config_dict.items(): if isinstance(value, dict): _validate_run_config(config_dict[key], errors) @@ -137,7 +137,7 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None # pylint: disable=too-many-branches -def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: +def validate_fields(config: dict[str, Any]) -> tuple[bool, list[str], list[str]]: """Validate pyproject.toml fields.""" errors = [] warnings = [] @@ -183,10 +183,10 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] def validate( - config: Dict[str, Any], + config: dict[str, Any], check_module: bool = True, project_dir: Optional[Union[str, Path]] = None, -) -> Tuple[bool, List[str], List[str]]: +) -> tuple[bool, list[str], list[str]]: """Validate pyproject.toml.""" is_valid, errors, warnings = validate_fields(config) @@ -210,7 +210,7 @@ def validate( return True, [], [] -def load_from_string(toml_content: str) -> Optional[Dict[str, Any]]: +def load_from_string(toml_content: str) -> Optional[dict[str, Any]]: """Load TOML content from a string and return as dict.""" try: data = tomli.loads(toml_content) diff --git a/src/py/flwr/cli/config_utils_test.py b/src/py/flwr/cli/config_utils_test.py index cad6714521e3..ddabc152bc0f 100644 --- a/src/py/flwr/cli/config_utils_test.py +++ b/src/py/flwr/cli/config_utils_test.py @@ -17,7 +17,7 @@ import os import textwrap from pathlib import Path -from typing import Any, Dict +from typing import Any from .config_utils import load, validate, validate_fields @@ -155,7 +155,7 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None: def test_validate_pyproject_toml_fields_empty() -> None: """Test that validate_pyproject_toml_fields fails correctly.""" # Prepare - config: Dict[str, Any] = {} + config: dict[str, Any] = {} # Execute is_valid, errors, warnings = validate_fields(config) diff --git a/src/py/flwr/cli/install.py b/src/py/flwr/cli/install.py index 4318ccdf9ffb..8e3e9505898c 100644 --- a/src/py/flwr/cli/install.py +++ b/src/py/flwr/cli/install.py @@ -21,10 +21,9 @@ import zipfile from io import BytesIO from pathlib import Path -from typing import IO, Optional, Union +from typing import IO, Annotated, Optional, Union import typer -from typing_extensions import Annotated from flwr.common.config import get_flwr_dir diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 90e4970d5928..d2f7179b45b4 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -18,10 +18,9 @@ from enum import Enum from pathlib import Path from string import Template -from typing import Dict, Optional +from typing import Annotated, Optional import typer -from typing_extensions import Annotated from ..utils import ( is_valid_project_name, @@ -70,7 +69,7 @@ def load_template(name: str) -> str: return tpl_file.read() -def render_template(template: str, data: Dict[str, str]) -> str: +def render_template(template: str, data: dict[str, str]) -> str: """Render template.""" tpl_file = load_template(template) tpl = Template(tpl_file) @@ -85,7 +84,7 @@ def create_file(file_path: Path, content: str) -> None: file_path.write_text(content) -def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None: +def render_and_create(file_path: Path, template: str, context: dict[str, str]) -> None: """Render template and write to file.""" content = render_template(template, context) create_file(file_path, content) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 6375e71522de..905055ac70c0 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -20,10 +20,9 @@ import sys from logging import DEBUG from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Optional import typer -from typing_extensions import Annotated from flwr.cli.build import build from flwr.cli.config_utils import load_and_validate @@ -52,7 +51,7 @@ def run( typer.Argument(help="Name of the federation to run the app on."), ] = None, config_overrides: Annotated[ - Optional[List[str]], + Optional[list[str]], typer.Option( "--run-config", "-c", @@ -125,8 +124,8 @@ def run( def _run_with_superexec( app: Path, - federation_config: Dict[str, Any], - config_overrides: Optional[List[str]], + federation_config: dict[str, Any], + config_overrides: Optional[list[str]], ) -> None: insecure_str = federation_config.get("insecure") @@ -187,8 +186,8 @@ def _run_with_superexec( def _run_without_superexec( app: Optional[Path], - federation_config: Dict[str, Any], - config_overrides: Optional[List[str]], + federation_config: dict[str, Any], + config_overrides: Optional[list[str]], federation: str, ) -> None: try: diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index 2f5a8831fa7c..e725fdd3f951 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -17,7 +17,7 @@ import hashlib import re from pathlib import Path -from typing import Callable, List, Optional, cast +from typing import Callable, Optional, cast import typer @@ -40,7 +40,7 @@ def prompt_text( return cast(str, result) -def prompt_options(text: str, options: List[str]) -> str: +def prompt_options(text: str, options: list[str]) -> str: """Ask user to select one of the given options and return the selected item.""" # Turn options into a list with index as in " [ 0] quickstart-pytorch" options_formatted = [ diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 78db5639ff0f..90c50aba7fad 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,10 +18,11 @@ import subprocess import sys import time +from contextlib import AbstractContextManager from dataclasses import dataclass from logging import ERROR, INFO, WARN from pathlib import Path -from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union, cast +from typing import Callable, Optional, Union, cast import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -94,7 +95,7 @@ def start_client( insecure: Optional[bool] = None, transport: Optional[str] = None, authentication_keys: Optional[ - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, @@ -204,7 +205,7 @@ def start_client_internal( insecure: Optional[bool] = None, transport: Optional[str] = None, authentication_keys: Optional[ - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, @@ -356,7 +357,7 @@ def _on_backoff(retry_state: RetryState) -> None: # NodeState gets initialized when the first connection is established node_state: Optional[NodeState] = None - runs: Dict[int, Run] = {} + runs: dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -689,7 +690,7 @@ def start_numpy_client( ) -def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ +def _init_connection(transport: Optional[str], server_address: str) -> tuple[ Callable[ [ str, @@ -697,10 +698,10 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ RetryInvoker, int, Union[bytes, str, None], - Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]], + Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]], ], - ContextManager[ - Tuple[ + AbstractContextManager[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], @@ -711,7 +712,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ ], ], str, - Type[Exception], + type[Exception], ]: # Parse IP address parsed_address = parse_address(server_address) @@ -769,7 +770,7 @@ def signal_handler(sig, frame): # type: ignore signal.signal(signal.SIGTERM, signal_handler) -def run_clientappio_api_grpc(address: str) -> Tuple[grpc.Server, ClientAppIoServicer]: +def run_clientappio_api_grpc(address: str) -> tuple[grpc.Server, ClientAppIoServicer]: """Run ClientAppIo API gRPC server.""" clientappio_servicer: grpc.Server = ClientAppIoServicer() clientappio_add_servicer_to_server_fn = add_ClientAppIoServicer_to_server diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 74ade03f973a..723a066ea0bc 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -15,8 +15,6 @@ """Flower Client app tests.""" -from typing import Dict, Tuple - from flwr.common import ( Config, EvaluateIns, @@ -59,7 +57,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: class NeedsWrappingClient(NumPyClient): """Client implementation extending the high-level NumPyClient.""" - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Raise an Exception because this method is not expected to be called.""" raise NotImplementedError() @@ -69,13 +67,13 @@ def get_parameters(self, config: Config) -> NDArrays: def fit( self, parameters: NDArrays, config: Config - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" raise NotImplementedError() def evaluate( self, parameters: NDArrays, config: Config - ) -> Tuple[float, int, Dict[str, Scalar]]: + ) -> tuple[float, int, dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" raise NotImplementedError() diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index c322ba747114..234d84f27782 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -16,7 +16,7 @@ import inspect -from typing import Callable, List, Optional +from typing import Callable, Optional from flwr.client.client import Client from flwr.client.message_handler.message_handler import ( @@ -109,9 +109,9 @@ class ClientApp: def __init__( self, client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility - mods: Optional[List[Mod]] = None, + mods: Optional[list[Mod]] = None, ) -> None: - self._mods: List[Mod] = mods if mods is not None else [] + self._mods: list[Mod] = mods if mods is not None else [] # Create wrapper function for `handle` self._call: Optional[ClientAppCallable] = None diff --git a/src/py/flwr/client/clientapp/app.py b/src/py/flwr/client/clientapp/app.py index 69d334fead14..f493128bebac 100644 --- a/src/py/flwr/client/clientapp/app.py +++ b/src/py/flwr/client/clientapp/app.py @@ -17,7 +17,7 @@ import argparse import time from logging import DEBUG, ERROR, INFO -from typing import Optional, Tuple +from typing import Optional import grpc @@ -196,7 +196,7 @@ def get_token(stub: grpc.Channel) -> Optional[int]: def pull_message( stub: grpc.Channel, token: int -) -> Tuple[Message, Context, Run, Optional[Fab]]: +) -> tuple[Message, Context, Run, Optional[Fab]]: """Pull message from SuperNode to ClientApp.""" log(INFO, "Pulling ClientAppInputs for token %s", token) try: diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index c592d10936d5..bade811b48ce 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -16,7 +16,6 @@ import copy -from typing import Dict, Tuple import numpy as np @@ -39,7 +38,7 @@ def __init__(self, client: NumPyClient) -> None: super().__init__() self.client = client - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Get client properties using the given Numpy client. Parameters @@ -58,7 +57,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: """ return self.client.get_properties(config) - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + def get_parameters(self, config: dict[str, Scalar]) -> NDArrays: """Return the current local model parameters. Parameters @@ -76,8 +75,8 @@ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: return self.client.get_parameters(config) def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Train the provided parameters using the locally held dataset. This method first updates the local model using the original parameters @@ -153,8 +152,8 @@ def fit( return updated_params, num_examples, metrics def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[float, int, dict[str, Scalar]]: """Evaluate the provided parameters using the locally held dataset. Parameters diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index f9f7b1043524..9b84545eacdb 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -15,9 +15,10 @@ """Contextmanager for a GrpcAdapter channel to the Flower server.""" +from collections.abc import Iterator from contextlib import contextmanager from logging import ERROR -from typing import Callable, Iterator, Optional, Tuple, Union +from typing import Callable, Optional, Union from cryptography.hazmat.primitives.asymmetric import ec @@ -38,10 +39,10 @@ def grpc_adapter( # pylint: disable=R0913 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ # pylint: disable=unused-argument - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 489891f55436..29479cf5479d 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -16,11 +16,12 @@ import uuid +from collections.abc import Iterator from contextlib import contextmanager from logging import DEBUG, ERROR from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Tuple, Union, cast +from typing import Callable, Optional, Union, cast from cryptography.hazmat.primitives.asymmetric import ec @@ -66,10 +67,10 @@ def grpc_connection( # pylint: disable=R0913, R0915 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ # pylint: disable=unused-argument - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index bd377ef3470a..13bd2c6af8e7 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -17,8 +17,9 @@ import concurrent.futures import socket +from collections.abc import Iterator from contextlib import closing -from typing import Iterator, cast +from typing import cast from unittest.mock import patch import grpc diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 8e8b701ca272..653e384aff96 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -17,8 +17,9 @@ import base64 import collections +from collections.abc import Sequence from logging import WARNING -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -53,7 +54,7 @@ def _get_value_from_tuples( - key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] + key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] ) -> bytes: value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index 72ac20738ad6..27f759a71713 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -18,9 +18,10 @@ import base64 import threading import unittest +from collections.abc import Sequence from concurrent import futures from logging import DEBUG, INFO, WARN -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Union import grpc @@ -60,7 +61,7 @@ def __init__(self) -> None: """Initialize mock servicer.""" self._lock = threading.Lock() self._received_client_metadata: Optional[ - Sequence[Tuple[str, Union[str, bytes]]] + Sequence[tuple[str, Union[str, bytes]]] ] = None self.server_private_key, self.server_public_key = generate_key_pairs() self._received_message_bytes: bytes = b"" @@ -105,7 +106,7 @@ def unary_unary( def received_client_metadata( self, - ) -> Optional[Sequence[Tuple[str, Union[str, bytes]]]]: + ) -> Optional[Sequence[tuple[str, Union[str, bytes]]]]: """Return received client metadata.""" with self._lock: return self._received_client_metadata @@ -151,7 +152,7 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: def _get_value_from_tuples( - key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] + key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] ) -> bytes: value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 8bae253c819a..7ce3d37b7a17 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -17,11 +17,12 @@ import random import threading +from collections.abc import Iterator, Sequence from contextlib import contextmanager from copy import copy from logging import DEBUG, ERROR from pathlib import Path -from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast +from typing import Callable, Optional, Union, cast import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -77,11 +78,11 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, - adapter_cls: Optional[Union[Type[FleetStub], Type[GrpcAdapter]]] = None, + adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index fde03943a852..3dce14c14956 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -17,7 +17,7 @@ import sys from logging import DEBUG -from typing import Any, Type, TypeVar, cast +from typing import Any, TypeVar, cast import grpc from google.protobuf.message import Message as GrpcMessage @@ -59,7 +59,7 @@ def __init__(self, channel: grpc.Channel) -> None: self.stub = GrpcAdapterStub(channel) def _send_and_receive( - self, request: GrpcMessage, response_type: Type[T], **kwargs: Any + self, request: GrpcMessage, response_type: type[T], **kwargs: Any ) -> T: # Serialize request container_req = MessageContainer( diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 1ab84eb01468..765c6a6b2e91 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -15,7 +15,7 @@ """Client-side message handler.""" from logging import WARN -from typing import Optional, Tuple, cast +from typing import Optional, cast from flwr.client.client import ( maybe_call_evaluate, @@ -52,7 +52,7 @@ class UnknownServerMessage(Exception): """Exception indicating that the received message is unknown.""" -def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: +def handle_control_message(message: Message) -> tuple[Optional[Message], int]: """Handle control part of the incoming message. Parameters @@ -147,7 +147,7 @@ def handle_legacy_message_from_msgtype( def _reconnect( reconnect_msg: ServerMessage.ReconnectIns, -) -> Tuple[ClientMessage, int]: +) -> tuple[ClientMessage, int]: # Determine the reason for sending DisconnectRes message reason = Reason.ACK sleep_duration = None diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 557d61ffb32a..311f8c37e1b1 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -19,7 +19,6 @@ import unittest import uuid from copy import copy -from typing import List from flwr.client import Client from flwr.client.typing import ClientFnExt @@ -294,7 +293,7 @@ def test_invalid_message_run_id(self) -> None: msg = Message(metadata=self.valid_out_metadata, content=RecordSet()) # Execute - invalid_metadata_list: List[Metadata] = [] + invalid_metadata_list: list[Metadata] = [] attrs = list(vars(self.valid_out_metadata).keys()) for attr in attrs: if attr == "_partition_id": diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 5b196ad84321..f9d3c433157d 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field from logging import DEBUG, WARNING -from typing import Any, Dict, List, Tuple, cast +from typing import Any, cast from flwr.client.typing import ClientAppCallable from flwr.common import ( @@ -91,11 +91,11 @@ class SecAggPlusState: # Random seed for generating the private mask rd_seed: bytes = b"" - rd_seed_share_dict: Dict[int, bytes] = field(default_factory=dict) - sk1_share_dict: Dict[int, bytes] = field(default_factory=dict) + rd_seed_share_dict: dict[int, bytes] = field(default_factory=dict) + sk1_share_dict: dict[int, bytes] = field(default_factory=dict) # The dict of the shared secrets from sk2 - ss2_dict: Dict[int, bytes] = field(default_factory=dict) - public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) + ss2_dict: dict[int, bytes] = field(default_factory=dict) + public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict) def __init__(self, **kwargs: ConfigsRecordValues) -> None: for k, v in kwargs.items(): @@ -104,8 +104,8 @@ def __init__(self, **kwargs: ConfigsRecordValues) -> None: new_v: Any = v if k.endswith(":K"): k = k[:-2] - keys = cast(List[int], v) - values = cast(List[bytes], kwargs[f"{k}:V"]) + keys = cast(list[int], v) + values = cast(list[bytes], kwargs[f"{k}:V"]) if len(values) > len(keys): updated_values = [ tuple(values[i : i + 2]) for i in range(0, len(values), 2) @@ -115,17 +115,17 @@ def __init__(self, **kwargs: ConfigsRecordValues) -> None: new_v = dict(zip(keys, values)) self.__setattr__(k, new_v) - def to_dict(self) -> Dict[str, ConfigsRecordValues]: + def to_dict(self) -> dict[str, ConfigsRecordValues]: """Convert the state to a dictionary.""" ret = vars(self) for k in list(ret.keys()): if isinstance(ret[k], dict): # Replace dict with two lists - v = cast(Dict[str, Any], ret.pop(k)) + v = cast(dict[str, Any], ret.pop(k)) ret[f"{k}:K"] = list(v.keys()) if k == "public_keys_dict": - v_list: List[bytes] = [] - for b1_b2 in cast(List[Tuple[bytes, bytes]], v.values()): + v_list: list[bytes] = [] + for b1_b2 in cast(list[tuple[bytes, bytes]], v.values()): v_list.extend(b1_b2) ret[f"{k}:V"] = v_list else: @@ -276,7 +276,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: ) if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], configs[key]) + for elm in cast(list[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -299,7 +299,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: ) if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], configs[key]) + for elm in cast(list[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -314,7 +314,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: def _setup( state: SecAggPlusState, configs: ConfigsRecord -) -> Dict[str, ConfigsRecordValues]: +) -> dict[str, ConfigsRecordValues]: # Assigning parameter values to object fields sec_agg_param_dict = configs state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER]) @@ -350,8 +350,8 @@ def _setup( # pylint: disable-next=too-many-locals def _share_keys( state: SecAggPlusState, configs: ConfigsRecord -) -> Dict[str, ConfigsRecordValues]: - named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) +) -> dict[str, ConfigsRecordValues]: + named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} log(DEBUG, "Node %d: starting stage 1...", state.nid) state.public_keys_dict = key_dict @@ -361,7 +361,7 @@ def _share_keys( raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique - pk_list: List[bytes] = [] + pk_list: list[bytes] = [] for pk1, pk2 in state.public_keys_dict.values(): pk_list.append(pk1) pk_list.append(pk2) @@ -415,11 +415,11 @@ def _collect_masked_vectors( configs: ConfigsRecord, num_examples: int, updated_parameters: Parameters, -) -> Dict[str, ConfigsRecordValues]: +) -> dict[str, ConfigsRecordValues]: log(DEBUG, "Node %d: starting stage 2...", state.nid) - available_clients: List[int] = [] - ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST]) - srcs = cast(List[int], configs[Key.SOURCE_LIST]) + available_clients: list[int] = [] + ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST]) + srcs = cast(list[int], configs[Key.SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: raise ValueError("Not enough available neighbour clients.") @@ -467,7 +467,7 @@ def _collect_masked_vectors( quantized_parameters = factor_combine(q_ratio, quantized_parameters) - dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters] + dimensions_list: list[tuple[int, ...]] = [a.shape for a in quantized_parameters] # Add private mask private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list) @@ -499,11 +499,11 @@ def _collect_masked_vectors( def _unmask( state: SecAggPlusState, configs: ConfigsRecord -) -> Dict[str, ConfigsRecordValues]: +) -> dict[str, ConfigsRecordValues]: log(DEBUG, "Node %d: starting stage 3...", state.nid) - active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST]) - dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST]) + active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST]) + dead_nids = cast(list[int], configs[Key.DEAD_NODE_ID_LIST]) # Send private mask seed share for every avaliable client (including itself) # Send first private key share for building pairwise mask for every dropped client if len(active_nids) < state.threshold: diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 2832576fb4fc..e68bf5177797 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -16,7 +16,7 @@ import unittest from itertools import product -from typing import Callable, Dict, List +from typing import Callable from flwr.client.mod import make_ffn from flwr.common import ( @@ -41,7 +41,7 @@ def get_test_handler( ctxt: Context, -) -> Callable[[Dict[str, ConfigsRecordValues]], ConfigsRecord]: +) -> Callable[[dict[str, ConfigsRecordValues]], ConfigsRecord]: """.""" def empty_ffn(_msg: Message, _2: Context) -> Message: @@ -49,7 +49,7 @@ def empty_ffn(_msg: Message, _2: Context) -> Message: app = make_ffn(empty_ffn, [secaggplus_mod]) - def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: + def func(configs: dict[str, ConfigsRecordValues]) -> ConfigsRecord: in_msg = Message( metadata=Metadata( run_id=0, @@ -158,7 +158,7 @@ def test_stage_setup_check(self) -> None: (Key.MOD_RANGE, int), ] - type_to_test_value: Dict[type, ConfigsRecordValues] = { + type_to_test_value: dict[type, ConfigsRecordValues] = { int: 10, bool: True, float: 1.0, @@ -166,7 +166,7 @@ def test_stage_setup_check(self) -> None: bytes: b"test", } - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { key: type_to_test_value[value_type] for key, value_type in valid_key_type_pairs } @@ -208,7 +208,7 @@ def test_stage_share_keys_check(self) -> None: handler = get_test_handler(ctxt) set_stage = _make_set_state_fn(ctxt) - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { "1": [b"public key 1", b"public key 2"], "2": [b"public key 1", b"public key 2"], "3": [b"public key 1", b"public key 2"], @@ -225,7 +225,7 @@ def test_stage_share_keys_check(self) -> None: valid_configs[Key.STAGE] = Stage.SHARE_KEYS # Test invalid configs - invalid_values: List[ConfigsRecordValues] = [ + invalid_values: list[ConfigsRecordValues] = [ b"public key 1", [b"public key 1"], [b"public key 1", b"public key 2", b"public key 3"], @@ -245,7 +245,7 @@ def test_stage_collect_masked_vectors_check(self) -> None: handler = get_test_handler(ctxt) set_stage = _make_set_state_fn(ctxt) - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { Key.CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], Key.SOURCE_LIST: [32, 51324, 32324123, -3], } @@ -289,7 +289,7 @@ def test_stage_unmask_check(self) -> None: handler = get_test_handler(ctxt) set_stage = _make_set_state_fn(ctxt) - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { Key.ACTIVE_NODE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], Key.DEAD_NODE_ID_LIST: [32, 51324, 32324123, -3], } diff --git a/src/py/flwr/client/mod/utils.py b/src/py/flwr/client/mod/utils.py index c8fb21379783..c76902cf263f 100644 --- a/src/py/flwr/client/mod/utils.py +++ b/src/py/flwr/client/mod/utils.py @@ -15,13 +15,11 @@ """Utility functions for mods.""" -from typing import List - from flwr.client.typing import ClientAppCallable, Mod from flwr.common import Context, Message -def make_ffn(ffn: ClientAppCallable, mods: List[Mod]) -> ClientAppCallable: +def make_ffn(ffn: ClientAppCallable, mods: list[Mod]) -> ClientAppCallable: """.""" def wrap_ffn(_ffn: ClientAppCallable, _mod: Mod) -> ClientAppCallable: diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index a5bbd0a0bb4d..e75fb5530b2c 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -16,7 +16,7 @@ import unittest -from typing import List, cast +from typing import cast from flwr.client.typing import ClientAppCallable, Mod from flwr.common import ( @@ -43,7 +43,7 @@ def _increment_context_counter(context: Context) -> None: context.state.metrics_records[METRIC] = MetricsRecord({COUNTER: current_counter}) -def make_mock_mod(name: str, footprint: List[str]) -> Mod: +def make_mock_mod(name: str, footprint: list[str]) -> Mod: """Make a mock mod.""" def mod(message: Message, context: Context, app: ClientAppCallable) -> Message: @@ -61,7 +61,7 @@ def mod(message: Message, context: Context, app: ClientAppCallable) -> Message: return mod -def make_mock_app(name: str, footprint: List[str]) -> ClientAppCallable: +def make_mock_app(name: str, footprint: list[str]) -> ClientAppCallable: """Make a mock app.""" def app(message: Message, context: Context) -> Message: @@ -97,7 +97,7 @@ class TestMakeApp(unittest.TestCase): def test_multiple_mods(self) -> None: """Test if multiple mods are called in the correct order.""" # Prepare - footprint: List[str] = [] + footprint: list[str] = [] mock_app = make_mock_app("app", footprint) mock_mod_names = [f"mod{i}" for i in range(1, 15)] mock_mods = [make_mock_mod(name, footprint) for name in mock_mod_names] @@ -127,7 +127,7 @@ def test_multiple_mods(self) -> None: def test_filter(self) -> None: """Test if a mod can filter incoming TaskIns.""" # Prepare - footprint: List[str] = [] + footprint: list[str] = [] mock_app = make_mock_app("app", footprint) context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) message = _get_dummy_flower_message() diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index e16d7e34715d..e7967dfc8bee 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Dict, Optional +from typing import Optional from flwr.common import Context, RecordSet from flwr.common.config import ( @@ -46,7 +46,7 @@ def __init__( ) -> None: self.node_id = node_id self.node_config = node_config - self.run_infos: Dict[int, RunInfo] = {} + self.run_infos: dict[int, RunInfo] = {} # pylint: disable=too-many-arguments def register_context( diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index b21a51b38e9b..6a656cb661d2 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -16,7 +16,7 @@ from abc import ABC -from typing import Callable, Dict, Tuple +from typing import Callable from flwr.client.client import Client from flwr.common import ( @@ -73,7 +73,7 @@ class NumPyClient(ABC): _context: Context - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return a client's set of properties. Parameters @@ -93,7 +93,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: _ = (self, config) return {} - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + def get_parameters(self, config: dict[str, Scalar]) -> NDArrays: """Return the current local model parameters. Parameters @@ -112,8 +112,8 @@ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: return [] def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Train the provided parameters using the locally held dataset. Parameters @@ -141,8 +141,8 @@ def fit( return [], 0, {} def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[float, int, dict[str, Scalar]]: """Evaluate the provided parameters using the locally held dataset. Parameters @@ -310,7 +310,7 @@ def _set_context(self: Client, context: Context) -> None: def _wrap_numpy_client(client: NumPyClient) -> Client: - member_dict: Dict[str, Callable] = { # type: ignore + member_dict: dict[str, Callable] = { # type: ignore "__init__": _constructor, "get_context": _get_context, "set_context": _set_context, diff --git a/src/py/flwr/client/numpy_client_test.py b/src/py/flwr/client/numpy_client_test.py index 06a0deafe2c9..c5d520a73ce1 100644 --- a/src/py/flwr/client/numpy_client_test.py +++ b/src/py/flwr/client/numpy_client_test.py @@ -15,8 +15,6 @@ """Flower NumPyClient tests.""" -from typing import Dict, Tuple - from flwr.common import Config, NDArrays, Properties, Scalar from .numpy_client import ( @@ -40,14 +38,14 @@ def get_parameters(self, config: Config) -> NDArrays: return [] def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Simulate training by returning empty weights, 0 samples, empty metrics.""" return [], 0, {} def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[float, int, dict[str, Scalar]]: """Simulate evaluate by returning 0.0 loss, 0 samples, empty metrics.""" return 0.0, 0, {} diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d5f005fbaf77..72b6be25a708 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -18,10 +18,11 @@ import random import sys import threading +from collections.abc import Iterator from contextlib import contextmanager from copy import copy from logging import ERROR, INFO, WARN -from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union from cryptography.hazmat.primitives.asymmetric import ec from google.protobuf.message import Message as GrpcMessage @@ -90,10 +91,10 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Union[bytes, str] ] = None, # pylint: disable=unused-argument authentication_keys: Optional[ # pylint: disable=unused-argument - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], @@ -173,7 +174,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 ########################################################################### def _request( - req: GrpcMessage, res_type: Type[T], api_path: str, retry: bool = True + req: GrpcMessage, res_type: type[T], api_path: str, retry: bool = True ) -> Optional[T]: # Serialize the request req_bytes = req.SerializeToString() diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 425c7f7133a4..d9af001bba53 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -18,7 +18,7 @@ import sys from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path -from typing import Optional, Tuple +from typing import Optional from cryptography.exceptions import UnsupportedAlgorithm from cryptography.hazmat.primitives.asymmetric import ec @@ -291,7 +291,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: def _try_setup_client_authentication( args: argparse.Namespace, -) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: +) -> Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: if not args.auth_supernode_private_key and not args.auth_supernode_public_key: return None diff --git a/src/py/flwr/common/address.py b/src/py/flwr/common/address.py index 7a70925c0fc9..2b10097ccb71 100644 --- a/src/py/flwr/common/address.py +++ b/src/py/flwr/common/address.py @@ -16,12 +16,12 @@ import socket from ipaddress import ip_address -from typing import Optional, Tuple +from typing import Optional IPV6: int = 6 -def parse_address(address: str) -> Optional[Tuple[str, int, Optional[bool]]]: +def parse_address(address: str) -> Optional[tuple[str, int, Optional[bool]]]: """Parse an IP address into host, port, and version. Parameters diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 42039fa959ac..071d41a3ab5e 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -17,7 +17,7 @@ import os import re from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args +from typing import Any, Optional, Union, cast, get_args import tomli @@ -53,7 +53,7 @@ def get_project_dir( return Path(flwr_dir) / APP_DIR / publisher / project_name / fab_version -def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: +def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]: """Return pyproject.toml in the given project directory.""" # Load pyproject.toml file toml_path = Path(project_dir) / FAB_CONFIG_FILE @@ -137,13 +137,13 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig: def flatten_dict( - raw_dict: Optional[Dict[str, Any]], parent_key: str = "" + raw_dict: Optional[dict[str, Any]], parent_key: str = "" ) -> UserConfig: """Flatten dict by joining nested keys with a given separator.""" if raw_dict is None: return {} - items: List[Tuple[str, UserConfigValue]] = [] + items: list[tuple[str, UserConfigValue]] = [] separator: str = "." for k, v in raw_dict.items(): new_key = f"{parent_key}{separator}{k}" if parent_key else k @@ -159,9 +159,9 @@ def flatten_dict( return dict(items) -def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: +def unflatten_dict(flat_dict: dict[str, Any]) -> dict[str, Any]: """Unflatten a dict with keys containing separators into a nested dict.""" - unflattened_dict: Dict[str, Any] = {} + unflattened_dict: dict[str, Any] = {} separator: str = "." for key, value in flat_dict.items(): @@ -177,7 +177,7 @@ def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: def parse_config_args( - config: Optional[List[str]], + config: Optional[list[str]], ) -> UserConfig: """Parse separator separated list of key-value pairs separated by '='.""" overrides: UserConfig = {} @@ -209,7 +209,7 @@ def parse_config_args( return overrides -def get_metadata_from_config(config: Dict[str, Any]) -> Tuple[str, str]: +def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]: """Extract `fab_version` and `fab_id` from a project config.""" return ( config["project"]["version"], diff --git a/src/py/flwr/common/differential_privacy.py b/src/py/flwr/common/differential_privacy.py index 85dc198ef8a0..56da98a3c805 100644 --- a/src/py/flwr/common/differential_privacy.py +++ b/src/py/flwr/common/differential_privacy.py @@ -16,7 +16,7 @@ from logging import WARNING -from typing import Optional, Tuple +from typing import Optional import numpy as np @@ -125,7 +125,7 @@ def compute_adaptive_noise_params( noise_multiplier: float, num_sampled_clients: float, clipped_count_stddev: Optional[float], -) -> Tuple[float, float]: +) -> tuple[float, float]: """Compute noising parameters for the adaptive clipping. Paper: https://arxiv.org/abs/1905.03871 diff --git a/src/py/flwr/common/dp.py b/src/py/flwr/common/dp.py index 527805c8ef42..13ae94461ef9 100644 --- a/src/py/flwr/common/dp.py +++ b/src/py/flwr/common/dp.py @@ -15,8 +15,6 @@ """Building block functions for DP algorithms.""" -from typing import Tuple - import numpy as np from flwr.common.logger import warn_deprecated_feature @@ -41,7 +39,7 @@ def add_gaussian_noise(update: NDArrays, std_dev: float) -> NDArrays: return update_noised -def clip_by_l2(update: NDArrays, threshold: float) -> Tuple[NDArrays, bool]: +def clip_by_l2(update: NDArrays, threshold: float) -> tuple[NDArrays, bool]: """Scales the update so thats its L2 norm is upper-bound to threshold.""" warn_deprecated_feature("`clip_by_l2` method") update_norm = _get_update_norm(update) diff --git a/src/py/flwr/common/exit_handlers.py b/src/py/flwr/common/exit_handlers.py index 30750c28a450..e5898b46a537 100644 --- a/src/py/flwr/common/exit_handlers.py +++ b/src/py/flwr/common/exit_handlers.py @@ -19,7 +19,7 @@ from signal import SIGINT, SIGTERM, signal from threading import Thread from types import FrameType -from typing import List, Optional +from typing import Optional from grpc import Server @@ -28,8 +28,8 @@ def register_exit_handlers( event_type: EventType, - grpc_servers: Optional[List[Server]] = None, - bckg_threads: Optional[List[Thread]] = None, + grpc_servers: Optional[list[Server]] = None, + bckg_threads: Optional[list[Thread]] = None, ) -> None: """Register exit handlers for `SIGINT` and `SIGTERM` signals. diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index ec8fe823a7eb..5a29c595119c 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -15,8 +15,9 @@ """Utility functions for gRPC.""" +from collections.abc import Sequence from logging import DEBUG -from typing import Optional, Sequence +from typing import Optional import grpc diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 2077f9beaca0..303780fc0b5d 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -18,7 +18,7 @@ import logging from logging import WARN, LogRecord from logging.handlers import HTTPHandler -from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple +from typing import TYPE_CHECKING, Any, Optional, TextIO # Create logger LOGGER_NAME = "flwr" @@ -119,12 +119,12 @@ def __init__( url: str, method: str = "GET", secure: bool = False, - credentials: Optional[Tuple[str, str]] = None, + credentials: Optional[tuple[str, str]] = None, ) -> None: super().__init__(host, url, method, secure, credentials) self.identifier = identifier - def mapLogRecord(self, record: LogRecord) -> Dict[str, Any]: + def mapLogRecord(self, record: LogRecord) -> dict[str, Any]: """Filter for the properties to be send to the logserver.""" record_dict = record.__dict__ return { diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py index c6142cb18256..57c57eb41bd9 100644 --- a/src/py/flwr/common/message_test.py +++ b/src/py/flwr/common/message_test.py @@ -17,7 +17,7 @@ import time from collections import namedtuple from contextlib import ExitStack -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import pytest @@ -193,7 +193,7 @@ def test_create_reply( ), ], ) -def test_repr(cls: type, kwargs: Dict[str, Any]) -> None: +def test_repr(cls: type, kwargs: dict[str, Any]) -> None: """Test string representations of Metadata/Message/Error.""" # Prepare anon_cls = namedtuple(cls.__qualname__, kwargs.keys()) # type: ignore diff --git a/src/py/flwr/common/object_ref.py b/src/py/flwr/common/object_ref.py index 9723c14037a0..6259b5ab557d 100644 --- a/src/py/flwr/common/object_ref.py +++ b/src/py/flwr/common/object_ref.py @@ -21,7 +21,7 @@ from importlib.util import find_spec from logging import WARN from pathlib import Path -from typing import Any, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from .logger import log @@ -40,7 +40,7 @@ def validate( module_attribute_str: str, check_module: bool = True, project_dir: Optional[Union[str, Path]] = None, -) -> Tuple[bool, Optional[str]]: +) -> tuple[bool, Optional[str]]: """Validate object reference. Parameters @@ -106,7 +106,7 @@ def validate( def load_app( # pylint: disable= too-many-branches module_attribute_str: str, - error_type: Type[Exception], + error_type: type[Exception], project_dir: Optional[Union[str, Path]] = None, ) -> Any: """Return the object specified in a module attribute string. diff --git a/src/py/flwr/common/record/configsrecord.py b/src/py/flwr/common/record/configsrecord.py index aeb311089bcd..f570e000cc9b 100644 --- a/src/py/flwr/common/record/configsrecord.py +++ b/src/py/flwr/common/record/configsrecord.py @@ -15,7 +15,7 @@ """ConfigsRecord.""" -from typing import Dict, List, Optional, get_args +from typing import Optional, get_args from flwr.common.typing import ConfigsRecordValues, ConfigsScalar @@ -109,7 +109,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]): def __init__( self, - configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, + configs_dict: Optional[dict[str, ConfigsRecordValues]] = None, keep_input: bool = True, ) -> None: @@ -141,7 +141,7 @@ def get_var_bytes(value: ConfigsScalar) -> int: num_bytes = 0 for k, v in self.items(): - if isinstance(v, List): + if isinstance(v, list): if isinstance(v[0], (bytes, str)): # not all str are of equal length necessarily # for both the footprint of each element is 1 Byte diff --git a/src/py/flwr/common/record/metricsrecord.py b/src/py/flwr/common/record/metricsrecord.py index 868ed82e79ca..d0a6123c807f 100644 --- a/src/py/flwr/common/record/metricsrecord.py +++ b/src/py/flwr/common/record/metricsrecord.py @@ -15,7 +15,7 @@ """MetricsRecord.""" -from typing import Dict, List, Optional, get_args +from typing import Optional, get_args from flwr.common.typing import MetricsRecordValues, MetricsScalar @@ -115,7 +115,7 @@ class MetricsRecord(TypedDict[str, MetricsRecordValues]): def __init__( self, - metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + metrics_dict: Optional[dict[str, MetricsRecordValues]] = None, keep_input: bool = True, ): super().__init__(_check_key, _check_value) @@ -130,7 +130,7 @@ def count_bytes(self) -> int: num_bytes = 0 for k, v in self.items(): - if isinstance(v, List): + if isinstance(v, list): # both int and float normally take 4 bytes # But MetricRecords are mapped to 64bit int/float # during protobuffing diff --git a/src/py/flwr/common/record/parametersrecord.py b/src/py/flwr/common/record/parametersrecord.py index f088d682497b..10ec65ca0277 100644 --- a/src/py/flwr/common/record/parametersrecord.py +++ b/src/py/flwr/common/record/parametersrecord.py @@ -14,9 +14,10 @@ # ============================================================================== """ParametersRecord and Array.""" +from collections import OrderedDict from dataclasses import dataclass from io import BytesIO -from typing import List, Optional, OrderedDict, cast +from typing import Optional, cast import numpy as np @@ -51,7 +52,7 @@ class Array: """ dtype: str - shape: List[int] + shape: list[int] stype: str data: bytes diff --git a/src/py/flwr/common/record/parametersrecord_test.py b/src/py/flwr/common/record/parametersrecord_test.py index e840e5e266e4..9ac18a3ec854 100644 --- a/src/py/flwr/common/record/parametersrecord_test.py +++ b/src/py/flwr/common/record/parametersrecord_test.py @@ -17,7 +17,6 @@ import unittest from collections import OrderedDict from io import BytesIO -from typing import List import numpy as np import pytest @@ -81,7 +80,7 @@ def test_numpy_conversion_invalid(self) -> None: ([31, 153], "bool_"), # bool_ is represented as a whole Byte in NumPy ], ) -def test_count_bytes(shape: List[int], dtype: str) -> None: +def test_count_bytes(shape: list[int], dtype: str) -> None: """Test bytes in a ParametersRecord are computed correctly.""" original_array = np.random.randn(*shape).astype(np.dtype(dtype)) diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 96556d335f4c..154e320e5f0b 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -15,9 +15,9 @@ """RecordSet tests.""" import pickle -from collections import namedtuple +from collections import OrderedDict, namedtuple from copy import deepcopy -from typing import Callable, Dict, List, OrderedDict, Type, Union +from typing import Callable, Union import numpy as np import pytest @@ -158,8 +158,8 @@ def test_set_parameters_with_correct_types() -> None: ], ) def test_set_parameters_with_incorrect_types( - key_type: Type[Union[int, str]], - value_fn: Callable[[NDArray], Union[NDArray, List[float]]], + key_type: type[Union[int, str]], + value_fn: Callable[[NDArray], Union[NDArray, list[float]]], ) -> None: """Test adding dictionary of unsupported types to ParametersRecord.""" p_record = ParametersRecord() @@ -183,7 +183,7 @@ def test_set_parameters_with_incorrect_types( ], ) def test_set_metrics_to_metricsrecord_with_correct_types( - key_type: Type[str], + key_type: type[str], value_fn: Callable[[NDArray], MetricsRecordValues], ) -> None: """Test adding metrics of various types to a MetricsRecord.""" @@ -236,8 +236,8 @@ def test_set_metrics_to_metricsrecord_with_correct_types( ], ) def test_set_metrics_to_metricsrecord_with_incorrect_types( - key_type: Type[Union[str, int, float, bool]], - value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], + key_type: type[Union[str, int, float, bool]], + value_fn: Callable[[NDArray], Union[NDArray, dict[str, NDArray], list[float]]], ) -> None: """Test adding metrics of various unsupported types to a MetricsRecord.""" m_record = MetricsRecord() @@ -302,7 +302,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( ], ) def test_set_configs_to_configsrecord_with_correct_types( - key_type: Type[str], + key_type: type[str], value_fn: Callable[[NDArray], ConfigsRecordValues], ) -> None: """Test adding configs of various types to a ConfigsRecord.""" @@ -346,8 +346,8 @@ def test_set_configs_to_configsrecord_with_correct_types( ], ) def test_set_configs_to_configsrecord_with_incorrect_types( - key_type: Type[Union[str, int, float]], - value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], + key_type: type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, dict[str, NDArray], list[float]]], ) -> None: """Test adding configs of various unsupported types to a ConfigsRecord.""" c_record = ConfigsRecord() diff --git a/src/py/flwr/common/record/typeddict.py b/src/py/flwr/common/record/typeddict.py index 37d98b01a306..c2c8548c4de3 100644 --- a/src/py/flwr/common/record/typeddict.py +++ b/src/py/flwr/common/record/typeddict.py @@ -15,18 +15,8 @@ """Typed dict base class for *Records.""" -from typing import ( - Callable, - Dict, - Generic, - ItemsView, - Iterator, - KeysView, - MutableMapping, - TypeVar, - ValuesView, - cast, -) +from collections.abc import ItemsView, Iterator, KeysView, MutableMapping, ValuesView +from typing import Callable, Generic, TypeVar, cast K = TypeVar("K") # Key type V = TypeVar("V") # Value type @@ -49,37 +39,37 @@ def __setitem__(self, key: K, value: V) -> None: cast(Callable[[V], None], self.__dict__["_check_value_fn"])(value) # Set key-value pair - cast(Dict[K, V], self.__dict__["_data"])[key] = value + cast(dict[K, V], self.__dict__["_data"])[key] = value def __delitem__(self, key: K) -> None: """Remove the item with the specified key.""" - del cast(Dict[K, V], self.__dict__["_data"])[key] + del cast(dict[K, V], self.__dict__["_data"])[key] def __getitem__(self, item: K) -> V: """Return the value for the specified key.""" - return cast(Dict[K, V], self.__dict__["_data"])[item] + return cast(dict[K, V], self.__dict__["_data"])[item] def __iter__(self) -> Iterator[K]: """Yield an iterator over the keys of the dictionary.""" - return iter(cast(Dict[K, V], self.__dict__["_data"])) + return iter(cast(dict[K, V], self.__dict__["_data"])) def __repr__(self) -> str: """Return a string representation of the dictionary.""" - return cast(Dict[K, V], self.__dict__["_data"]).__repr__() + return cast(dict[K, V], self.__dict__["_data"]).__repr__() def __len__(self) -> int: """Return the number of items in the dictionary.""" - return len(cast(Dict[K, V], self.__dict__["_data"])) + return len(cast(dict[K, V], self.__dict__["_data"])) def __contains__(self, key: object) -> bool: """Check if the dictionary contains the specified key.""" - return key in cast(Dict[K, V], self.__dict__["_data"]) + return key in cast(dict[K, V], self.__dict__["_data"]) def __eq__(self, other: object) -> bool: """Compare this instance to another dictionary or TypedDict.""" - data = cast(Dict[K, V], self.__dict__["_data"]) + data = cast(dict[K, V], self.__dict__["_data"]) if isinstance(other, TypedDict): - other_data = cast(Dict[K, V], other.__dict__["_data"]) + other_data = cast(dict[K, V], other.__dict__["_data"]) return data == other_data if isinstance(other, dict): return data == other @@ -87,12 +77,12 @@ def __eq__(self, other: object) -> bool: def keys(self) -> KeysView[K]: """D.keys() -> a set-like object providing a view on D's keys.""" - return cast(Dict[K, V], self.__dict__["_data"]).keys() + return cast(dict[K, V], self.__dict__["_data"]).keys() def values(self) -> ValuesView[V]: """D.values() -> an object providing a view on D's values.""" - return cast(Dict[K, V], self.__dict__["_data"]).values() + return cast(dict[K, V], self.__dict__["_data"]).values() def items(self) -> ItemsView[K, V]: """D.items() -> a set-like object providing a view on D's items.""" - return cast(Dict[K, V], self.__dict__["_data"]).items() + return cast(dict[K, V], self.__dict__["_data"]).items() diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py index 8bf884c30e58..35024fcd67d1 100644 --- a/src/py/flwr/common/recordset_compat.py +++ b/src/py/flwr/common/recordset_compat.py @@ -15,7 +15,9 @@ """RecordSet utilities.""" -from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args +from collections import OrderedDict +from collections.abc import Mapping +from typing import Union, cast, get_args from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet from .typing import ( @@ -115,7 +117,7 @@ def parameters_to_parametersrecord( def _check_mapping_from_recordscalartype_to_scalar( record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]] -) -> Dict[str, Scalar]: +) -> dict[str, Scalar]: """Check mapping `common.*RecordValues` into `common.Scalar` is possible.""" for value in record_data.values(): if not isinstance(value, get_args(Scalar)): @@ -126,14 +128,14 @@ def _check_mapping_from_recordscalartype_to_scalar( "supported by the `common.RecordSet` infrastructure. " f"You used type: {type(value)}" ) - return cast(Dict[str, Scalar], record_data) + return cast(dict[str, Scalar], record_data) def _recordset_to_fit_or_evaluate_ins_components( recordset: RecordSet, ins_str: str, keep_input: bool, -) -> Tuple[Parameters, Dict[str, Scalar]]: +) -> tuple[Parameters, dict[str, Scalar]]: """Derive Fit/Evaluate Ins from a RecordSet.""" # get Array and construct Parameters parameters_record = recordset.parameters_records[f"{ins_str}.parameters"] @@ -169,7 +171,7 @@ def _fit_or_evaluate_ins_to_recordset( def _embed_status_into_recordset( res_str: str, status: Status, recordset: RecordSet ) -> RecordSet: - status_dict: Dict[str, ConfigsRecordValues] = { + status_dict: dict[str, ConfigsRecordValues] = { "code": int(status.code.value), "message": status.message, } diff --git a/src/py/flwr/common/recordset_compat_test.py b/src/py/flwr/common/recordset_compat_test.py index e0ac7f216af9..05d821e37e40 100644 --- a/src/py/flwr/common/recordset_compat_test.py +++ b/src/py/flwr/common/recordset_compat_test.py @@ -15,7 +15,7 @@ """RecordSet from legacy messages tests.""" from copy import deepcopy -from typing import Callable, Dict +from typing import Callable import numpy as np import pytest @@ -82,7 +82,7 @@ def _get_valid_fitins_with_empty_ndarrays() -> FitIns: def _get_valid_fitres() -> FitRes: """Returnn Valid parameters but potentially invalid config.""" arrays = get_ndarrays() - metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + metrics: dict[str, Scalar] = {"a": 1.0, "b": 0} return FitRes( parameters=ndarrays_to_parameters(arrays), num_examples=1, @@ -98,7 +98,7 @@ def _get_valid_evaluateins() -> EvaluateIns: def _get_valid_evaluateres() -> EvaluateRes: """Return potentially invalid config.""" - metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + metrics: dict[str, Scalar] = {"a": 1.0, "b": 0} return EvaluateRes( num_examples=1, loss=0.1, @@ -108,7 +108,7 @@ def _get_valid_evaluateres() -> EvaluateRes: def _get_valid_getparametersins() -> GetParametersIns: - config_dict: Dict[str, Scalar] = { + config_dict: dict[str, Scalar] = { "a": 1.0, "b": 3, "c": True, @@ -131,7 +131,7 @@ def _get_valid_getpropertiesins() -> GetPropertiesIns: def _get_valid_getpropertiesres() -> GetPropertiesRes: - config_dict: Dict[str, Scalar] = { + config_dict: dict[str, Scalar] = { "a": 1.0, "b": 3, "c": True, diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index d12124b89840..303d5596f237 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -18,20 +18,9 @@ import itertools import random import time +from collections.abc import Generator, Iterable from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import Any, Callable, Optional, Union, cast def exponential( @@ -93,8 +82,8 @@ class RetryState: """State for callbacks in RetryInvoker.""" target: Callable[..., Any] - args: Tuple[Any, ...] - kwargs: Dict[str, Any] + args: tuple[Any, ...] + kwargs: dict[str, Any] tries: int elapsed_time: float exception: Optional[Exception] = None @@ -167,7 +156,7 @@ class RetryInvoker: def __init__( self, wait_gen_factory: Callable[[], Generator[float, None, None]], - recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]], + recoverable_exceptions: Union[type[Exception], tuple[type[Exception], ...]], max_tries: Optional[int], max_time: Optional[float], *, @@ -244,7 +233,7 @@ def try_call_event_handler( try_cnt = 0 wait_generator = self.wait_gen_factory() start = time.monotonic() - ref_state: List[Optional[RetryState]] = [None] + ref_state: list[Optional[RetryState]] = [None] while True: try_cnt += 1 diff --git a/src/py/flwr/common/retry_invoker_test.py b/src/py/flwr/common/retry_invoker_test.py index 2259ae47ded4..a9f2625ff443 100644 --- a/src/py/flwr/common/retry_invoker_test.py +++ b/src/py/flwr/common/retry_invoker_test.py @@ -15,7 +15,7 @@ """Tests for `RetryInvoker`.""" -from typing import Generator +from collections.abc import Generator from unittest.mock import MagicMock, Mock, patch import pytest diff --git a/src/py/flwr/common/secure_aggregation/crypto/shamir.py b/src/py/flwr/common/secure_aggregation/crypto/shamir.py index 688bfa2153ea..9c7e67abf94f 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/shamir.py +++ b/src/py/flwr/common/secure_aggregation/crypto/shamir.py @@ -17,20 +17,20 @@ import pickle from concurrent.futures import ThreadPoolExecutor -from typing import List, Tuple, cast +from typing import cast from Crypto.Protocol.SecretSharing import Shamir from Crypto.Util.Padding import pad, unpad -def create_shares(secret: bytes, threshold: int, num: int) -> List[bytes]: +def create_shares(secret: bytes, threshold: int, num: int) -> list[bytes]: """Return list of shares (bytes).""" secret_padded = pad(secret, 16) secret_padded_chunk = [ (threshold, num, secret_padded[i : i + 16]) for i in range(0, len(secret_padded), 16) ] - share_list: List[List[Tuple[int, bytes]]] = [[] for _ in range(num)] + share_list: list[list[tuple[int, bytes]]] = [[] for _ in range(num)] with ThreadPoolExecutor(max_workers=10) as executor: for chunk_shares in executor.map( @@ -43,22 +43,22 @@ def create_shares(secret: bytes, threshold: int, num: int) -> List[bytes]: return [pickle.dumps(shares) for shares in share_list] -def _shamir_split(threshold: int, num: int, chunk: bytes) -> List[Tuple[int, bytes]]: +def _shamir_split(threshold: int, num: int, chunk: bytes) -> list[tuple[int, bytes]]: return Shamir.split(threshold, num, chunk, ssss=False) # Reconstructing secret with PyCryptodome -def combine_shares(share_list: List[bytes]) -> bytes: +def combine_shares(share_list: list[bytes]) -> bytes: """Reconstruct secret from shares.""" - unpickled_share_list: List[List[Tuple[int, bytes]]] = [ - cast(List[Tuple[int, bytes]], pickle.loads(share)) for share in share_list + unpickled_share_list: list[list[tuple[int, bytes]]] = [ + cast(list[tuple[int, bytes]], pickle.loads(share)) for share in share_list ] chunk_num = len(unpickled_share_list[0]) secret_padded = bytearray(0) - chunk_shares_list: List[List[Tuple[int, bytes]]] = [] + chunk_shares_list: list[list[tuple[int, bytes]]] = [] for i in range(chunk_num): - chunk_shares: List[Tuple[int, bytes]] = [] + chunk_shares: list[tuple[int, bytes]] = [] for share in unpickled_share_list: chunk_shares.append(share[i]) chunk_shares_list.append(chunk_shares) @@ -71,5 +71,5 @@ def combine_shares(share_list: List[bytes]) -> bytes: return bytes(secret) -def _shamir_combine(shares: List[Tuple[int, bytes]]) -> bytes: +def _shamir_combine(shares: list[tuple[int, bytes]]) -> bytes: return Shamir.combine(shares, ssss=False) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 59ca84d604b8..f5c130fb2663 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -16,7 +16,7 @@ import base64 -from typing import Tuple, cast +from typing import cast from cryptography.exceptions import InvalidSignature from cryptography.fernet import Fernet @@ -26,7 +26,7 @@ def generate_key_pairs() -> ( - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ): """Generate private and public key pairs with Cryptography.""" private_key = ec.generate_private_key(ec.SECP384R1()) diff --git a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py index 207c15b61518..3197fd852f3d 100644 --- a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py +++ b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py @@ -15,51 +15,51 @@ """Utility functions for performing operations on Numpy NDArrays.""" -from typing import Any, List, Tuple, Union +from typing import Any, Union import numpy as np from numpy.typing import DTypeLike, NDArray -def factor_combine(factor: int, parameters: List[NDArray[Any]]) -> List[NDArray[Any]]: +def factor_combine(factor: int, parameters: list[NDArray[Any]]) -> list[NDArray[Any]]: """Combine factor with parameters.""" return [np.array([factor])] + parameters def factor_extract( - parameters: List[NDArray[Any]], -) -> Tuple[int, List[NDArray[Any]]]: + parameters: list[NDArray[Any]], +) -> tuple[int, list[NDArray[Any]]]: """Extract factor from parameters.""" return parameters[0][0], parameters[1:] -def get_parameters_shape(parameters: List[NDArray[Any]]) -> List[Tuple[int, ...]]: +def get_parameters_shape(parameters: list[NDArray[Any]]) -> list[tuple[int, ...]]: """Get dimensions of each NDArray in parameters.""" return [arr.shape for arr in parameters] def get_zero_parameters( - dimensions_list: List[Tuple[int, ...]], dtype: DTypeLike = np.int64 -) -> List[NDArray[Any]]: + dimensions_list: list[tuple[int, ...]], dtype: DTypeLike = np.int64 +) -> list[NDArray[Any]]: """Generate zero parameters based on the dimensions list.""" return [np.zeros(dimensions, dtype=dtype) for dimensions in dimensions_list] def parameters_addition( - parameters1: List[NDArray[Any]], parameters2: List[NDArray[Any]] -) -> List[NDArray[Any]]: + parameters1: list[NDArray[Any]], parameters2: list[NDArray[Any]] +) -> list[NDArray[Any]]: """Add two parameters.""" return [parameters1[idx] + parameters2[idx] for idx in range(len(parameters1))] def parameters_subtraction( - parameters1: List[NDArray[Any]], parameters2: List[NDArray[Any]] -) -> List[NDArray[Any]]: + parameters1: list[NDArray[Any]], parameters2: list[NDArray[Any]] +) -> list[NDArray[Any]]: """Subtract parameters from the other parameters.""" return [parameters1[idx] - parameters2[idx] for idx in range(len(parameters1))] -def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray[Any]]: +def parameters_mod(parameters: list[NDArray[Any]], divisor: int) -> list[NDArray[Any]]: """Take mod of parameters with an integer divisor.""" if bin(divisor).count("1") == 1: msk = divisor - 1 @@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray def parameters_multiply( - parameters: List[NDArray[Any]], multiplier: Union[int, float] -) -> List[NDArray[Any]]: + parameters: list[NDArray[Any]], multiplier: Union[int, float] +) -> list[NDArray[Any]]: """Multiply parameters by an integer/float multiplier.""" return [parameters[idx] * multiplier for idx in range(len(parameters))] def parameters_divide( - parameters: List[NDArray[Any]], divisor: Union[int, float] -) -> List[NDArray[Any]]: + parameters: list[NDArray[Any]], divisor: Union[int, float] +) -> list[NDArray[Any]]: """Divide weight by an integer/float divisor.""" return [parameters[idx] / divisor for idx in range(len(parameters))] diff --git a/src/py/flwr/common/secure_aggregation/quantization.py b/src/py/flwr/common/secure_aggregation/quantization.py index 7946276b6a4f..ab8521eed981 100644 --- a/src/py/flwr/common/secure_aggregation/quantization.py +++ b/src/py/flwr/common/secure_aggregation/quantization.py @@ -15,7 +15,7 @@ """Utility functions for model quantization.""" -from typing import List, cast +from typing import cast import numpy as np @@ -30,10 +30,10 @@ def _stochastic_round(arr: NDArrayFloat) -> NDArrayInt: def quantize( - parameters: List[NDArrayFloat], clipping_range: float, target_range: int -) -> List[NDArrayInt]: + parameters: list[NDArrayFloat], clipping_range: float, target_range: int +) -> list[NDArrayInt]: """Quantize float Numpy arrays to integer Numpy arrays.""" - quantized_list: List[NDArrayInt] = [] + quantized_list: list[NDArrayInt] = [] quantizer = target_range / (2 * clipping_range) for arr in parameters: # Stochastic quantization @@ -49,12 +49,12 @@ def quantize( # Dequantize parameters to range [-clipping_range, clipping_range] def dequantize( - quantized_parameters: List[NDArrayInt], + quantized_parameters: list[NDArrayInt], clipping_range: float, target_range: int, -) -> List[NDArrayFloat]: +) -> list[NDArrayFloat]: """Dequantize integer Numpy arrays to float Numpy arrays.""" - reverse_quantized_list: List[NDArrayFloat] = [] + reverse_quantized_list: list[NDArrayFloat] = [] quantizer = (2 * clipping_range) / target_range shift = -clipping_range for arr in quantized_parameters: diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py index cf6ac3bfb003..7bfb80f57891 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py @@ -15,8 +15,6 @@ """Utility functions for the SecAgg/SecAgg+ protocol.""" -from typing import List, Tuple - import numpy as np from flwr.common.typing import NDArrayInt @@ -54,7 +52,7 @@ def share_keys_plaintext_concat( ) -def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, bytes]: +def share_keys_plaintext_separate(plaintext: bytes) -> tuple[int, int, bytes, bytes]: """Retrieve arguments from bytes. Parameters @@ -83,8 +81,8 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by def pseudo_rand_gen( - seed: bytes, num_range: int, dimensions_list: List[Tuple[int, ...]] -) -> List[NDArrayInt]: + seed: bytes, num_range: int, dimensions_list: list[tuple[int, ...]] +) -> list[NDArrayInt]: """Seeded pseudo-random number generator for noise generation with Numpy.""" assert len(seed) & 0x3 == 0 seed32 = 0 diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 76265b9836d1..87e01b05d341 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -15,7 +15,9 @@ """ProtoBuf serialization and deserialization.""" -from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast +from collections import OrderedDict +from collections.abc import MutableMapping +from typing import Any, TypeVar, cast from google.protobuf.message import Message as GrpcMessage @@ -72,7 +74,7 @@ def parameters_to_proto(parameters: typing.Parameters) -> Parameters: def parameters_from_proto(msg: Parameters) -> typing.Parameters: """Deserialize `Parameters` from ProtoBuf.""" - tensors: List[bytes] = list(msg.tensors) + tensors: list[bytes] = list(msg.tensors) return typing.Parameters(tensors=tensors, tensor_type=msg.tensor_type) @@ -390,7 +392,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: def _record_value_to_proto( - value: Any, allowed_types: List[type], proto_class: Type[T] + value: Any, allowed_types: list[type], proto_class: type[T] ) -> T: """Serialize `*RecordValue` to ProtoBuf. @@ -427,9 +429,9 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any: def _record_value_dict_to_proto( value_dict: TypedDict[str, Any], - allowed_types: List[type], - value_proto_class: Type[T], -) -> Dict[str, T]: + allowed_types: list[type], + value_proto_class: type[T], +) -> dict[str, T]: """Serialize the record value dict to ProtoBuf. Note: `bool` MUST be put in the front of allowd_types if it exists. @@ -447,7 +449,7 @@ def proto(_v: Any) -> T: def _record_value_dict_from_proto( value_dict_proto: MutableMapping[str, Any] -) -> Dict[str, Any]: +) -> dict[str, Any]: """Deserialize the record value dict from ProtoBuf.""" return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()} @@ -498,7 +500,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord """Deserialize MetricsRecord from ProtoBuf.""" return MetricsRecord( metrics_dict=cast( - Dict[str, typing.MetricsRecordValues], + dict[str, typing.MetricsRecordValues], _record_value_dict_from_proto(record_proto.data), ), keep_input=False, @@ -520,7 +522,7 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord """Deserialize ConfigsRecord from ProtoBuf.""" return ConfigsRecord( configs_dict=cast( - Dict[str, typing.ConfigsRecordValues], + dict[str, typing.ConfigsRecordValues], _record_value_dict_from_proto(record_proto.data), ), keep_input=False, diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 013d04a32fd4..49d1e38fa897 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -16,7 +16,8 @@ import random import string -from typing import Any, Callable, Optional, OrderedDict, Type, TypeVar, Union, cast +from collections import OrderedDict +from typing import Any, Callable, Optional, TypeVar, Union, cast import pytest @@ -169,7 +170,7 @@ def get_str(self, length: Optional[int] = None) -> str: length = self.rng.randint(1, 10) return "".join(self.rng.choices(char_pool, k=length)) - def get_value(self, dtype: Type[T]) -> T: + def get_value(self, dtype: type[T]) -> T: """Create a value of a given type.""" ret: Any = None if dtype == bool: diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index 981cfe79966a..724f36d2b98f 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -25,7 +25,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from flwr.common.version import package_name, package_version @@ -126,7 +126,7 @@ class EventType(str, Enum): # The type signature is not compatible with mypy, pylint and flake8 # so each of those needs to be disabled for this line. # pylint: disable-next=no-self-argument,arguments-differ,line-too-long - def _generate_next_value_(name: str, start: int, count: int, last_values: List[Any]) -> Any: # type: ignore # noqa: E501 + def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]) -> Any: # type: ignore # noqa: E501 return name # Ping @@ -189,7 +189,7 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A # Use the ThreadPoolExecutor with max_workers=1 to have a queue # and also ensure that telemetry calls are not blocking. -state: Dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = { +state: dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = { # Will be assigned ThreadPoolExecutor(max_workers=1) # in event() the first time it's required "executor": None, @@ -201,7 +201,7 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A def event( event_type: EventType, - event_details: Optional[Dict[str, Any]] = None, + event_details: Optional[dict[str, Any]] = None, ) -> Future: # type: ignore """Submit create_event to ThreadPoolExecutor to avoid blocking.""" if state["executor"] is None: @@ -213,7 +213,7 @@ def event( return result -def create_event(event_type: EventType, event_details: Optional[Dict[str, Any]]) -> str: +def create_event(event_type: EventType, event_details: Optional[dict[str, Any]]) -> str: """Create telemetry event.""" if state["source"] is None: state["source"] = _get_source_id() diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index b1dec8d0420b..081a957f28ff 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import numpy as np import numpy.typing as npt @@ -25,7 +25,7 @@ NDArray = npt.NDArray[Any] NDArrayInt = npt.NDArray[np.int_] NDArrayFloat = npt.NDArray[np.float_] -NDArrays = List[NDArray] +NDArrays = list[NDArray] # The following union type contains Python types corresponding to ProtoBuf types that # ProtoBuf considers to be "Scalar Value Types", even though some of them arguably do @@ -38,31 +38,31 @@ float, int, str, - List[bool], - List[bytes], - List[float], - List[int], - List[str], + list[bool], + list[bytes], + list[float], + list[int], + list[str], ] # Value types for common.MetricsRecord MetricsScalar = Union[int, float] -MetricsScalarList = Union[List[int], List[float]] +MetricsScalarList = Union[list[int], list[float]] MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] # Value types for common.ConfigsRecord ConfigsScalar = Union[MetricsScalar, str, bytes, bool] -ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]] +ConfigsScalarList = Union[MetricsScalarList, list[str], list[bytes], list[bool]] ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList] -Metrics = Dict[str, Scalar] -MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] +Metrics = dict[str, Scalar] +MetricsAggregationFn = Callable[[list[tuple[int, Metrics]]], Metrics] -Config = Dict[str, Scalar] -Properties = Dict[str, Scalar] +Config = dict[str, Scalar] +Properties = dict[str, Scalar] # Value type for user configs UserConfigValue = Union[bool, float, int, str] -UserConfig = Dict[str, UserConfigValue] +UserConfig = dict[str, UserConfigValue] class Code(Enum): @@ -103,7 +103,7 @@ class ClientAppOutputStatus: class Parameters: """Model parameters.""" - tensors: List[bytes] + tensors: list[bytes] tensor_type: str @@ -127,7 +127,7 @@ class FitIns: """Fit instructions for a client.""" parameters: Parameters - config: Dict[str, Scalar] + config: dict[str, Scalar] @dataclass @@ -137,7 +137,7 @@ class FitRes: status: Status parameters: Parameters num_examples: int - metrics: Dict[str, Scalar] + metrics: dict[str, Scalar] @dataclass @@ -145,7 +145,7 @@ class EvaluateIns: """Evaluate instructions for a client.""" parameters: Parameters - config: Dict[str, Scalar] + config: dict[str, Scalar] @dataclass @@ -155,7 +155,7 @@ class EvaluateRes: status: Status loss: float num_examples: int - metrics: Dict[str, Scalar] + metrics: dict[str, Scalar] @dataclass diff --git a/src/py/flwr/common/version.py b/src/py/flwr/common/version.py index ac13f70d8a88..141c16ac9367 100644 --- a/src/py/flwr/common/version.py +++ b/src/py/flwr/common/version.py @@ -15,15 +15,14 @@ """Flower package version helper.""" import importlib.metadata as importlib_metadata -from typing import Tuple -def _check_package(name: str) -> Tuple[str, str]: +def _check_package(name: str) -> tuple[str, str]: version: str = importlib_metadata.version(name) return name, version -def _version() -> Tuple[str, str]: +def _version() -> tuple[str, str]: """Read and return Flower package name and version. Returns diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 67fd54bfcae2..d156edaa3c99 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -19,10 +19,11 @@ import importlib.util import sys import threading +from collections.abc import Sequence from logging import INFO, WARN from os.path import isfile from pathlib import Path -from typing import Optional, Sequence, Set, Tuple +from typing import Optional import grpc from cryptography.exceptions import UnsupportedAlgorithm @@ -84,7 +85,7 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, - certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + certificates: Optional[tuple[bytes, bytes, bytes]] = None, ) -> History: """Start a Flower server using the gRPC transport layer. @@ -333,7 +334,7 @@ def run_superlink() -> None: driver_server.wait_for_termination(timeout=1) -def _format_address(address: str) -> Tuple[str, str, int]: +def _format_address(address: str) -> tuple[str, str, int]: parsed_address = parse_address(address) if not parsed_address: sys.exit( @@ -345,8 +346,8 @@ def _format_address(address: str) -> Tuple[str, str, int]: def _try_setup_node_authentication( args: argparse.Namespace, - certificates: Optional[Tuple[bytes, bytes, bytes]], -) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: + certificates: Optional[tuple[bytes, bytes, bytes]], +) -> Optional[tuple[set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: if ( not args.auth_list_public_keys and not args.auth_superlink_private_key @@ -381,7 +382,7 @@ def _try_setup_node_authentication( "to '--auth-list-public-keys'." ) - node_public_keys: Set[bytes] = set() + node_public_keys: set[bytes] = set() try: ssh_private_key = load_ssh_private_key( @@ -434,7 +435,7 @@ def _try_setup_node_authentication( def _try_obtain_certificates( args: argparse.Namespace, -) -> Optional[Tuple[bytes, bytes, bytes]]: +) -> Optional[tuple[bytes, bytes, bytes]]: # Obtain certificates if args.insecure: log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.") @@ -490,7 +491,7 @@ def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, ffs_factory: FfsFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Run Fleet API (gRPC, request-response).""" @@ -518,7 +519,7 @@ def _run_fleet_api_grpc_adapter( address: str, state_factory: StateFactory, ffs_factory: FfsFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Fleet API (GrpcAdapter).""" # Create Fleet API gRPC server diff --git a/src/py/flwr/server/client_manager.py b/src/py/flwr/server/client_manager.py index 7956e282bd2c..175bd4a786ea 100644 --- a/src/py/flwr/server/client_manager.py +++ b/src/py/flwr/server/client_manager.py @@ -19,7 +19,7 @@ import threading from abc import ABC, abstractmethod from logging import INFO -from typing import Dict, List, Optional +from typing import Optional from flwr.common.logger import log @@ -67,7 +67,7 @@ def unregister(self, client: ClientProxy) -> None: """ @abstractmethod - def all(self) -> Dict[str, ClientProxy]: + def all(self) -> dict[str, ClientProxy]: """Return all available clients.""" @abstractmethod @@ -80,7 +80,7 @@ def sample( num_clients: int, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + ) -> list[ClientProxy]: """Sample a number of Flower ClientProxy instances.""" @@ -88,7 +88,7 @@ class SimpleClientManager(ClientManager): """Provides a pool of available clients.""" def __init__(self) -> None: - self.clients: Dict[str, ClientProxy] = {} + self.clients: dict[str, ClientProxy] = {} self._cv = threading.Condition() def __len__(self) -> int: @@ -170,7 +170,7 @@ def unregister(self, client: ClientProxy) -> None: with self._cv: self._cv.notify_all() - def all(self) -> Dict[str, ClientProxy]: + def all(self) -> dict[str, ClientProxy]: """Return all available clients.""" return self.clients @@ -179,7 +179,7 @@ def sample( num_clients: int, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + ) -> list[ClientProxy]: """Sample a number of Flower ClientProxy instances.""" # Block until at least num_clients are connected. if min_num_clients is None: diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py index baff27307b88..8d2479f47d40 100644 --- a/src/py/flwr/server/compat/app_utils.py +++ b/src/py/flwr/server/compat/app_utils.py @@ -16,7 +16,6 @@ import threading -from typing import Dict, Tuple from ..client_manager import ClientManager from ..compat.driver_client_proxy import DriverClientProxy @@ -26,7 +25,7 @@ def start_update_client_manager_thread( driver: Driver, client_manager: ClientManager, -) -> Tuple[threading.Thread, threading.Event]: +) -> tuple[threading.Thread, threading.Event]: """Periodically update the nodes list in the client manager in a thread. This function starts a thread that periodically uses the associated driver to @@ -73,7 +72,7 @@ def _update_client_manager( ) -> None: """Update the nodes list in the client manager.""" # Loop until the driver is disconnected - registered_nodes: Dict[int, DriverClientProxy] = {} + registered_nodes: dict[int, DriverClientProxy] = {} while not f_stop.is_set(): all_node_ids = set(driver.get_node_ids()) dead_nodes = set(registered_nodes).difference(all_node_ids) diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index 31b917fa869b..a5b454c79f90 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -17,7 +17,8 @@ import unittest import unittest.mock -from typing import Any, Callable, Iterable, Optional, Union, cast +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union, cast from unittest.mock import Mock import numpy as np diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 4f888323e586..e8429e865db6 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -16,7 +16,8 @@ from abc import ABC, abstractmethod -from typing import Iterable, List, Optional +from collections.abc import Iterable +from typing import Optional from flwr.common import Message, RecordSet from flwr.common.typing import Run @@ -70,7 +71,7 @@ def create_message( # pylint: disable=too-many-arguments """ @abstractmethod - def get_node_ids(self) -> List[int]: + def get_node_ids(self) -> list[int]: """Get node IDs.""" @abstractmethod diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 2fe2c8a2e4aa..421dfd30ecb2 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -16,8 +16,9 @@ import time import warnings +from collections.abc import Iterable from logging import DEBUG, WARNING -from typing import Iterable, List, Optional, cast +from typing import Optional, cast import grpc @@ -192,7 +193,7 @@ def create_message( # pylint: disable=too-many-arguments ) return Message(metadata=metadata, content=content) - def get_node_ids(self) -> List[int]: + def get_node_ids(self) -> list[int]: """Get node IDs.""" self._init_run() # Call GrpcDriverStub method @@ -209,7 +210,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """ self._init_run() # Construct TaskIns - task_ins_list: List[TaskIns] = [] + task_ins_list: list[TaskIns] = [] for msg in messages: # Check message self._check_message(msg) @@ -255,7 +256,7 @@ def send_and_receive( # Pull messages end_time = time.time() + (timeout if timeout is not None else 0.0) - ret: List[Message] = [] + ret: list[Message] = [] while timeout is None or time.time() < end_time: res_msgs = self.pull_messages(msg_ids) ret.extend(res_msgs) diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 53406796750f..3a8a4b1bc73d 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,7 +17,8 @@ import time import warnings -from typing import Iterable, List, Optional, cast +from collections.abc import Iterable +from typing import Optional, cast from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet @@ -112,7 +113,7 @@ def create_message( # pylint: disable=too-many-arguments ) return Message(metadata=metadata, content=content) - def get_node_ids(self) -> List[int]: + def get_node_ids(self) -> list[int]: """Get node IDs.""" self._init_run() return list(self.state.get_nodes(cast(Run, self._run).run_id)) @@ -123,7 +124,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: This method takes an iterable of messages and sends each message to the node specified in `dst_node_id`. """ - task_ids: List[str] = [] + task_ids: list[str] = [] for msg in messages: # Check message self._check_message(msg) @@ -169,7 +170,7 @@ def send_and_receive( # Pull messages end_time = time.time() + (timeout if timeout is not None else 0.0) - ret: List[Message] = [] + ret: list[Message] = [] while timeout is None or time.time() < end_time: res_msgs = self.pull_messages(msg_ids) ret.extend(res_msgs) diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index ddfdb249c1b4..9e5aaeaa9ca7 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -17,7 +17,7 @@ import time import unittest -from typing import Iterable, List, Tuple +from collections.abc import Iterable from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -38,7 +38,7 @@ from .inmemory_driver import InMemoryDriver -def push_messages(driver: InMemoryDriver, num_nodes: int) -> Tuple[Iterable[str], int]: +def push_messages(driver: InMemoryDriver, num_nodes: int) -> tuple[Iterable[str], int]: """Help push messages to state.""" for _ in range(num_nodes): driver.state.create_node(ping_interval=PING_MAX_INTERVAL) @@ -55,7 +55,7 @@ def push_messages(driver: InMemoryDriver, num_nodes: int) -> Tuple[Iterable[str] def get_replies( driver: InMemoryDriver, msg_ids: Iterable[str], node_id: int -) -> List[str]: +) -> list[str]: """Help create message replies and pull taskres from state.""" taskins = driver.state.get_task_ins(node_id, limit=len(list(msg_ids))) for taskin in taskins: diff --git a/src/py/flwr/server/history.py b/src/py/flwr/server/history.py index 291974a4323c..50daf2e04de6 100644 --- a/src/py/flwr/server/history.py +++ b/src/py/flwr/server/history.py @@ -17,7 +17,6 @@ import pprint from functools import reduce -from typing import Dict, List, Tuple from flwr.common.typing import Scalar @@ -26,11 +25,11 @@ class History: """History class for training and/or evaluation metrics collection.""" def __init__(self) -> None: - self.losses_distributed: List[Tuple[int, float]] = [] - self.losses_centralized: List[Tuple[int, float]] = [] - self.metrics_distributed_fit: Dict[str, List[Tuple[int, Scalar]]] = {} - self.metrics_distributed: Dict[str, List[Tuple[int, Scalar]]] = {} - self.metrics_centralized: Dict[str, List[Tuple[int, Scalar]]] = {} + self.losses_distributed: list[tuple[int, float]] = [] + self.losses_centralized: list[tuple[int, float]] = [] + self.metrics_distributed_fit: dict[str, list[tuple[int, Scalar]]] = {} + self.metrics_distributed: dict[str, list[tuple[int, Scalar]]] = {} + self.metrics_centralized: dict[str, list[tuple[int, Scalar]]] = {} def add_loss_distributed(self, server_round: int, loss: float) -> None: """Add one loss entry (from distributed evaluation).""" @@ -41,7 +40,7 @@ def add_loss_centralized(self, server_round: int, loss: float) -> None: self.losses_centralized.append((server_round, loss)) def add_metrics_distributed_fit( - self, server_round: int, metrics: Dict[str, Scalar] + self, server_round: int, metrics: dict[str, Scalar] ) -> None: """Add metrics entries (from distributed fit).""" for key in metrics: @@ -52,7 +51,7 @@ def add_metrics_distributed_fit( self.metrics_distributed_fit[key].append((server_round, metrics[key])) def add_metrics_distributed( - self, server_round: int, metrics: Dict[str, Scalar] + self, server_round: int, metrics: dict[str, Scalar] ) -> None: """Add metrics entries (from distributed evaluation).""" for key in metrics: @@ -63,7 +62,7 @@ def add_metrics_distributed( self.metrics_distributed[key].append((server_round, metrics[key])) def add_metrics_centralized( - self, server_round: int, metrics: Dict[str, Scalar] + self, server_round: int, metrics: dict[str, Scalar] ) -> None: """Add metrics entries (from centralized evaluation).""" for key in metrics: diff --git a/src/py/flwr/server/server.py b/src/py/flwr/server/server.py index 5e2a0c6b2719..bdaa11ba20a2 100644 --- a/src/py/flwr/server/server.py +++ b/src/py/flwr/server/server.py @@ -19,7 +19,7 @@ import io import timeit from logging import INFO, WARN -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import ( Code, @@ -41,17 +41,17 @@ from .server_config import ServerConfig -FitResultsAndFailures = Tuple[ - List[Tuple[ClientProxy, FitRes]], - List[Union[Tuple[ClientProxy, FitRes], BaseException]], +FitResultsAndFailures = tuple[ + list[tuple[ClientProxy, FitRes]], + list[Union[tuple[ClientProxy, FitRes], BaseException]], ] -EvaluateResultsAndFailures = Tuple[ - List[Tuple[ClientProxy, EvaluateRes]], - List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], +EvaluateResultsAndFailures = tuple[ + list[tuple[ClientProxy, EvaluateRes]], + list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], ] -ReconnectResultsAndFailures = Tuple[ - List[Tuple[ClientProxy, DisconnectRes]], - List[Union[Tuple[ClientProxy, DisconnectRes], BaseException]], +ReconnectResultsAndFailures = tuple[ + list[tuple[ClientProxy, DisconnectRes]], + list[Union[tuple[ClientProxy, DisconnectRes], BaseException]], ] @@ -84,7 +84,7 @@ def client_manager(self) -> ClientManager: return self._client_manager # pylint: disable=too-many-locals - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: Optional[float]) -> tuple[History, float]: """Run federated averaging for a number of rounds.""" history = History() @@ -163,7 +163,7 @@ def evaluate_round( server_round: int, timeout: Optional[float], ) -> Optional[ - Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures] + tuple[Optional[float], dict[str, Scalar], EvaluateResultsAndFailures] ]: """Validate current global model on a number of clients.""" # Get clients and their respective instructions from strategy @@ -197,9 +197,9 @@ def evaluate_round( ) # Aggregate the evaluation results - aggregated_result: Tuple[ + aggregated_result: tuple[ Optional[float], - Dict[str, Scalar], + dict[str, Scalar], ] = self.strategy.aggregate_evaluate(server_round, results, failures) loss_aggregated, metrics_aggregated = aggregated_result @@ -210,7 +210,7 @@ def fit_round( server_round: int, timeout: Optional[float], ) -> Optional[ - Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures] + tuple[Optional[Parameters], dict[str, Scalar], FitResultsAndFailures] ]: """Perform a single round of federated averaging.""" # Get clients and their respective instructions from strategy @@ -245,9 +245,9 @@ def fit_round( ) # Aggregate training results - aggregated_result: Tuple[ + aggregated_result: tuple[ Optional[Parameters], - Dict[str, Scalar], + dict[str, Scalar], ] = self.strategy.aggregate_fit(server_round, results, failures) parameters_aggregated, metrics_aggregated = aggregated_result @@ -296,7 +296,7 @@ def _get_initial_parameters( def reconnect_clients( - client_instructions: List[Tuple[ClientProxy, ReconnectIns]], + client_instructions: list[tuple[ClientProxy, ReconnectIns]], max_workers: Optional[int], timeout: Optional[float], ) -> ReconnectResultsAndFailures: @@ -312,8 +312,8 @@ def reconnect_clients( ) # Gather results - results: List[Tuple[ClientProxy, DisconnectRes]] = [] - failures: List[Union[Tuple[ClientProxy, DisconnectRes], BaseException]] = [] + results: list[tuple[ClientProxy, DisconnectRes]] = [] + failures: list[Union[tuple[ClientProxy, DisconnectRes], BaseException]] = [] for future in finished_fs: failure = future.exception() if failure is not None: @@ -328,7 +328,7 @@ def reconnect_client( client: ClientProxy, reconnect: ReconnectIns, timeout: Optional[float], -) -> Tuple[ClientProxy, DisconnectRes]: +) -> tuple[ClientProxy, DisconnectRes]: """Instruct client to disconnect and (optionally) reconnect later.""" disconnect = client.reconnect( reconnect, @@ -339,7 +339,7 @@ def reconnect_client( def fit_clients( - client_instructions: List[Tuple[ClientProxy, FitIns]], + client_instructions: list[tuple[ClientProxy, FitIns]], max_workers: Optional[int], timeout: Optional[float], group_id: int, @@ -356,8 +356,8 @@ def fit_clients( ) # Gather results - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] for future in finished_fs: _handle_finished_future_after_fit( future=future, results=results, failures=failures @@ -367,7 +367,7 @@ def fit_clients( def fit_client( client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int -) -> Tuple[ClientProxy, FitRes]: +) -> tuple[ClientProxy, FitRes]: """Refine parameters on a single client.""" fit_res = client.fit(ins, timeout=timeout, group_id=group_id) return client, fit_res @@ -375,8 +375,8 @@ def fit_client( def _handle_finished_future_after_fit( future: concurrent.futures.Future, # type: ignore - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], ) -> None: """Convert finished future into either a result or a failure.""" # Check if there was an exception @@ -386,7 +386,7 @@ def _handle_finished_future_after_fit( return # Successfully received a result from a client - result: Tuple[ClientProxy, FitRes] = future.result() + result: tuple[ClientProxy, FitRes] = future.result() _, res = result # Check result status code @@ -399,7 +399,7 @@ def _handle_finished_future_after_fit( def evaluate_clients( - client_instructions: List[Tuple[ClientProxy, EvaluateIns]], + client_instructions: list[tuple[ClientProxy, EvaluateIns]], max_workers: Optional[int], timeout: Optional[float], group_id: int, @@ -416,8 +416,8 @@ def evaluate_clients( ) # Gather results - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] for future in finished_fs: _handle_finished_future_after_evaluate( future=future, results=results, failures=failures @@ -430,7 +430,7 @@ def evaluate_client( ins: EvaluateIns, timeout: Optional[float], group_id: int, -) -> Tuple[ClientProxy, EvaluateRes]: +) -> tuple[ClientProxy, EvaluateRes]: """Evaluate parameters on a single client.""" evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id) return client, evaluate_res @@ -438,8 +438,8 @@ def evaluate_client( def _handle_finished_future_after_evaluate( future: concurrent.futures.Future, # type: ignore - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], ) -> None: """Convert finished future into either a result or a failure.""" # Check if there was an exception @@ -449,7 +449,7 @@ def _handle_finished_future_after_evaluate( return # Successfully received a result from a client - result: Tuple[ClientProxy, EvaluateRes] = future.result() + result: tuple[ClientProxy, EvaluateRes] = future.result() _, res = result # Check result status code @@ -466,7 +466,7 @@ def init_defaults( config: Optional[ServerConfig], strategy: Optional[Strategy], client_manager: Optional[ClientManager], -) -> Tuple[Server, ServerConfig]: +) -> tuple[Server, ServerConfig]: """Create server instance if none was given.""" if server is None: if client_manager is None: diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index b80811a6f730..6e8f423fe115 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -19,7 +19,7 @@ import csv import tempfile from pathlib import Path -from typing import List, Optional +from typing import Optional import numpy as np from cryptography.hazmat.primitives.asymmetric import ec @@ -143,7 +143,7 @@ def reconnect( def test_fit_clients() -> None: """Test fit_clients.""" # Prepare - clients: List[ClientProxy] = [ + clients: list[ClientProxy] = [ FailingClient("0"), SuccessClient("1"), ] @@ -164,7 +164,7 @@ def test_fit_clients() -> None: def test_eval_clients() -> None: """Test eval_clients.""" # Prepare - clients: List[ClientProxy] = [ + clients: list[ClientProxy] = [ FailingClient("0"), SuccessClient("1"), ] diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index c668b55eebe6..d5ee7340f8ea 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -16,7 +16,7 @@ # mypy: disallow_untyped_calls=False from functools import reduce -from typing import Any, Callable, List, Tuple +from typing import Any, Callable import numpy as np @@ -24,7 +24,7 @@ from flwr.server.client_proxy import ClientProxy -def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: +def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays: """Compute weighted average.""" # Calculate the total number of examples used during training num_examples_total = sum(num_examples for (_, num_examples) in results) @@ -42,7 +42,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: return weights_prime -def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: +def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays: """Compute in-place weighted average.""" # Count total examples num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) @@ -67,7 +67,7 @@ def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: return params -def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: +def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays: """Compute median.""" # Create a list of weights and ignore the number of examples weights = [weights for weights, _ in results] @@ -80,7 +80,7 @@ def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: def aggregate_krum( - results: List[Tuple[NDArrays, int]], num_malicious: int, to_keep: int + results: list[tuple[NDArrays, int]], num_malicious: int, to_keep: int ) -> NDArrays: """Choose one parameter vector according to the Krum function. @@ -119,7 +119,7 @@ def aggregate_krum( # pylint: disable=too-many-locals def aggregate_bulyan( - results: List[Tuple[NDArrays, int]], + results: list[tuple[NDArrays, int]], num_malicious: int, aggregation_rule: Callable, # type: ignore **aggregation_rule_kwargs: Any, @@ -155,7 +155,7 @@ def aggregate_bulyan( "It is needed to ensure that the method reduces the attacker's leeway to " "the one proved in the paper." ) - selected_models_set: List[Tuple[NDArrays, int]] = [] + selected_models_set: list[tuple[NDArrays, int]] = [] theta = len(results) - 2 * num_malicious beta = theta - 2 * num_malicious @@ -200,7 +200,7 @@ def aggregate_bulyan( return parameters_aggregated -def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: +def weighted_loss_avg(results: list[tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results) weighted_losses = [num_examples * loss for num_examples, loss in results] @@ -208,7 +208,7 @@ def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: def aggregate_qffl( - parameters: NDArrays, deltas: List[NDArrays], hs_fll: List[NDArrays] + parameters: NDArrays, deltas: list[NDArrays], hs_fll: list[NDArrays] ) -> NDArrays: """Compute weighted average based on Q-FFL paper.""" demominator: float = np.sum(np.asarray(hs_fll)) @@ -225,7 +225,7 @@ def aggregate_qffl( return new_parameters -def _compute_distances(weights: List[NDArrays]) -> NDArray: +def _compute_distances(weights: list[NDArrays]) -> NDArray: """Compute distances between vectors. Input: weights - list of weights vectors @@ -265,7 +265,7 @@ def _trim_mean(array: NDArray, proportiontocut: float) -> NDArray: def aggregate_trimmed_avg( - results: List[Tuple[NDArrays, int]], proportiontocut: float + results: list[tuple[NDArrays, int]], proportiontocut: float ) -> NDArrays: """Compute trimmed average.""" # Create a list of weights and ignore the number of examples @@ -290,7 +290,7 @@ def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool: def _find_reference_weights( - reference_weights: NDArrays, list_of_weights: List[NDArrays] + reference_weights: NDArrays, list_of_weights: list[NDArrays] ) -> int: """Find the reference weights by looping through the `list_of_weights`. @@ -320,7 +320,7 @@ def _find_reference_weights( def _aggregate_n_closest_weights( - reference_weights: NDArrays, results: List[Tuple[NDArrays, int]], beta_closest: int + reference_weights: NDArrays, results: list[tuple[NDArrays, int]], beta_closest: int ) -> NDArrays: """Calculate element-wise mean of the `N` closest values. diff --git a/src/py/flwr/server/strategy/aggregate_test.py b/src/py/flwr/server/strategy/aggregate_test.py index f8b4e3c03b50..9f9dba79ec7c 100644 --- a/src/py/flwr/server/strategy/aggregate_test.py +++ b/src/py/flwr/server/strategy/aggregate_test.py @@ -15,8 +15,6 @@ """Aggregation function tests.""" -from typing import List, Tuple - import numpy as np from .aggregate import ( @@ -49,7 +47,7 @@ def test_aggregate() -> None: def test_weighted_loss_avg_single_value() -> None: """Test weighted loss averaging.""" # Prepare - results: List[Tuple[int, float]] = [(5, 0.5)] + results: list[tuple[int, float]] = [(5, 0.5)] expected = 0.5 # Execute @@ -62,7 +60,7 @@ def test_weighted_loss_avg_single_value() -> None: def test_weighted_loss_avg_multiple_values() -> None: """Test weighted loss averaging.""" # Prepare - results: List[Tuple[int, float]] = [(1, 2.0), (2, 1.0), (1, 2.0)] + results: list[tuple[int, float]] = [(1, 2.0), (2, 1.0), (1, 2.0)] expected = 1.5 # Execute diff --git a/src/py/flwr/server/strategy/bulyan.py b/src/py/flwr/server/strategy/bulyan.py index a81406c255ad..84a261237ac5 100644 --- a/src/py/flwr/server/strategy/bulyan.py +++ b/src/py/flwr/server/strategy/bulyan.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union from flwr.common import ( FitRes, @@ -86,12 +86,12 @@ def __init__( num_malicious_clients: int = 0, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -125,9 +125,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using Bulyan.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/bulyan_test.py b/src/py/flwr/server/strategy/bulyan_test.py index 93a9ebda3783..c0b87c82a036 100644 --- a/src/py/flwr/server/strategy/bulyan_test.py +++ b/src/py/flwr/server/strategy/bulyan_test.py @@ -15,7 +15,6 @@ """Bulyan tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -62,7 +61,7 @@ def test_aggregate_fit() -> None: param_5: Parameters = ndarrays_to_parameters( [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] ) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( diff --git a/src/py/flwr/server/strategy/dp_adaptive_clipping.py b/src/py/flwr/server/strategy/dp_adaptive_clipping.py index b25e1efdf0e9..77e70bb9af04 100644 --- a/src/py/flwr/server/strategy/dp_adaptive_clipping.py +++ b/src/py/flwr/server/strategy/dp_adaptive_clipping.py @@ -20,7 +20,7 @@ import math from logging import INFO, WARNING -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np @@ -156,14 +156,14 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" self.current_round_params = parameters_to_ndarrays(parameters) return self.strategy.configure_fit(server_round, parameters, client_manager) def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -172,9 +172,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results and update clip norms.""" if failures: return None, {} @@ -245,15 +245,15 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) @@ -372,7 +372,7 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} inner_strategy_config_result = self.strategy.configure_fit( @@ -385,7 +385,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -394,9 +394,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results and update clip norms.""" if failures: return None, {} @@ -432,7 +432,7 @@ def aggregate_fit( return aggregated_params, metrics - def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: + def _update_clip_norm(self, results: list[tuple[ClientProxy, FitRes]]) -> None: # Calculate the number of clients which set the norm indicator bit norm_bit_set_count = 0 for client_proxy, fit_res in results: @@ -457,14 +457,14 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dp_fixed_clipping.py b/src/py/flwr/server/strategy/dp_fixed_clipping.py index 92b2845fd846..2ca253c96370 100644 --- a/src/py/flwr/server/strategy/dp_fixed_clipping.py +++ b/src/py/flwr/server/strategy/dp_fixed_clipping.py @@ -19,7 +19,7 @@ from logging import INFO, WARNING -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import ( EvaluateIns, @@ -117,14 +117,14 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" self.current_round_params = parameters_to_ndarrays(parameters) return self.strategy.configure_fit(server_round, parameters, client_manager) def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -133,9 +133,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Compute the updates, clip, and pass them for aggregation. Afterward, add noise to the aggregated parameters. @@ -191,15 +191,15 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) @@ -285,7 +285,7 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} inner_strategy_config_result = self.strategy.configure_fit( @@ -298,7 +298,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -307,9 +307,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Add noise to the aggregated parameters.""" if failures: return None, {} @@ -348,14 +348,14 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 423ddddeb379..ab513aba2269 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -19,7 +19,7 @@ import math -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np @@ -80,7 +80,7 @@ def __repr__(self) -> str: def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" additional_config = {"dpfedavg_adaptive_clip_enabled": True} @@ -93,7 +93,7 @@ def configure_fit( return client_instructions - def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: + def _update_clip_norm(self, results: list[tuple[ClientProxy, FitRes]]) -> None: # Calculating number of clients which set the norm indicator bit norm_bit_set_count = 0 for client_proxy, fit_res in results: @@ -118,9 +118,9 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results as in DPFedAvgFixed and update clip norms.""" if failures: return None, {} diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index d122f0688922..4ea84db30cd4 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -17,7 +17,7 @@ Paper: arxiv.org/pdf/1710.06963.pdf """ -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.common.dp import add_gaussian_noise @@ -79,7 +79,7 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training incorporating Differential Privacy (DP). Configuration of the next training round includes information related to DP, @@ -119,7 +119,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation using the specified strategy. Parameters @@ -147,9 +147,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results using unweighted aggregation.""" if failures: return None, {} @@ -168,14 +168,14 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/fault_tolerant_fedavg.py b/src/py/flwr/server/strategy/fault_tolerant_fedavg.py index 663ac8872c39..60213db2efeb 100644 --- a/src/py/flwr/server/strategy/fault_tolerant_fedavg.py +++ b/src/py/flwr/server/strategy/fault_tolerant_fedavg.py @@ -16,7 +16,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( EvaluateRes, @@ -49,12 +49,12 @@ def __init__( min_available_clients: int = 1, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, min_completion_rate_fit: float = 0.5, min_completion_rate_evaluate: float = 0.5, initial_parameters: Optional[Parameters] = None, @@ -85,9 +85,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -117,9 +117,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py b/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py index 98f4cac032cb..a01a3a5c0ad5 100644 --- a/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py +++ b/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py @@ -15,7 +15,7 @@ """FaultTolerantFedAvg tests.""" -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from unittest.mock import MagicMock from flwr.common import ( @@ -36,8 +36,8 @@ def test_aggregate_fit_no_results_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.1) - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: Optional[Parameters] = None # Execute @@ -51,8 +51,8 @@ def test_aggregate_fit_no_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.1) - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [Exception()] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [Exception()] expected: Optional[Parameters] = None # Execute @@ -66,7 +66,7 @@ def test_aggregate_fit_not_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.5) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -77,7 +77,7 @@ def test_aggregate_fit_not_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [ + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [ Exception(), Exception(), ] @@ -94,7 +94,7 @@ def test_aggregate_fit_just_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.5) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -105,7 +105,7 @@ def test_aggregate_fit_just_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [Exception()] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [Exception()] expected: Optional[NDArrays] = [] # Execute @@ -120,7 +120,7 @@ def test_aggregate_fit_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.99) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -131,7 +131,7 @@ def test_aggregate_fit_no_failures() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: Optional[NDArrays] = [] # Execute @@ -146,8 +146,8 @@ def test_aggregate_evaluate_no_results_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.1) - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] expected: Optional[float] = None # Execute @@ -161,8 +161,8 @@ def test_aggregate_evaluate_no_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.1) - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [ Exception() ] expected: Optional[float] = None @@ -178,7 +178,7 @@ def test_aggregate_evaluate_not_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.5) - results: List[Tuple[ClientProxy, EvaluateRes]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [ ( MagicMock(), EvaluateRes( @@ -189,7 +189,7 @@ def test_aggregate_evaluate_not_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [ + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [ Exception(), Exception(), ] @@ -206,7 +206,7 @@ def test_aggregate_evaluate_just_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.5) - results: List[Tuple[ClientProxy, EvaluateRes]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [ ( MagicMock(), EvaluateRes( @@ -217,7 +217,7 @@ def test_aggregate_evaluate_just_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [ + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [ Exception() ] expected: Optional[float] = 2.3 @@ -233,7 +233,7 @@ def test_aggregate_evaluate_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.99) - results: List[Tuple[ClientProxy, EvaluateRes]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [ ( MagicMock(), EvaluateRes( @@ -244,7 +244,7 @@ def test_aggregate_evaluate_no_failures() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] expected: Optional[float] = 2.3 # Execute diff --git a/src/py/flwr/server/strategy/fedadagrad.py b/src/py/flwr/server/strategy/fedadagrad.py index f13c5358da25..75befdd0e796 100644 --- a/src/py/flwr/server/strategy/fedadagrad.py +++ b/src/py/flwr/server/strategy/fedadagrad.py @@ -20,7 +20,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -89,12 +89,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, accept_failures: bool = True, @@ -131,9 +131,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures diff --git a/src/py/flwr/server/strategy/fedadagrad_test.py b/src/py/flwr/server/strategy/fedadagrad_test.py index b43a4c75d123..96d98fe750f3 100644 --- a/src/py/flwr/server/strategy/fedadagrad_test.py +++ b/src/py/flwr/server/strategy/fedadagrad_test.py @@ -15,7 +15,6 @@ """FedAdagrad tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -54,7 +53,7 @@ def test_aggregate_fit() -> None: bridge = MagicMock() client_0 = GrpcClientProxy(cid="0", bridge=bridge) client_1 = GrpcClientProxy(cid="1", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/fedadam.py b/src/py/flwr/server/strategy/fedadam.py index dc90e90c7568..d0f87a43f79b 100644 --- a/src/py/flwr/server/strategy/fedadam.py +++ b/src/py/flwr/server/strategy/fedadam.py @@ -20,7 +20,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -93,12 +93,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Parameters, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -137,9 +137,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index 3b9b2640c2b5..2d0b855c3186 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( EvaluateIns, @@ -99,12 +99,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -138,12 +138,12 @@ def __repr__(self) -> str: rep = f"FedAvg(accept_failures={self.accept_failures})" return rep - def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]: """Return the sample size and the required number of available clients.""" num_clients = int(num_available_clients * self.fraction_fit) return max(num_clients, self.min_fit_clients), self.min_available_clients - def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]: """Use a fraction of available clients for evaluation.""" num_clients = int(num_available_clients * self.fraction_evaluate) return max(num_clients, self.min_evaluate_clients), self.min_available_clients @@ -158,7 +158,7 @@ def initialize_parameters( def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided @@ -172,7 +172,7 @@ def evaluate( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" config = {} if self.on_fit_config_fn is not None: @@ -193,7 +193,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction eval is 0. if self.fraction_evaluate == 0.0: @@ -220,9 +220,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -256,9 +256,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index 2f49cf8784c9..bcecf8efb504 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Callable, Optional, Union, cast import numpy as np @@ -81,12 +81,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, ) -> None: @@ -107,12 +107,12 @@ def __repr__(self) -> str: rep = f"FedAvg(accept_failures={self.accept_failures})" return rep - def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]: """Return the sample size and the required number of available clients.""" num_clients = int(num_available_clients * self.fraction_fit) return max(num_clients, self.min_fit_clients), self.min_available_clients - def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]: """Use a fraction of available clients for evaluation.""" num_clients = int(num_available_clients * self.fraction_evaluate) return max(num_clients, self.min_evaluate_clients), self.min_available_clients @@ -127,7 +127,7 @@ def initialize_parameters( def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided @@ -141,7 +141,7 @@ def evaluate( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" config = {} if self.on_fit_config_fn is not None: @@ -162,7 +162,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction_evaluate is 0 if self.fraction_evaluate == 0.0: @@ -189,9 +189,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -208,9 +208,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedavg_test.py b/src/py/flwr/server/strategy/fedavg_test.py index e62eaa5c5832..66241c3ab66a 100644 --- a/src/py/flwr/server/strategy/fedavg_test.py +++ b/src/py/flwr/server/strategy/fedavg_test.py @@ -15,7 +15,7 @@ """FedAvg tests.""" -from typing import List, Tuple, Union +from typing import Union from unittest.mock import MagicMock import numpy as np @@ -140,7 +140,7 @@ def test_inplace_aggregate_fit_equivalence() -> None: weights1_0 = np.random.randn(100, 64) weights1_1 = np.random.randn(314, 628, 3) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -160,7 +160,7 @@ def test_inplace_aggregate_fit_equivalence() -> None: ), ), ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] fedavg_reference = FedAvg(inplace=False) fedavg_inplace = FedAvg() diff --git a/src/py/flwr/server/strategy/fedavgm.py b/src/py/flwr/server/strategy/fedavgm.py index ab3d37249db6..a7c37c38770f 100644 --- a/src/py/flwr/server/strategy/fedavgm.py +++ b/src/py/flwr/server/strategy/fedavgm.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( FitRes, @@ -84,12 +84,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -132,9 +132,9 @@ def initialize_parameters( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedavgm_test.py b/src/py/flwr/server/strategy/fedavgm_test.py index 39da5f4b82c4..400fa3c97247 100644 --- a/src/py/flwr/server/strategy/fedavgm_test.py +++ b/src/py/flwr/server/strategy/fedavgm_test.py @@ -15,7 +15,7 @@ """FedAvgM tests.""" -from typing import List, Tuple, Union +from typing import Union from unittest.mock import MagicMock from numpy import array, float32 @@ -41,7 +41,7 @@ def test_aggregate_fit_using_near_one_server_lr_and_no_momentum() -> None: array([0, 0, 0, 0], dtype=float32), ] - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -61,7 +61,7 @@ def test_aggregate_fit_using_near_one_server_lr_and_no_momentum() -> None: ), ), ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: NDArrays = [ array([[1, 2, 3], [4, 5, 6]], dtype=float32), array([7, 8, 9, 10], dtype=float32), @@ -94,7 +94,7 @@ def test_aggregate_fit_server_learning_rate_and_momentum() -> None: array([0, 0, 0, 0], dtype=float32), ] - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -114,7 +114,7 @@ def test_aggregate_fit_server_learning_rate_and_momentum() -> None: ), ), ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: NDArrays = [ array([[1, 2, 3], [4, 5, 6]], dtype=float32), array([7, 8, 9, 10], dtype=float32), diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index e7cba5324fa8..35044d42b22c 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import ( FitRes, @@ -46,9 +46,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using median.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedmedian_test.py b/src/py/flwr/server/strategy/fedmedian_test.py index 3960ad70b145..2c9881635319 100644 --- a/src/py/flwr/server/strategy/fedmedian_test.py +++ b/src/py/flwr/server/strategy/fedmedian_test.py @@ -15,7 +15,6 @@ """FedMedian tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -159,7 +158,7 @@ def test_aggregate_fit() -> None: client_0 = GrpcClientProxy(cid="0", bridge=bridge) client_1 = GrpcClientProxy(cid="1", bridge=bridge) client_2 = GrpcClientProxy(cid="2", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/fedopt.py b/src/py/flwr/server/strategy/fedopt.py index c581d4797123..3e143fc3ca59 100644 --- a/src/py/flwr/server/strategy/fedopt.py +++ b/src/py/flwr/server/strategy/fedopt.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Optional from flwr.common import ( MetricsAggregationFn, @@ -86,12 +86,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Parameters, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, diff --git a/src/py/flwr/server/strategy/fedprox.py b/src/py/flwr/server/strategy/fedprox.py index f15271e06060..218fece0491f 100644 --- a/src/py/flwr/server/strategy/fedprox.py +++ b/src/py/flwr/server/strategy/fedprox.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Optional from flwr.common import FitIns, MetricsAggregationFn, NDArrays, Parameters, Scalar from flwr.server.client_manager import ClientManager @@ -113,12 +113,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -148,7 +148,7 @@ def __repr__(self) -> str: def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training. Sends the proximal factor mu to the clients diff --git a/src/py/flwr/server/strategy/fedtrimmedavg.py b/src/py/flwr/server/strategy/fedtrimmedavg.py index 96b0d35e7a61..8a0e4e50fbff 100644 --- a/src/py/flwr/server/strategy/fedtrimmedavg.py +++ b/src/py/flwr/server/strategy/fedtrimmedavg.py @@ -17,7 +17,7 @@ Paper: arxiv.org/abs/1803.01498 """ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( FitRes, @@ -78,12 +78,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -114,9 +114,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using trimmed average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index a74ee81976a6..1e55466808f8 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -17,7 +17,7 @@ import json from logging import WARNING -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Optional, Union, cast from flwr.common import EvaluateRes, FitRes, Parameters, Scalar from flwr.common.logger import log @@ -34,8 +34,8 @@ def __init__( self, evaluate_function: Optional[ Callable[ - [int, Parameters, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, Parameters, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, **kwargs: Any, @@ -52,9 +52,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using bagging.""" if not results: return None, {} @@ -79,9 +79,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation metrics using average.""" if not results: return None, {} @@ -101,7 +101,7 @@ def aggregate_evaluate( def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_function is None: # No evaluation function provided @@ -152,7 +152,7 @@ def aggregate( return bst_prev_bytes -def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]: +def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]: xgb_model = json.loads(bytearray(xgb_model_org)) # Get the number of trees tree_num = int( diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py index 75025a89728b..c2dc3d797c7e 100644 --- a/src/py/flwr/server/strategy/fedxgb_cyclic.py +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -16,7 +16,7 @@ from logging import WARNING -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.common.logger import log @@ -45,9 +45,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using bagging.""" if not results: return None, {} @@ -69,9 +69,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation metrics using average.""" if not results: return None, {} @@ -91,7 +91,7 @@ def aggregate_evaluate( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" config = {} if self.on_fit_config_fn is not None: @@ -117,7 +117,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction eval is 0. if self.fraction_evaluate == 0.0: diff --git a/src/py/flwr/server/strategy/fedxgb_nn_avg.py b/src/py/flwr/server/strategy/fedxgb_nn_avg.py index 4562663287ae..a7da4a919af7 100644 --- a/src/py/flwr/server/strategy/fedxgb_nn_avg.py +++ b/src/py/flwr/server/strategy/fedxgb_nn_avg.py @@ -22,7 +22,7 @@ from logging import WARNING -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays from flwr.common.logger import log, warn_deprecated_feature @@ -56,7 +56,7 @@ def __repr__(self) -> str: def evaluate( self, server_round: int, parameters: Any - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided @@ -70,9 +70,9 @@ def evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Any], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Any], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedyogi.py b/src/py/flwr/server/strategy/fedyogi.py index c7b2ebb51667..11873d1b781f 100644 --- a/src/py/flwr/server/strategy/fedyogi.py +++ b/src/py/flwr/server/strategy/fedyogi.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -93,12 +93,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Parameters, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -137,9 +137,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures diff --git a/src/py/flwr/server/strategy/krum.py b/src/py/flwr/server/strategy/krum.py index 074d018c35a3..5d33874b9789 100644 --- a/src/py/flwr/server/strategy/krum.py +++ b/src/py/flwr/server/strategy/krum.py @@ -21,7 +21,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( FitRes, @@ -87,12 +87,12 @@ def __init__( num_clients_to_keep: int = 0, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -123,9 +123,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using Krum.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/krum_test.py b/src/py/flwr/server/strategy/krum_test.py index b34982325b39..dc996b480630 100644 --- a/src/py/flwr/server/strategy/krum_test.py +++ b/src/py/flwr/server/strategy/krum_test.py @@ -15,7 +15,6 @@ """Krum tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -160,7 +159,7 @@ def test_aggregate_fit() -> None: client_0 = GrpcClientProxy(cid="0", bridge=bridge) client_1 = GrpcClientProxy(cid="1", bridge=bridge) client_2 = GrpcClientProxy(cid="2", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/multikrum_test.py b/src/py/flwr/server/strategy/multikrum_test.py index 7a1a4c3ecf38..90607e2c0edc 100644 --- a/src/py/flwr/server/strategy/multikrum_test.py +++ b/src/py/flwr/server/strategy/multikrum_test.py @@ -15,7 +15,6 @@ """Krum tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -59,7 +58,7 @@ def test_aggregate_fit() -> None: client_1 = GrpcClientProxy(cid="1", bridge=bridge) client_2 = GrpcClientProxy(cid="2", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 26a397d4cf8c..30a3cc53ee94 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -60,12 +60,12 @@ def __init__( min_available_clients: int = 1, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -95,19 +95,19 @@ def __repr__(self) -> str: rep += f"q_param={self.q_param}, pre_weights={self.pre_weights})" return rep - def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]: """Return the sample size and the required number of available clients.""" num_clients = int(num_available_clients * self.fraction_fit) return max(num_clients, self.min_fit_clients), self.min_available_clients - def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]: """Use a fraction of available clients for evaluation.""" num_clients = int(num_available_clients * self.fraction_evaluate) return max(num_clients, self.min_evaluate_clients), self.min_available_clients def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" weights = parameters_to_ndarrays(parameters) self.pre_weights = weights @@ -131,7 +131,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction_evaluate is 0 if self.fraction_evaluate == 0.0: @@ -158,9 +158,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -229,9 +229,9 @@ def norm_grad(grad_list: NDArrays) -> float: def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/strategy.py b/src/py/flwr/server/strategy/strategy.py index cfdfe2e246c5..14999e9a8993 100644 --- a/src/py/flwr/server/strategy/strategy.py +++ b/src/py/flwr/server/strategy/strategy.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.server.client_manager import ClientManager @@ -47,7 +47,7 @@ def initialize_parameters( @abstractmethod def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training. Parameters @@ -72,9 +72,9 @@ def configure_fit( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results. Parameters @@ -108,7 +108,7 @@ def aggregate_fit( @abstractmethod def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation. Parameters @@ -134,9 +134,9 @@ def configure_evaluate( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation results. Parameters @@ -164,7 +164,7 @@ def aggregate_evaluate( @abstractmethod def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate the current model parameters. This function can be used to perform centralized (i.e., server-side) evaluation diff --git a/src/py/flwr/server/superlink/driver/driver_grpc.py b/src/py/flwr/server/superlink/driver/driver_grpc.py index b7b914206f72..70354387812e 100644 --- a/src/py/flwr/server/superlink/driver/driver_grpc.py +++ b/src/py/flwr/server/superlink/driver/driver_grpc.py @@ -15,7 +15,7 @@ """Driver gRPC API.""" from logging import INFO -from typing import Optional, Tuple +from typing import Optional import grpc @@ -35,7 +35,7 @@ def run_driver_api_grpc( address: str, state_factory: StateFactory, ffs_factory: FfsFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Driver API (gRPC, request-response).""" # Create Driver API gRPC server diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 73cd1c73a6fd..4d7d6cb6ce89 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -17,7 +17,7 @@ import time from logging import DEBUG -from typing import List, Optional, Set +from typing import Optional from uuid import UUID import grpc @@ -68,8 +68,8 @@ def GetNodes( """Get available nodes.""" log(DEBUG, "DriverServicer.GetNodes") state: State = self.state_factory.state() - all_ids: Set[int] = state.get_nodes(request.run_id) - nodes: List[Node] = [ + all_ids: set[int] = state.get_nodes(request.run_id) + nodes: list[Node] = [ Node(node_id=node_id, anonymous=False) for node_id in all_ids ] return GetNodesResponse(nodes=nodes) @@ -119,7 +119,7 @@ def PushTaskIns( state: State = self.state_factory.state() # Store each TaskIns - task_ids: List[Optional[UUID]] = [] + task_ids: list[Optional[UUID]] = [] for task_ins in request.task_ins_list: task_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins) task_ids.append(task_id) @@ -135,7 +135,7 @@ def PullTaskRes( log(DEBUG, "DriverServicer.PullTaskRes") # Convert each task_id str to UUID - task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids} + task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids} # Init state state: State = self.state_factory.state() @@ -155,7 +155,7 @@ def on_rpc_done() -> None: context.add_callback(on_rpc_done) # Read from state - task_res_list: List[TaskRes] = state.get_task_res(task_ids=task_ids, limit=None) + task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids, limit=None) context.set_code(grpc.StatusCode.OK) return PullTaskResResponse(task_res_list=task_res_list) diff --git a/src/py/flwr/server/superlink/ffs/disk_ffs.py b/src/py/flwr/server/superlink/ffs/disk_ffs.py index 98ec4f93498f..4f1ab05be9a2 100644 --- a/src/py/flwr/server/superlink/ffs/disk_ffs.py +++ b/src/py/flwr/server/superlink/ffs/disk_ffs.py @@ -17,7 +17,7 @@ import hashlib import json from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Optional from flwr.server.superlink.ffs.ffs import Ffs @@ -35,7 +35,7 @@ def __init__(self, base_dir: str) -> None: """ self.base_dir = Path(base_dir) - def put(self, content: bytes, meta: Dict[str, str]) -> str: + def put(self, content: bytes, meta: dict[str, str]) -> str: """Store bytes and metadata and return key (hash of content). Parameters @@ -58,7 +58,7 @@ def put(self, content: bytes, meta: Dict[str, str]) -> str: return content_hash - def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]: + def get(self, key: str) -> Optional[tuple[bytes, dict[str, str]]]: """Return tuple containing the object content and metadata. Parameters @@ -90,7 +90,7 @@ def delete(self, key: str) -> None: (self.base_dir / key).unlink() (self.base_dir / f"{key}.META").unlink() - def list(self) -> List[str]: + def list(self) -> list[str]: """List all keys. Return all available keys in this `Ffs` instance. diff --git a/src/py/flwr/server/superlink/ffs/ffs.py b/src/py/flwr/server/superlink/ffs/ffs.py index fab3b1fdfb3e..b1d26e74c157 100644 --- a/src/py/flwr/server/superlink/ffs/ffs.py +++ b/src/py/flwr/server/superlink/ffs/ffs.py @@ -16,14 +16,14 @@ import abc -from typing import Dict, List, Optional, Tuple +from typing import Optional class Ffs(abc.ABC): # pylint: disable=R0904 """Abstract Flower File Storage interface for large objects.""" @abc.abstractmethod - def put(self, content: bytes, meta: Dict[str, str]) -> str: + def put(self, content: bytes, meta: dict[str, str]) -> str: """Store bytes and metadata and return sha256hex hash of data as str. Parameters @@ -40,7 +40,7 @@ def put(self, content: bytes, meta: Dict[str, str]) -> str: """ @abc.abstractmethod - def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]: + def get(self, key: str) -> Optional[tuple[bytes, dict[str, str]]]: """Return tuple containing the object content and metadata. Parameters @@ -65,7 +65,7 @@ def delete(self, key: str) -> None: """ @abc.abstractmethod - def list(self) -> List[str]: + def list(self) -> list[str]: """List keys of all stored objects. Return all available keys in this `Ffs` instance. diff --git a/src/py/flwr/server/superlink/ffs/ffs_test.py b/src/py/flwr/server/superlink/ffs/ffs_test.py index f7fbbf1218e1..5cf28cfd2cbe 100644 --- a/src/py/flwr/server/superlink/ffs/ffs_test.py +++ b/src/py/flwr/server/superlink/ffs/ffs_test.py @@ -21,7 +21,6 @@ import tempfile import unittest from abc import abstractmethod -from typing import Dict from flwr.server.superlink.ffs import DiskFfs, Ffs @@ -65,7 +64,7 @@ def test_get(self) -> None: ffs: Ffs = self.ffs_factory() content_expected = b"content" hash_expected = hashlib.sha256(content_expected).hexdigest() - meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + meta_expected: dict[str, str] = {"meta_key": "meta_value"} with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: file.write(content_expected) @@ -93,7 +92,7 @@ def test_delete(self) -> None: ffs: Ffs = self.ffs_factory() content_expected = b"content" hash_expected = hashlib.sha256(content_expected).hexdigest() - meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + meta_expected: dict[str, str] = {"meta_key": "meta_value"} with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: file.write(content_expected) @@ -117,7 +116,7 @@ def test_list(self) -> None: ffs: Ffs = self.ffs_factory() content_expected = b"content" hash_expected = hashlib.sha256(content_expected).hexdigest() - meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + meta_expected: dict[str, str] = {"meta_key": "meta_value"} with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: file.write(content_expected) diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index 278e20eb1d69..dbfbb236a7e4 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -16,7 +16,7 @@ from logging import DEBUG, INFO -from typing import Callable, Type, TypeVar +from typing import Callable, TypeVar import grpc from google.protobuf.message import Message as GrpcMessage @@ -47,7 +47,7 @@ def _handle( msg_container: MessageContainer, - request_type: Type[T], + request_type: type[T], handler: Callable[[T], GrpcMessage], ) -> MessageContainer: req = request_type.FromString(msg_container.grpc_message_content) diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py index 79f1a8f9902b..38f0dfdae299 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py @@ -19,7 +19,8 @@ """ import uuid -from typing import Callable, Iterator +from collections.abc import Iterator +from typing import Callable import grpc from iterators import TimeoutIterator diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py index 5fe0396696ab..476e2914f4d9 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py @@ -15,10 +15,11 @@ """Provides class GrpcBridge.""" +from collections.abc import Iterator from dataclasses import dataclass from enum import Enum from threading import Condition -from typing import Iterator, Optional +from typing import Optional from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py index f9b6b97030f0..6d9e081d8dd4 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py @@ -17,7 +17,7 @@ import time from threading import Thread -from typing import List, Union +from typing import Union from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -32,7 +32,7 @@ def start_worker( - rounds: int, bridge: GrpcBridge, results: List[ClientMessage] + rounds: int, bridge: GrpcBridge, results: list[ClientMessage] ) -> Thread: """Simulate processing loop with five calls.""" @@ -59,7 +59,7 @@ def test_workflow_successful() -> None: """Test full workflow.""" # Prepare rounds = 5 - client_messages_received: List[ClientMessage] = [] + client_messages_received: list[ClientMessage] = [] bridge = GrpcBridge() ins_wrapper_iterator = bridge.ins_wrapper_iterator() @@ -90,7 +90,7 @@ def test_workflow_close() -> None: """ # Prepare rounds = 5 - client_messages_received: List[ClientMessage] = [] + client_messages_received: list[ClientMessage] = [] bridge = GrpcBridge() ins_wrapper_iterator = bridge.ins_wrapper_iterator() @@ -135,7 +135,7 @@ def test_ins_wrapper_iterator_close_while_blocking() -> None: """ # Prepare rounds = 5 - client_messages_received: List[ClientMessage] = [] + client_messages_received: list[ClientMessage] = [] bridge = GrpcBridge() ins_wrapper_iterator = bridge.ins_wrapper_iterator() diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index dd78acb72fb1..b161492000f2 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -17,8 +17,9 @@ import concurrent.futures import sys +from collections.abc import Sequence from logging import ERROR -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import grpc @@ -46,7 +47,7 @@ AddServicerToServerFn = Callable[..., Any] -def valid_certificates(certificates: Tuple[bytes, bytes, bytes]) -> bool: +def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool: """Validate certificates tuple.""" is_valid = ( all(isinstance(certificate, bytes) for certificate in certificates) @@ -65,7 +66,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments max_concurrent_workers: int = 1000, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, - certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + certificates: Optional[tuple[bytes, bytes, bytes]] = None, ) -> grpc.Server: """Create and start a gRPC server running FlowerServiceServicer. @@ -157,16 +158,16 @@ def start_grpc_server( # pylint: disable=too-many-arguments def generic_create_grpc_server( # pylint: disable=too-many-arguments servicer_and_add_fn: Union[ - Tuple[FleetServicer, AddServicerToServerFn], - Tuple[GrpcAdapterServicer, AddServicerToServerFn], - Tuple[FlowerServiceServicer, AddServicerToServerFn], - Tuple[DriverServicer, AddServicerToServerFn], + tuple[FleetServicer, AddServicerToServerFn], + tuple[GrpcAdapterServicer, AddServicerToServerFn], + tuple[FlowerServiceServicer, AddServicerToServerFn], + tuple[DriverServicer, AddServicerToServerFn], ], server_address: str, max_concurrent_workers: int = 1000, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, - certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + certificates: Optional[tuple[bytes, bytes, bytes]] = None, interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Create a gRPC server with a single servicer. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py index 7ff730b17afa..9635993e0ad5 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py @@ -20,7 +20,7 @@ from contextlib import closing from os.path import abspath, dirname, join from pathlib import Path -from typing import Tuple, cast +from typing import cast from flwr.server.client_manager import SimpleClientManager from flwr.server.superlink.fleet.grpc_bidi.grpc_server import ( @@ -31,7 +31,7 @@ root_dir = dirname(abspath(join(__file__, "../../../../../../.."))) -def load_certificates() -> Tuple[str, str, str]: +def load_certificates() -> tuple[str, str, str]: """Generate and load SSL credentials/certificates. Utility function for loading for SSL-enabled gRPC servertests. diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 2c58d0049849..d836a74bef2e 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -16,8 +16,9 @@ import base64 +from collections.abc import Sequence from logging import INFO, WARNING -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -68,7 +69,7 @@ def _get_value_from_tuples( - key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] + key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] ) -> bytes: value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 64f9ac609998..85f3fa34e0ac 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -16,7 +16,7 @@ import time -from typing import List, Optional +from typing import Optional from uuid import UUID from flwr.common.serde import fab_to_proto, user_config_to_proto @@ -83,7 +83,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo node_id: Optional[int] = None if node.anonymous else node.node_id # Retrieve TaskIns from State - task_ins_list: List[TaskIns] = state.get_task_ins(node_id=node_id, limit=1) + task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1) # Build response response = PullTaskInsResponse( diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index cf5ad16f7999..a988252b3ea2 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -18,7 +18,8 @@ from __future__ import annotations import sys -from typing import Awaitable, Callable, TypeVar +from collections.abc import Awaitable +from typing import Callable, TypeVar from google.protobuf.message import Message as GrpcMessage diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py index a8c671810a51..31129fce1b1b 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py @@ -15,17 +15,16 @@ """Simulation Engine Backends.""" import importlib -from typing import Dict, Type from .backend import Backend, BackendConfig is_ray_installed = importlib.util.find_spec("ray") is not None # Mapping of supported backends -supported_backends: Dict[str, Type[Backend]] = {} +supported_backends: dict[str, type[Backend]] = {} # To log backend-specific error message when chosen backend isn't available -error_messages_backends: Dict[str, str] = {} +error_messages_backends: dict[str, str] = {} if is_ray_installed: from .raybackend import RayBackend diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py index 89341c0d238f..38be6032e3a5 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -16,14 +16,14 @@ from abc import ABC, abstractmethod -from typing import Callable, Dict, Tuple +from typing import Callable from flwr.client.client_app import ClientApp from flwr.common.context import Context from flwr.common.message import Message from flwr.common.typing import ConfigsRecordValues -BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]] +BackendConfig = dict[str, dict[str, ConfigsRecordValues]] class Backend(ABC): @@ -62,5 +62,5 @@ def process_message( self, message: Message, context: Context, - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Submit a job to the backend.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 2024b8760d95..dd79d2ef7f62 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -16,7 +16,7 @@ import sys from logging import DEBUG, ERROR -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Union import ray @@ -31,8 +31,8 @@ from .backend import Backend, BackendConfig -ClientResourcesDict = Dict[str, Union[int, float]] -ActorArgsDict = Dict[str, Union[int, float, Callable[[], None]]] +ClientResourcesDict = dict[str, Union[int, float]] +ActorArgsDict = dict[str, Union[int, float, Callable[[], None]]] class RayBackend(Backend): @@ -101,7 +101,7 @@ def _validate_actor_arguments(self, config: BackendConfig) -> ActorArgsDict: def init_ray(self, backend_config: BackendConfig) -> None: """Intialises Ray if not already initialised.""" if not ray.is_initialized(): - ray_init_args: Dict[ + ray_init_args: dict[ str, ConfigsRecordValues, ] = {} @@ -144,7 +144,7 @@ def process_message( self, message: Message, context: Context, - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Run ClientApp that process a given message. Return output message and updated context. diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index cdb11401c29c..1cbdc230c938 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -15,7 +15,7 @@ """Test for Ray backend for the Fleet API using the Simulation Engine.""" from math import pi -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Union from unittest import TestCase import ray @@ -47,7 +47,7 @@ class DummyClient(NumPyClient): def __init__(self, state: RecordSet) -> None: self.client_state = state - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return properties by doing a simple calculation.""" result = float(config["factor"]) * pi @@ -69,8 +69,8 @@ def _load_app() -> ClientApp: def backend_build_process_and_termination( backend: RayBackend, app_fn: Callable[[], ClientApp], - process_args: Optional[Tuple[Message, Context]] = None, -) -> Union[Tuple[Message, Context], None]: + process_args: Optional[tuple[Message, Context]] = None, +) -> Union[tuple[Message, Context], None]: """Build, process job and terminate RayBackend.""" backend.build(app_fn) to_return = None @@ -83,7 +83,7 @@ def backend_build_process_and_termination( return to_return -def _create_message_and_context() -> Tuple[Message, Context, float]: +def _create_message_and_context() -> tuple[Message, Context, float]: # Construct a Message mult_factor = 2024 diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 165c2de73c21..8f4e18e14e28 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -24,7 +24,7 @@ from pathlib import Path from queue import Empty, Queue from time import sleep -from typing import Callable, Dict, Optional +from typing import Callable, Optional from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.clientapp.utils import get_load_client_app_fn @@ -44,7 +44,7 @@ from .backend import Backend, error_messages_backends, supported_backends -NodeToPartitionMapping = Dict[int, int] +NodeToPartitionMapping = dict[int, int] def _register_nodes( @@ -64,9 +64,9 @@ def _register_node_states( nodes_mapping: NodeToPartitionMapping, run: Run, app_dir: Optional[str] = None, -) -> Dict[int, NodeState]: +) -> dict[int, NodeState]: """Create NodeState objects and pre-register the context for the run.""" - node_states: Dict[int, NodeState] = {} + node_states: dict[int, NodeState] = {} num_partitions = len(set(nodes_mapping.values())) for node_id, partition_id in nodes_mapping.items(): node_states[node_id] = NodeState( @@ -89,7 +89,7 @@ def _register_node_states( def worker( taskins_queue: "Queue[TaskIns]", taskres_queue: "Queue[TaskRes]", - node_states: Dict[int, NodeState], + node_states: dict[int, NodeState], backend: Backend, f_stop: threading.Event, ) -> None: @@ -177,7 +177,7 @@ def run_api( backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, state_factory: StateFactory, - node_states: Dict[int, NodeState], + node_states: dict[int, NodeState], f_stop: threading.Event, ) -> None: """Run the VCE.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 76e8ac9156d2..1cc3a8f128b6 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -22,7 +22,7 @@ from math import pi from pathlib import Path from time import sleep -from typing import Dict, Optional, Set, Tuple +from typing import Optional from unittest import TestCase from uuid import UUID @@ -57,7 +57,7 @@ class DummyClient(NumPyClient): def __init__(self, state: RecordSet) -> None: self.client_state = state - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return properties by doing a simple calculation.""" result = float(config["factor"]) * pi @@ -86,7 +86,7 @@ def terminate_simulation(f_stop: threading.Event, sleep_duration: int) -> None: def init_state_factory_nodes_mapping( num_nodes: int, num_messages: int, -) -> Tuple[StateFactory, NodeToPartitionMapping, Dict[UUID, float]]: +) -> tuple[StateFactory, NodeToPartitionMapping, dict[UUID, float]]: """Instatiate StateFactory, register nodes and pre-insert messages in the state.""" # Register a state and a run_id in it run_id = 1234 @@ -110,7 +110,7 @@ def register_messages_into_state( nodes_mapping: NodeToPartitionMapping, run_id: int, num_messages: int, -) -> Dict[UUID, float]: +) -> dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore state.run_ids[run_id] = Run( @@ -123,7 +123,7 @@ def register_messages_into_state( # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes - task_ids: Set[UUID] = set() # so we can retrieve them later + task_ids: set[UUID] = set() # so we can retrieve them later expected_results = {} for i in range(num_messages): dst_node_id = next(nodes_cycle) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index c87ba86e47e7..e34d15374350 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -18,7 +18,7 @@ import threading import time from logging import ERROR -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional from uuid import UUID, uuid4 from flwr.common import log, now @@ -37,15 +37,15 @@ class InMemoryState(State): # pylint: disable=R0902,R0904 def __init__(self) -> None: # Map node_id to (online_until, ping_interval) - self.node_ids: Dict[int, Tuple[float, float]] = {} - self.public_key_to_node_id: Dict[bytes, int] = {} + self.node_ids: dict[int, tuple[float, float]] = {} + self.public_key_to_node_id: dict[bytes, int] = {} # Map run_id to (fab_id, fab_version) - self.run_ids: Dict[int, Run] = {} - self.task_ins_store: Dict[UUID, TaskIns] = {} - self.task_res_store: Dict[UUID, TaskRes] = {} + self.run_ids: dict[int, Run] = {} + self.task_ins_store: dict[UUID, TaskIns] = {} + self.task_res_store: dict[UUID, TaskRes] = {} - self.node_public_keys: Set[bytes] = set() + self.node_public_keys: set[bytes] = set() self.server_public_key: Optional[bytes] = None self.server_private_key: Optional[bytes] = None @@ -76,13 +76,13 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: def get_task_ins( self, node_id: Optional[int], limit: Optional[int] - ) -> List[TaskIns]: + ) -> list[TaskIns]: """Get all TaskIns that have not been delivered yet.""" if limit is not None and limit < 1: raise AssertionError("`limit` must be >= 1") # Find TaskIns for node_id that were not delivered yet - task_ins_list: List[TaskIns] = [] + task_ins_list: list[TaskIns] = [] with self.lock: for _, task_ins in self.task_ins_store.items(): # pylint: disable=too-many-boolean-expressions @@ -133,15 +133,15 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Return the new task_id return task_id - def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]: + def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]: """Get all TaskRes that have not been delivered yet.""" if limit is not None and limit < 1: raise AssertionError("`limit` must be >= 1") with self.lock: # Find TaskRes that were not delivered yet - task_res_list: List[TaskRes] = [] - replied_task_ids: Set[UUID] = set() + task_res_list: list[TaskRes] = [] + replied_task_ids: set[UUID] = set() for _, task_res in self.task_res_store.items(): reply_to = UUID(task_res.task.ancestry[0]) if reply_to in task_ids and task_res.task.delivered_at == "": @@ -175,10 +175,10 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe # Return TaskRes return task_res_list - def delete_tasks(self, task_ids: Set[UUID]) -> None: + def delete_tasks(self, task_ids: set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" - task_ins_to_be_deleted: Set[UUID] = set() - task_res_to_be_deleted: Set[UUID] = set() + task_ins_to_be_deleted: set[UUID] = set() + task_res_to_be_deleted: set[UUID] = set() with self.lock: for task_ins_id in task_ids: @@ -253,7 +253,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: del self.node_ids[node_id] - def get_nodes(self, run_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> set[int]: """Return all available nodes. Constraints @@ -318,7 +318,7 @@ def get_server_public_key(self) -> Optional[bytes]: """Retrieve `server_public_key` in urlsafe bytes.""" return self.server_public_key - def store_node_public_keys(self, public_keys: Set[bytes]) -> None: + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in state.""" with self.lock: self.node_public_keys = public_keys @@ -328,7 +328,7 @@ def store_node_public_key(self, public_key: bytes) -> None: with self.lock: self.node_public_keys.add(public_key) - def get_node_public_keys(self) -> Set[bytes]: + def get_node_public_keys(self) -> set[bytes]: """Retrieve all currently stored `node_public_keys` as a set.""" return self.node_public_keys diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index daa211560912..4bb31fa6cea5 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -19,8 +19,9 @@ import re import sqlite3 import time +from collections.abc import Sequence from logging import DEBUG, ERROR -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import Any, Optional, Union, cast from uuid import UUID, uuid4 from flwr.common import log, now @@ -110,7 +111,7 @@ ); """ -DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]] +DictOrTuple = Union[tuple[Any, ...], dict[str, Any]] class SqliteState(State): # pylint: disable=R0904 @@ -131,7 +132,7 @@ def __init__( self.database_path = database_path self.conn: Optional[sqlite3.Connection] = None - def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: + def initialize(self, log_queries: bool = False) -> list[tuple[str]]: """Create tables if they don't exist yet. Parameters @@ -162,7 +163,7 @@ def query( self, query: str, data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: raise AttributeError("State is not initialized.") @@ -237,7 +238,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: def get_task_ins( self, node_id: Optional[int], limit: Optional[int] - ) -> List[TaskIns]: + ) -> list[TaskIns]: """Get undelivered TaskIns for one node (either anonymous or with ID). Usually, the Fleet API calls this for Nodes planning to work on one or more @@ -271,7 +272,7 @@ def get_task_ins( ) raise AssertionError(msg) - data: Dict[str, Union[str, int]] = {} + data: dict[str, Union[str, int]] = {} if node_id is None: # Retrieve all anonymous Tasks @@ -367,7 +368,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: return task_id # pylint: disable-next=R0914 - def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]: + def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]: """Get TaskRes for task_ids. Usually, the Driver API calls this method to get results for instructions it has @@ -397,7 +398,7 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe AND delivered_at = "" """ - data: Dict[str, Union[str, float, int]] = {} + data: dict[str, Union[str, float, int]] = {} if limit is not None: query += " LIMIT :limit" @@ -435,7 +436,7 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe # 1. Query: Fetch consumer_node_id of remaining task_ids # Assume the ancestry field only contains one element data.clear() - replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows} + replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows} remaining_task_ids = task_ids - replied_task_ids placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))]) query = f""" @@ -499,10 +500,10 @@ def num_task_res(self) -> int: """ query = "SELECT count(*) AS num FROM task_res;" rows = self.query(query) - result: Dict[str, int] = rows[0] + result: dict[str, int] = rows[0] return result["num"] - def delete_tasks(self, task_ids: Set[UUID]) -> None: + def delete_tasks(self, task_ids: set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" ids = list(task_ids) if len(ids) == 0: @@ -588,7 +589,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: except KeyError as exc: log(ERROR, {"query": query, "data": params, "exception": exc}) - def get_nodes(self, run_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -604,7 +605,7 @@ def get_nodes(self, run_id: int) -> Set[int]: # Get nodes query = "SELECT node_id FROM node WHERE online_until > ?;" rows = self.query(query, (time.time(),)) - result: Set[int] = {row["node_id"] for row in rows} + result: set[int] = {row["node_id"] for row in rows} return result def get_node_id(self, node_public_key: bytes) -> Optional[int]: @@ -684,7 +685,7 @@ def get_server_public_key(self) -> Optional[bytes]: public_key = None return public_key - def store_node_public_keys(self, public_keys: Set[bytes]) -> None: + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in state.""" query = "INSERT INTO public_key (public_key) VALUES (?)" data = [(key,) for key in public_keys] @@ -695,11 +696,11 @@ def store_node_public_key(self, public_key: bytes) -> None: query = "INSERT INTO public_key (public_key) VALUES (:public_key)" self.query(query, {"public_key": public_key}) - def get_node_public_keys(self) -> Set[bytes]: + def get_node_public_keys(self) -> set[bytes]: """Retrieve all currently stored `node_public_keys` as a set.""" query = "SELECT public_key FROM public_key" rows = self.query(query) - result: Set[bytes] = {row["public_key"] for row in rows} + result: set[bytes] = {row["public_key"] for row in rows} return result def get_run(self, run_id: int) -> Optional[Run]: @@ -733,7 +734,7 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: def dict_factory( cursor: sqlite3.Cursor, row: sqlite3.Row, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Turn SQLite results into dicts. Less efficent for retrival of large amounts of data but easier to use. @@ -742,7 +743,7 @@ def dict_factory( return dict(zip(fields, row)) -def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: +def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]: """Transform TaskIns to dict.""" result = { "task_id": task_msg.task_id, @@ -763,7 +764,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: return result -def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: +def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]: """Transform TaskRes to dict.""" result = { "task_id": task_msg.task_id, @@ -784,7 +785,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: return result -def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: +def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns: """Turn task_dict into protobuf message.""" recordset = RecordSet() recordset.ParseFromString(task_dict["recordset"]) @@ -814,7 +815,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: return result -def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: +def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes: """Turn task_dict into protobuf message.""" recordset = RecordSet() recordset.ParseFromString(task_dict["recordset"]) diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index fea53105b23f..39da052fb0aa 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,7 +16,7 @@ import abc -from typing import List, Optional, Set +from typing import Optional from uuid import UUID from flwr.common.typing import Run, UserConfig @@ -51,7 +51,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: @abc.abstractmethod def get_task_ins( self, node_id: Optional[int], limit: Optional[int] - ) -> List[TaskIns]: + ) -> list[TaskIns]: """Get TaskIns optionally filtered by node_id. Usually, the Fleet API calls this for Nodes planning to work on one or more @@ -98,7 +98,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: """ @abc.abstractmethod - def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]: + def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]: """Get TaskRes for task_ids. Usually, the Driver API calls this method to get results for instructions it has @@ -129,7 +129,7 @@ def num_task_res(self) -> int: """ @abc.abstractmethod - def delete_tasks(self, task_ids: Set[UUID]) -> None: + def delete_tasks(self, task_ids: set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" @abc.abstractmethod @@ -143,7 +143,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: """Remove `node_id` from state.""" @abc.abstractmethod - def get_nodes(self, run_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -199,7 +199,7 @@ def get_server_public_key(self) -> Optional[bytes]: """Retrieve `server_public_key` in urlsafe bytes.""" @abc.abstractmethod - def store_node_public_keys(self, public_keys: Set[bytes]) -> None: + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in state.""" @abc.abstractmethod @@ -207,7 +207,7 @@ def store_node_public_key(self, public_key: bytes) -> None: """Store a `node_public_key` in state.""" @abc.abstractmethod - def get_node_public_keys(self) -> Set[bytes]: + def get_node_public_keys(self) -> set[bytes]: """Retrieve all currently stored `node_public_keys` as a set.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 0cf30a42ca2c..42c0768f1c7d 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -20,7 +20,6 @@ import unittest from abc import abstractmethod from datetime import datetime, timezone -from typing import List from unittest.mock import patch from uuid import uuid4 @@ -655,7 +654,7 @@ def test_node_unavailable_error(self) -> None: # Execute current_time = time.time() - task_res_list: List[TaskRes] = [] + task_res_list: list[TaskRes] = [] with patch("time.time", side_effect=lambda: current_time + 50): task_res_list = state.get_task_res({task_id_0, task_id_1}, limit=None) @@ -698,7 +697,7 @@ def create_task_ins( def create_task_res( producer_node_id: int, anonymous: bool, - ancestry: List[str], + ancestry: list[str], run_id: int, ) -> TaskRes: """Create a TaskRes for testing.""" diff --git a/src/py/flwr/server/utils/tensorboard.py b/src/py/flwr/server/utils/tensorboard.py index 5d38fc159657..281e8949c53c 100644 --- a/src/py/flwr/server/utils/tensorboard.py +++ b/src/py/flwr/server/utils/tensorboard.py @@ -18,7 +18,7 @@ import os from datetime import datetime from logging import WARN -from typing import Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Callable, Optional, Union, cast from flwr.common import EvaluateRes, Scalar from flwr.common.logger import log @@ -92,9 +92,9 @@ class TBWrapper(strategy_class): # type: ignore def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Hooks into aggregate_evaluate for TensorBoard logging purpose.""" # Execute decorated function and extract results for logging # They will be returned at the end of this function but also diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index c0b0ec85761c..fb3d0425db86 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -15,13 +15,13 @@ """Validators.""" -from typing import List, Union +from typing import Union from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 # pylint: disable-next=too-many-branches,too-many-statements -def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str]: +def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str]: """Validate a TaskIns or TaskRes.""" validation_errors = [] diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 61fe094c23d4..20162883efea 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -17,7 +17,6 @@ import time import unittest -from typing import List, Tuple from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 @@ -52,12 +51,12 @@ def test_is_valid_task_res(self) -> None: """Test is_valid task_res.""" # Prepare # (producer_node_id, anonymous, ancestry) - valid_res: List[Tuple[int, bool, List[str]]] = [ + valid_res: list[tuple[int, bool, list[str]]] = [ (0, True, ["1"]), (1, False, ["1"]), ] - invalid_res: List[Tuple[int, bool, List[str]]] = [ + invalid_res: list[tuple[int, bool, list[str]]] = [ (0, False, []), (0, False, ["1"]), (0, True, []), @@ -110,7 +109,7 @@ def create_task_ins( def create_task_res( producer_node_id: int, anonymous: bool, - ancestry: List[str], + ancestry: list[str], ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index 82d8d5d4ccb6..484a747292d5 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -18,7 +18,7 @@ import io import timeit from logging import INFO, WARN -from typing import List, Optional, Tuple, Union, cast +from typing import Optional, Union, cast import flwr.common.recordset_compat as compat from flwr.common import ( @@ -276,8 +276,8 @@ def default_fit_workflow( # pylint: disable=R0914 ) # Aggregate training results - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] for msg in messages: if msg.has_content(): proxy = node_id_to_proxy[msg.metadata.src_node_id] @@ -362,8 +362,8 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: ) # Aggregate the evaluation results - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] for msg in messages: if msg.has_content(): proxy = node_id_to_proxy[msg.metadata.src_node_id] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index 322e32ed5019..d84a5496dfe1 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -18,7 +18,7 @@ import random from dataclasses import dataclass, field from logging import DEBUG, ERROR, INFO, WARN -from typing import Dict, List, Optional, Set, Tuple, Union, cast +from typing import Optional, Union, cast import flwr.common.recordset_compat as compat from flwr.common import ( @@ -65,23 +65,23 @@ class WorkflowState: # pylint: disable=R0902 """The state of the SecAgg+ protocol.""" - nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict) - nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict) - sampled_node_ids: Set[int] = field(default_factory=set) - active_node_ids: Set[int] = field(default_factory=set) + nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict) + nid_to_fitins: dict[int, RecordSet] = field(default_factory=dict) + sampled_node_ids: set[int] = field(default_factory=set) + active_node_ids: set[int] = field(default_factory=set) num_shares: int = 0 threshold: int = 0 clipping_range: float = 0.0 quantization_range: int = 0 mod_range: int = 0 max_weight: float = 0.0 - nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict) - nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict) - forward_srcs: Dict[int, List[int]] = field(default_factory=dict) - forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict) + nid_to_neighbours: dict[int, set[int]] = field(default_factory=dict) + nid_to_publickeys: dict[int, list[bytes]] = field(default_factory=dict) + forward_srcs: dict[int, list[int]] = field(default_factory=dict) + forward_ciphertexts: dict[int, list[bytes]] = field(default_factory=dict) aggregate_ndarrays: NDArrays = field(default_factory=list) - legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list) - failures: List[Exception] = field(default_factory=list) + legacy_results: list[tuple[ClientProxy, FitRes]] = field(default_factory=list) + failures: list[Exception] = field(default_factory=list) class SecAggPlusWorkflow: @@ -444,13 +444,13 @@ def make(nid: int) -> Message: ) # Build forward packet list dictionary - srcs: List[int] = [] - dsts: List[int] = [] - ciphertexts: List[bytes] = [] - fwd_ciphertexts: Dict[int, List[bytes]] = { + srcs: list[int] = [] + dsts: list[int] = [] + ciphertexts: list[bytes] = [] + fwd_ciphertexts: dict[int, list[bytes]] = { nid: [] for nid in state.active_node_ids } # dest node ID -> list of ciphertexts - fwd_srcs: Dict[int, List[int]] = { + fwd_srcs: dict[int, list[int]] = { nid: [] for nid in state.active_node_ids } # dest node ID -> list of src node IDs for msg in msgs: @@ -459,8 +459,8 @@ def make(nid: int) -> Message: continue node_id = msg.metadata.src_node_id res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] - dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST]) - ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST]) + dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST]) + ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST]) srcs += [node_id] * len(dst_lst) dsts += dst_lst ciphertexts += ctxt_lst @@ -525,7 +525,7 @@ def make(nid: int) -> Message: state.failures.append(Exception(msg.error)) continue res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] - bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS]) + bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS]) client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list] if masked_vector is None: masked_vector = client_masked_vec @@ -592,7 +592,7 @@ def make(nid: int) -> Message: ) # Build collected shares dict - collected_shares_dict: Dict[int, List[bytes]] = {} + collected_shares_dict: dict[int, list[bytes]] = {} for nid in state.sampled_node_ids: collected_shares_dict[nid] = [] for msg in msgs: @@ -600,8 +600,8 @@ def make(nid: int) -> Message: state.failures.append(Exception(msg.error)) continue res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] - nids = cast(List[int], res_dict[Key.NODE_ID_LIST]) - shares = cast(List[bytes], res_dict[Key.SHARE_LIST]) + nids = cast(list[int], res_dict[Key.NODE_ID_LIST]) + shares = cast(list[bytes], res_dict[Key.SHARE_LIST]) for owner_nid, share in zip(nids, shares): collected_shares_dict[owner_nid].append(share) diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 973a9a89e652..0070d75c53dc 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -22,7 +22,7 @@ import traceback import warnings from logging import ERROR, INFO -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Optional, Union import ray from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -72,7 +72,7 @@ """ -NodeToPartitionMapping = Dict[int, int] +NodeToPartitionMapping = dict[int, int] def _create_node_id_to_partition_mapping( @@ -94,16 +94,16 @@ def start_simulation( *, client_fn: ClientFnExt, num_clients: int, - clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED - client_resources: Optional[Dict[str, float]] = None, + clients_ids: Optional[list[str]] = None, # UNSUPPORTED, WILL BE REMOVED + client_resources: Optional[dict[str, float]] = None, server: Optional[Server] = None, config: Optional[ServerConfig] = None, strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, - ray_init_args: Optional[Dict[str, Any]] = None, + ray_init_args: Optional[dict[str, Any]] = None, keep_initialised: Optional[bool] = False, - actor_type: Type[VirtualClientEngineActor] = ClientAppActor, - actor_kwargs: Optional[Dict[str, Any]] = None, + actor_type: type[VirtualClientEngineActor] = ClientAppActor, + actor_kwargs: Optional[dict[str, Any]] = None, actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT", ) -> History: """Start a Ray-based Flower simulation server. @@ -279,7 +279,7 @@ def start_simulation( # An actor factory. This is called N times to add N actors # to the pool. If at some point the pool can accommodate more actors # this will be called again. - def create_actor_fn() -> Type[VirtualClientEngineActor]: + def create_actor_fn() -> type[VirtualClientEngineActor]: return actor_type.options( # type: ignore **client_resources, scheduling_strategy=actor_scheduling, diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 698eb78f2aef..4fb48a99b689 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -17,7 +17,7 @@ import threading from abc import ABC from logging import DEBUG, ERROR, WARNING -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import ray from ray import ObjectRef @@ -44,7 +44,7 @@ def run( message: Message, cid: str, context: Context, - ) -> Tuple[str, Message, Context]: + ) -> tuple[str, Message, Context]: """Run a client run.""" # Pass message through ClientApp and return a message # return also cid which is needed to ensure results @@ -81,7 +81,7 @@ def __init__(self, on_actor_init_fn: Optional[Callable[[], None]] = None) -> Non on_actor_init_fn() -def pool_size_from_resources(client_resources: Dict[str, Union[int, float]]) -> int: +def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) -> int: """Calculate number of Actors that fit in the cluster. For this we consider the resources available on each node and those required per @@ -162,9 +162,9 @@ class VirtualClientEngineActorPool(ActorPool): def __init__( self, - create_actor_fn: Callable[[], Type[VirtualClientEngineActor]], - client_resources: Dict[str, Union[int, float]], - actor_list: Optional[List[Type[VirtualClientEngineActor]]] = None, + create_actor_fn: Callable[[], type[VirtualClientEngineActor]], + client_resources: dict[str, Union[int, float]], + actor_list: Optional[list[type[VirtualClientEngineActor]]] = None, ): self.client_resources = client_resources self.create_actor_fn = create_actor_fn @@ -183,10 +183,10 @@ def __init__( # A dict that maps cid to another dict containing: a reference to the remote job # and its status (i.e. whether it is ready or not) - self._cid_to_future: Dict[ - str, Dict[str, Union[bool, Optional[ObjectRef[Any]]]] + self._cid_to_future: dict[ + str, dict[str, Union[bool, Optional[ObjectRef[Any]]]] ] = {} - self.actor_to_remove: Set[str] = set() # a set + self.actor_to_remove: set[str] = set() # a set self.num_actors = len(actors) self.lock = threading.RLock() @@ -210,7 +210,7 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> None: + def submit(self, fn: Any, value: tuple[ClientAppFn, Message, str, Context]) -> None: """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then @@ -220,7 +220,7 @@ def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> N actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): future = fn(actor, app_fn, mssg, cid, context) - future_key = tuple(future) if isinstance(future, List) else future + future_key = tuple(future) if isinstance(future, list) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -228,7 +228,7 @@ def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> N self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] + self, actor_fn: Any, job: tuple[ClientAppFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -268,7 +268,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: + def _fetch_future_result(self, cid: str) -> tuple[Message, Context]: """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -382,7 +382,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -403,14 +403,14 @@ class BasicActorPool: def __init__( self, - actor_type: Type[VirtualClientEngineActor], - client_resources: Dict[str, Union[int, float]], - actor_kwargs: Dict[str, Any], + actor_type: type[VirtualClientEngineActor], + client_resources: dict[str, Union[int, float]], + actor_kwargs: dict[str, Any], ): self.client_resources = client_resources # Queue of idle actors - self.pool: List[VirtualClientEngineActor] = [] + self.pool: list[VirtualClientEngineActor] = [] self.num_actors = 0 # Resolve arguments to pass during actor init @@ -424,7 +424,7 @@ def __init__( # Figure out how many actors can be created given the cluster resources # and the resources the user indicates each VirtualClient will need self.actors_capacity = pool_size_from_resources(client_resources) - self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {} + self._future_to_actor: dict[Any, VirtualClientEngineActor] = {} def is_actor_available(self) -> bool: """Return true if there is an idle actor.""" @@ -450,7 +450,7 @@ def terminate_all_actors(self) -> None: log(DEBUG, "Terminated %i actors", num_terminated) def submit( - self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] + self, actor_fn: Any, job: tuple[ClientAppFn, Message, str, Context] ) -> Any: """On idle actor, submit job and return future.""" # Remove idle actor from pool @@ -470,7 +470,7 @@ def add_actor_back_to_pool(self, future: Any) -> None: def fetch_result_and_return_actor_to_pool( self, future: Any - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Pull result given a future and add actor back to pool.""" # Retrieve result for object store # Instead of doing ray.get(future) we await it diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 1c2aa455d9cd..ce0ef46d135f 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,7 +17,6 @@ from math import pi from random import shuffle -from typing import Dict, List, Tuple, Type import ray @@ -60,7 +59,7 @@ def __init__(self, node_id: int, state: RecordSet) -> None: self.node_id = node_id self.client_state = state - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return properties by doing a simple calculation.""" result = self.node_id * pi # store something in context @@ -76,14 +75,14 @@ def get_dummy_client(context: Context) -> Client: def prep( - actor_type: Type[VirtualClientEngineActor] = ClientAppActor, -) -> Tuple[ - List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping + actor_type: type[VirtualClientEngineActor] = ClientAppActor, +) -> tuple[ + list[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping ]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} - def create_actor_fn() -> Type[VirtualClientEngineActor]: + def create_actor_fn() -> type[VirtualClientEngineActor]: return actor_type.options(**client_resources).remote() # type: ignore # Create actor pool @@ -195,7 +194,7 @@ def test_cid_consistency_without_proxies() -> None: node_ids = list(mapping.keys()) # register node states - node_states: Dict[int, NodeState] = {} + node_states: dict[int, NodeState] = {} for node_id, partition_id in mapping.items(): node_states[node_id] = NodeState( node_id=node_id, diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index be6410dcbd6b..2d29629c4f01 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -25,7 +25,7 @@ from logging import DEBUG, ERROR, INFO, WARNING from pathlib import Path from time import sleep -from typing import Any, List, Optional +from typing import Any, Optional from flwr.cli.config_utils import load_and_validate from flwr.client import ClientApp @@ -56,7 +56,7 @@ def _check_args_do_not_interfere(args: Namespace) -> bool: mode_one_args = ["app", "run_config"] mode_two_args = ["client_app", "server_app"] - def _resolve_message(conflict_keys: List[str]) -> str: + def _resolve_message(conflict_keys: list[str]) -> str: return ",".join([f"`--{key}`".replace("_", "-") for key in conflict_keys]) # When passing `--app`, `--app-dir` is ignored diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index 36f781706146..c00aa0f88e7b 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -18,7 +18,7 @@ import sys from logging import INFO, WARN from pathlib import Path -from typing import Optional, Tuple +from typing import Optional import grpc @@ -130,7 +130,7 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser: def _try_obtain_certificates( args: argparse.Namespace, -) -> Optional[Tuple[bytes, bytes, bytes]]: +) -> Optional[tuple[bytes, bytes, bytes]]: # Obtain certificates if args.insecure: log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.") diff --git a/src/py/flwr/superexec/exec_grpc.py b/src/py/flwr/superexec/exec_grpc.py index a32ebc1b3e35..017395bc8002 100644 --- a/src/py/flwr/superexec/exec_grpc.py +++ b/src/py/flwr/superexec/exec_grpc.py @@ -15,7 +15,7 @@ """SuperExec gRPC API.""" from logging import INFO -from typing import Optional, Tuple +from typing import Optional import grpc @@ -32,7 +32,7 @@ def run_superexec_api_grpc( address: str, executor: Executor, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], config: UserConfig, ) -> grpc.Server: """Run SuperExec API (gRPC, request-response).""" diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index dda3e96994de..5b729dbc2b8e 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -15,8 +15,9 @@ """SuperExec API servicer.""" +from collections.abc import Generator from logging import ERROR, INFO -from typing import Any, Dict, Generator +from typing import Any import grpc @@ -38,7 +39,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer): def __init__(self, executor: Executor) -> None: self.executor = executor - self.runs: Dict[int, RunTracker] = {} + self.runs: dict[int, RunTracker] = {} def StartRun( self, request: StartRunRequest, context: grpc.ServicerContext