diff --git a/examples/quickstart-monai/monaiexample/task.py b/examples/quickstart-monai/monaiexample/task.py index 09597562a1f2..4f7972d455fd 100644 --- a/examples/quickstart-monai/monaiexample/task.py +++ b/examples/quickstart-monai/monaiexample/task.py @@ -189,9 +189,10 @@ def _download_and_extract_if_needed(url, dest_folder): # Download the tar.gz file tar_gz_filename = url.split("/")[-1] if not os.path.isfile(tar_gz_filename): - with request.urlopen(url) as response, open( - tar_gz_filename, "wb" - ) as out_file: + with ( + request.urlopen(url) as response, + open(tar_gz_filename, "wb") as out_file, + ): out_file.write(response.read()) # Extract the tar.gz file diff --git a/pyproject.toml b/pyproject.toml index b4ff7be41bf3..fa543d3183f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,7 +144,7 @@ known_first_party = ["flwr", "flwr_tool"] [tool.black] line-length = 88 -target-version = ["py38", "py39", "py310", "py311"] +target-version = ["py39", "py310", "py311"] [tool.pylint."MESSAGES CONTROL"] disable = "duplicate-code,too-few-public-methods,useless-import-alias" @@ -193,7 +193,7 @@ wrap-summaries = 88 wrap-descriptions = 88 [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 88 select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] diff --git a/src/py/flwr/cli/build.py b/src/py/flwr/cli/build.py index 676bc1723568..137e2dc31aff 100644 --- a/src/py/flwr/cli/build.py +++ b/src/py/flwr/cli/build.py @@ -17,12 +17,11 @@ import os import zipfile from pathlib import Path -from typing import Optional +from typing import Annotated, Optional import pathspec import tomli_w import typer -from typing_extensions import Annotated from .config_utils import load_and_validate from .utils import get_sha256_hash, is_valid_project_name diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index 233d35a5fa17..79e4973ccf9c 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -17,7 +17,7 @@ import zipfile from io import BytesIO from pathlib import Path -from typing import IO, Any, Dict, List, Optional, Tuple, Union, get_args +from typing import IO, Any, Optional, Union, get_args import tomli @@ -25,7 +25,7 @@ from flwr.common.typing import UserConfigValue -def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]: +def get_fab_config(fab_file: Union[Path, bytes]) -> dict[str, Any]: """Extract the config from a FAB file or path. Parameters @@ -62,7 +62,7 @@ def get_fab_config(fab_file: Union[Path, bytes]) -> Dict[str, Any]: return conf -def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]: +def get_fab_metadata(fab_file: Union[Path, bytes]) -> tuple[str, str]: """Extract the fab_id and the fab_version from a FAB file or path. Parameters @@ -87,7 +87,7 @@ def get_fab_metadata(fab_file: Union[Path, bytes]) -> Tuple[str, str]: def load_and_validate( path: Optional[Path] = None, check_module: bool = True, -) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]: +) -> tuple[Optional[dict[str, Any]], list[str], list[str]]: """Load and validate pyproject.toml as dict. Returns @@ -116,7 +116,7 @@ def load_and_validate( return (config, errors, warnings) -def load(toml_path: Path) -> Optional[Dict[str, Any]]: +def load(toml_path: Path) -> Optional[dict[str, Any]]: """Load pyproject.toml and return as dict.""" if not toml_path.is_file(): return None @@ -125,7 +125,7 @@ def load(toml_path: Path) -> Optional[Dict[str, Any]]: return load_from_string(toml_file.read()) -def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None: +def _validate_run_config(config_dict: dict[str, Any], errors: list[str]) -> None: for key, value in config_dict.items(): if isinstance(value, dict): _validate_run_config(config_dict[key], errors) @@ -137,7 +137,7 @@ def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None # pylint: disable=too-many-branches -def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: +def validate_fields(config: dict[str, Any]) -> tuple[bool, list[str], list[str]]: """Validate pyproject.toml fields.""" errors = [] warnings = [] @@ -183,10 +183,10 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] def validate( - config: Dict[str, Any], + config: dict[str, Any], check_module: bool = True, project_dir: Optional[Union[str, Path]] = None, -) -> Tuple[bool, List[str], List[str]]: +) -> tuple[bool, list[str], list[str]]: """Validate pyproject.toml.""" is_valid, errors, warnings = validate_fields(config) @@ -210,7 +210,7 @@ def validate( return True, [], [] -def load_from_string(toml_content: str) -> Optional[Dict[str, Any]]: +def load_from_string(toml_content: str) -> Optional[dict[str, Any]]: """Load TOML content from a string and return as dict.""" try: data = tomli.loads(toml_content) diff --git a/src/py/flwr/cli/config_utils_test.py b/src/py/flwr/cli/config_utils_test.py index cad6714521e3..ddabc152bc0f 100644 --- a/src/py/flwr/cli/config_utils_test.py +++ b/src/py/flwr/cli/config_utils_test.py @@ -17,7 +17,7 @@ import os import textwrap from pathlib import Path -from typing import Any, Dict +from typing import Any from .config_utils import load, validate, validate_fields @@ -155,7 +155,7 @@ def test_load_pyproject_toml_from_path(tmp_path: Path) -> None: def test_validate_pyproject_toml_fields_empty() -> None: """Test that validate_pyproject_toml_fields fails correctly.""" # Prepare - config: Dict[str, Any] = {} + config: dict[str, Any] = {} # Execute is_valid, errors, warnings = validate_fields(config) diff --git a/src/py/flwr/cli/install.py b/src/py/flwr/cli/install.py index 4318ccdf9ffb..8e3e9505898c 100644 --- a/src/py/flwr/cli/install.py +++ b/src/py/flwr/cli/install.py @@ -21,10 +21,9 @@ import zipfile from io import BytesIO from pathlib import Path -from typing import IO, Optional, Union +from typing import IO, Annotated, Optional, Union import typer -from typing_extensions import Annotated from flwr.common.config import get_flwr_dir diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 90e4970d5928..d2f7179b45b4 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -18,10 +18,9 @@ from enum import Enum from pathlib import Path from string import Template -from typing import Dict, Optional +from typing import Annotated, Optional import typer -from typing_extensions import Annotated from ..utils import ( is_valid_project_name, @@ -70,7 +69,7 @@ def load_template(name: str) -> str: return tpl_file.read() -def render_template(template: str, data: Dict[str, str]) -> str: +def render_template(template: str, data: dict[str, str]) -> str: """Render template.""" tpl_file = load_template(template) tpl = Template(tpl_file) @@ -85,7 +84,7 @@ def create_file(file_path: Path, content: str) -> None: file_path.write_text(content) -def render_and_create(file_path: Path, template: str, context: Dict[str, str]) -> None: +def render_and_create(file_path: Path, template: str, context: dict[str, str]) -> None: """Render template and write to file.""" content = render_template(template, context) create_file(file_path, content) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 6375e71522de..905055ac70c0 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -20,10 +20,9 @@ import sys from logging import DEBUG from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Optional import typer -from typing_extensions import Annotated from flwr.cli.build import build from flwr.cli.config_utils import load_and_validate @@ -52,7 +51,7 @@ def run( typer.Argument(help="Name of the federation to run the app on."), ] = None, config_overrides: Annotated[ - Optional[List[str]], + Optional[list[str]], typer.Option( "--run-config", "-c", @@ -125,8 +124,8 @@ def run( def _run_with_superexec( app: Path, - federation_config: Dict[str, Any], - config_overrides: Optional[List[str]], + federation_config: dict[str, Any], + config_overrides: Optional[list[str]], ) -> None: insecure_str = federation_config.get("insecure") @@ -187,8 +186,8 @@ def _run_with_superexec( def _run_without_superexec( app: Optional[Path], - federation_config: Dict[str, Any], - config_overrides: Optional[List[str]], + federation_config: dict[str, Any], + config_overrides: Optional[list[str]], federation: str, ) -> None: try: diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index 2f5a8831fa7c..e725fdd3f951 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -17,7 +17,7 @@ import hashlib import re from pathlib import Path -from typing import Callable, List, Optional, cast +from typing import Callable, Optional, cast import typer @@ -40,7 +40,7 @@ def prompt_text( return cast(str, result) -def prompt_options(text: str, options: List[str]) -> str: +def prompt_options(text: str, options: list[str]) -> str: """Ask user to select one of the given options and return the selected item.""" # Turn options into a list with index as in " [ 0] quickstart-pytorch" options_formatted = [ diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 78db5639ff0f..90c50aba7fad 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -18,10 +18,11 @@ import subprocess import sys import time +from contextlib import AbstractContextManager from dataclasses import dataclass from logging import ERROR, INFO, WARN from pathlib import Path -from typing import Callable, ContextManager, Dict, Optional, Tuple, Type, Union, cast +from typing import Callable, Optional, Union, cast import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -94,7 +95,7 @@ def start_client( insecure: Optional[bool] = None, transport: Optional[str] = None, authentication_keys: Optional[ - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, @@ -204,7 +205,7 @@ def start_client_internal( insecure: Optional[bool] = None, transport: Optional[str] = None, authentication_keys: Optional[ - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, @@ -356,7 +357,7 @@ def _on_backoff(retry_state: RetryState) -> None: # NodeState gets initialized when the first connection is established node_state: Optional[NodeState] = None - runs: Dict[int, Run] = {} + runs: dict[int, Run] = {} while not app_state_tracker.interrupt: sleep_duration: int = 0 @@ -689,7 +690,7 @@ def start_numpy_client( ) -def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ +def _init_connection(transport: Optional[str], server_address: str) -> tuple[ Callable[ [ str, @@ -697,10 +698,10 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ RetryInvoker, int, Union[bytes, str, None], - Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]], + Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]], ], - ContextManager[ - Tuple[ + AbstractContextManager[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], @@ -711,7 +712,7 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[ ], ], str, - Type[Exception], + type[Exception], ]: # Parse IP address parsed_address = parse_address(server_address) @@ -769,7 +770,7 @@ def signal_handler(sig, frame): # type: ignore signal.signal(signal.SIGTERM, signal_handler) -def run_clientappio_api_grpc(address: str) -> Tuple[grpc.Server, ClientAppIoServicer]: +def run_clientappio_api_grpc(address: str) -> tuple[grpc.Server, ClientAppIoServicer]: """Run ClientAppIo API gRPC server.""" clientappio_servicer: grpc.Server = ClientAppIoServicer() clientappio_add_servicer_to_server_fn = add_ClientAppIoServicer_to_server diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 74ade03f973a..723a066ea0bc 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -15,8 +15,6 @@ """Flower Client app tests.""" -from typing import Dict, Tuple - from flwr.common import ( Config, EvaluateIns, @@ -59,7 +57,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: class NeedsWrappingClient(NumPyClient): """Client implementation extending the high-level NumPyClient.""" - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Raise an Exception because this method is not expected to be called.""" raise NotImplementedError() @@ -69,13 +67,13 @@ def get_parameters(self, config: Config) -> NDArrays: def fit( self, parameters: NDArrays, config: Config - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" raise NotImplementedError() def evaluate( self, parameters: NDArrays, config: Config - ) -> Tuple[float, int, Dict[str, Scalar]]: + ) -> tuple[float, int, dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" raise NotImplementedError() diff --git a/src/py/flwr/client/client_app.py b/src/py/flwr/client/client_app.py index c322ba747114..234d84f27782 100644 --- a/src/py/flwr/client/client_app.py +++ b/src/py/flwr/client/client_app.py @@ -16,7 +16,7 @@ import inspect -from typing import Callable, List, Optional +from typing import Callable, Optional from flwr.client.client import Client from flwr.client.message_handler.message_handler import ( @@ -109,9 +109,9 @@ class ClientApp: def __init__( self, client_fn: Optional[ClientFnExt] = None, # Only for backward compatibility - mods: Optional[List[Mod]] = None, + mods: Optional[list[Mod]] = None, ) -> None: - self._mods: List[Mod] = mods if mods is not None else [] + self._mods: list[Mod] = mods if mods is not None else [] # Create wrapper function for `handle` self._call: Optional[ClientAppCallable] = None diff --git a/src/py/flwr/client/clientapp/app.py b/src/py/flwr/client/clientapp/app.py index 69d334fead14..f493128bebac 100644 --- a/src/py/flwr/client/clientapp/app.py +++ b/src/py/flwr/client/clientapp/app.py @@ -17,7 +17,7 @@ import argparse import time from logging import DEBUG, ERROR, INFO -from typing import Optional, Tuple +from typing import Optional import grpc @@ -196,7 +196,7 @@ def get_token(stub: grpc.Channel) -> Optional[int]: def pull_message( stub: grpc.Channel, token: int -) -> Tuple[Message, Context, Run, Optional[Fab]]: +) -> tuple[Message, Context, Run, Optional[Fab]]: """Pull message from SuperNode to ClientApp.""" log(INFO, "Pulling ClientAppInputs for token %s", token) try: diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index c592d10936d5..bade811b48ce 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -16,7 +16,6 @@ import copy -from typing import Dict, Tuple import numpy as np @@ -39,7 +38,7 @@ def __init__(self, client: NumPyClient) -> None: super().__init__() self.client = client - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Get client properties using the given Numpy client. Parameters @@ -58,7 +57,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: """ return self.client.get_properties(config) - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + def get_parameters(self, config: dict[str, Scalar]) -> NDArrays: """Return the current local model parameters. Parameters @@ -76,8 +75,8 @@ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: return self.client.get_parameters(config) def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Train the provided parameters using the locally held dataset. This method first updates the local model using the original parameters @@ -153,8 +152,8 @@ def fit( return updated_params, num_examples, metrics def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[float, int, dict[str, Scalar]]: """Evaluate the provided parameters using the locally held dataset. Parameters diff --git a/src/py/flwr/client/grpc_adapter_client/connection.py b/src/py/flwr/client/grpc_adapter_client/connection.py index f9f7b1043524..9b84545eacdb 100644 --- a/src/py/flwr/client/grpc_adapter_client/connection.py +++ b/src/py/flwr/client/grpc_adapter_client/connection.py @@ -15,9 +15,10 @@ """Contextmanager for a GrpcAdapter channel to the Flower server.""" +from collections.abc import Iterator from contextlib import contextmanager from logging import ERROR -from typing import Callable, Iterator, Optional, Tuple, Union +from typing import Callable, Optional, Union from cryptography.hazmat.primitives.asymmetric import ec @@ -38,10 +39,10 @@ def grpc_adapter( # pylint: disable=R0913 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ # pylint: disable=unused-argument - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 489891f55436..29479cf5479d 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -16,11 +16,12 @@ import uuid +from collections.abc import Iterator from contextlib import contextmanager from logging import DEBUG, ERROR from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Tuple, Union, cast +from typing import Callable, Optional, Union, cast from cryptography.hazmat.primitives.asymmetric import ec @@ -66,10 +67,10 @@ def grpc_connection( # pylint: disable=R0913, R0915 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ # pylint: disable=unused-argument - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index bd377ef3470a..13bd2c6af8e7 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -17,8 +17,9 @@ import concurrent.futures import socket +from collections.abc import Iterator from contextlib import closing -from typing import Iterator, cast +from typing import cast from unittest.mock import patch import grpc diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 8e8b701ca272..653e384aff96 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -17,8 +17,9 @@ import base64 import collections +from collections.abc import Sequence from logging import WARNING -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -53,7 +54,7 @@ def _get_value_from_tuples( - key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] + key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] ) -> bytes: value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index 72ac20738ad6..27f759a71713 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -18,9 +18,10 @@ import base64 import threading import unittest +from collections.abc import Sequence from concurrent import futures from logging import DEBUG, INFO, WARN -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Union import grpc @@ -60,7 +61,7 @@ def __init__(self) -> None: """Initialize mock servicer.""" self._lock = threading.Lock() self._received_client_metadata: Optional[ - Sequence[Tuple[str, Union[str, bytes]]] + Sequence[tuple[str, Union[str, bytes]]] ] = None self.server_private_key, self.server_public_key = generate_key_pairs() self._received_message_bytes: bytes = b"" @@ -105,7 +106,7 @@ def unary_unary( def received_client_metadata( self, - ) -> Optional[Sequence[Tuple[str, Union[str, bytes]]]]: + ) -> Optional[Sequence[tuple[str, Union[str, bytes]]]]: """Return received client metadata.""" with self._lock: return self._received_client_metadata @@ -151,7 +152,7 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: def _get_value_from_tuples( - key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] + key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] ) -> bytes: value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 8bae253c819a..7ce3d37b7a17 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -17,11 +17,12 @@ import random import threading +from collections.abc import Iterator, Sequence from contextlib import contextmanager from copy import copy from logging import DEBUG, ERROR from pathlib import Path -from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast +from typing import Callable, Optional, Union, cast import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -77,11 +78,11 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, - adapter_cls: Optional[Union[Type[FleetStub], Type[GrpcAdapter]]] = None, + adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], diff --git a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py index fde03943a852..3dce14c14956 100644 --- a/src/py/flwr/client/grpc_rere_client/grpc_adapter.py +++ b/src/py/flwr/client/grpc_rere_client/grpc_adapter.py @@ -17,7 +17,7 @@ import sys from logging import DEBUG -from typing import Any, Type, TypeVar, cast +from typing import Any, TypeVar, cast import grpc from google.protobuf.message import Message as GrpcMessage @@ -59,7 +59,7 @@ def __init__(self, channel: grpc.Channel) -> None: self.stub = GrpcAdapterStub(channel) def _send_and_receive( - self, request: GrpcMessage, response_type: Type[T], **kwargs: Any + self, request: GrpcMessage, response_type: type[T], **kwargs: Any ) -> T: # Serialize request container_req = MessageContainer( diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 1ab84eb01468..765c6a6b2e91 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -15,7 +15,7 @@ """Client-side message handler.""" from logging import WARN -from typing import Optional, Tuple, cast +from typing import Optional, cast from flwr.client.client import ( maybe_call_evaluate, @@ -52,7 +52,7 @@ class UnknownServerMessage(Exception): """Exception indicating that the received message is unknown.""" -def handle_control_message(message: Message) -> Tuple[Optional[Message], int]: +def handle_control_message(message: Message) -> tuple[Optional[Message], int]: """Handle control part of the incoming message. Parameters @@ -147,7 +147,7 @@ def handle_legacy_message_from_msgtype( def _reconnect( reconnect_msg: ServerMessage.ReconnectIns, -) -> Tuple[ClientMessage, int]: +) -> tuple[ClientMessage, int]: # Determine the reason for sending DisconnectRes message reason = Reason.ACK sleep_duration = None diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 557d61ffb32a..311f8c37e1b1 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -19,7 +19,6 @@ import unittest import uuid from copy import copy -from typing import List from flwr.client import Client from flwr.client.typing import ClientFnExt @@ -294,7 +293,7 @@ def test_invalid_message_run_id(self) -> None: msg = Message(metadata=self.valid_out_metadata, content=RecordSet()) # Execute - invalid_metadata_list: List[Metadata] = [] + invalid_metadata_list: list[Metadata] = [] attrs = list(vars(self.valid_out_metadata).keys()) for attr in attrs: if attr == "_partition_id": diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 5b196ad84321..f9d3c433157d 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field from logging import DEBUG, WARNING -from typing import Any, Dict, List, Tuple, cast +from typing import Any, cast from flwr.client.typing import ClientAppCallable from flwr.common import ( @@ -91,11 +91,11 @@ class SecAggPlusState: # Random seed for generating the private mask rd_seed: bytes = b"" - rd_seed_share_dict: Dict[int, bytes] = field(default_factory=dict) - sk1_share_dict: Dict[int, bytes] = field(default_factory=dict) + rd_seed_share_dict: dict[int, bytes] = field(default_factory=dict) + sk1_share_dict: dict[int, bytes] = field(default_factory=dict) # The dict of the shared secrets from sk2 - ss2_dict: Dict[int, bytes] = field(default_factory=dict) - public_keys_dict: Dict[int, Tuple[bytes, bytes]] = field(default_factory=dict) + ss2_dict: dict[int, bytes] = field(default_factory=dict) + public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict) def __init__(self, **kwargs: ConfigsRecordValues) -> None: for k, v in kwargs.items(): @@ -104,8 +104,8 @@ def __init__(self, **kwargs: ConfigsRecordValues) -> None: new_v: Any = v if k.endswith(":K"): k = k[:-2] - keys = cast(List[int], v) - values = cast(List[bytes], kwargs[f"{k}:V"]) + keys = cast(list[int], v) + values = cast(list[bytes], kwargs[f"{k}:V"]) if len(values) > len(keys): updated_values = [ tuple(values[i : i + 2]) for i in range(0, len(values), 2) @@ -115,17 +115,17 @@ def __init__(self, **kwargs: ConfigsRecordValues) -> None: new_v = dict(zip(keys, values)) self.__setattr__(k, new_v) - def to_dict(self) -> Dict[str, ConfigsRecordValues]: + def to_dict(self) -> dict[str, ConfigsRecordValues]: """Convert the state to a dictionary.""" ret = vars(self) for k in list(ret.keys()): if isinstance(ret[k], dict): # Replace dict with two lists - v = cast(Dict[str, Any], ret.pop(k)) + v = cast(dict[str, Any], ret.pop(k)) ret[f"{k}:K"] = list(v.keys()) if k == "public_keys_dict": - v_list: List[bytes] = [] - for b1_b2 in cast(List[Tuple[bytes, bytes]], v.values()): + v_list: list[bytes] = [] + for b1_b2 in cast(list[tuple[bytes, bytes]], v.values()): v_list.extend(b1_b2) ret[f"{k}:V"] = v_list else: @@ -276,7 +276,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: ) if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], configs[key]) + for elm in cast(list[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -299,7 +299,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: ) if not isinstance(configs[key], list) or any( elm - for elm in cast(List[Any], configs[key]) + for elm in cast(list[Any], configs[key]) # pylint: disable-next=unidiomatic-typecheck if type(elm) is not expected_type ): @@ -314,7 +314,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: def _setup( state: SecAggPlusState, configs: ConfigsRecord -) -> Dict[str, ConfigsRecordValues]: +) -> dict[str, ConfigsRecordValues]: # Assigning parameter values to object fields sec_agg_param_dict = configs state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER]) @@ -350,8 +350,8 @@ def _setup( # pylint: disable-next=too-many-locals def _share_keys( state: SecAggPlusState, configs: ConfigsRecord -) -> Dict[str, ConfigsRecordValues]: - named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) +) -> dict[str, ConfigsRecordValues]: + named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} log(DEBUG, "Node %d: starting stage 1...", state.nid) state.public_keys_dict = key_dict @@ -361,7 +361,7 @@ def _share_keys( raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique - pk_list: List[bytes] = [] + pk_list: list[bytes] = [] for pk1, pk2 in state.public_keys_dict.values(): pk_list.append(pk1) pk_list.append(pk2) @@ -415,11 +415,11 @@ def _collect_masked_vectors( configs: ConfigsRecord, num_examples: int, updated_parameters: Parameters, -) -> Dict[str, ConfigsRecordValues]: +) -> dict[str, ConfigsRecordValues]: log(DEBUG, "Node %d: starting stage 2...", state.nid) - available_clients: List[int] = [] - ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST]) - srcs = cast(List[int], configs[Key.SOURCE_LIST]) + available_clients: list[int] = [] + ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST]) + srcs = cast(list[int], configs[Key.SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: raise ValueError("Not enough available neighbour clients.") @@ -467,7 +467,7 @@ def _collect_masked_vectors( quantized_parameters = factor_combine(q_ratio, quantized_parameters) - dimensions_list: List[Tuple[int, ...]] = [a.shape for a in quantized_parameters] + dimensions_list: list[tuple[int, ...]] = [a.shape for a in quantized_parameters] # Add private mask private_mask = pseudo_rand_gen(state.rd_seed, state.mod_range, dimensions_list) @@ -499,11 +499,11 @@ def _collect_masked_vectors( def _unmask( state: SecAggPlusState, configs: ConfigsRecord -) -> Dict[str, ConfigsRecordValues]: +) -> dict[str, ConfigsRecordValues]: log(DEBUG, "Node %d: starting stage 3...", state.nid) - active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST]) - dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST]) + active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST]) + dead_nids = cast(list[int], configs[Key.DEAD_NODE_ID_LIST]) # Send private mask seed share for every avaliable client (including itself) # Send first private key share for building pairwise mask for every dropped client if len(active_nids) < state.threshold: diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index 2832576fb4fc..e68bf5177797 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -16,7 +16,7 @@ import unittest from itertools import product -from typing import Callable, Dict, List +from typing import Callable from flwr.client.mod import make_ffn from flwr.common import ( @@ -41,7 +41,7 @@ def get_test_handler( ctxt: Context, -) -> Callable[[Dict[str, ConfigsRecordValues]], ConfigsRecord]: +) -> Callable[[dict[str, ConfigsRecordValues]], ConfigsRecord]: """.""" def empty_ffn(_msg: Message, _2: Context) -> Message: @@ -49,7 +49,7 @@ def empty_ffn(_msg: Message, _2: Context) -> Message: app = make_ffn(empty_ffn, [secaggplus_mod]) - def func(configs: Dict[str, ConfigsRecordValues]) -> ConfigsRecord: + def func(configs: dict[str, ConfigsRecordValues]) -> ConfigsRecord: in_msg = Message( metadata=Metadata( run_id=0, @@ -158,7 +158,7 @@ def test_stage_setup_check(self) -> None: (Key.MOD_RANGE, int), ] - type_to_test_value: Dict[type, ConfigsRecordValues] = { + type_to_test_value: dict[type, ConfigsRecordValues] = { int: 10, bool: True, float: 1.0, @@ -166,7 +166,7 @@ def test_stage_setup_check(self) -> None: bytes: b"test", } - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { key: type_to_test_value[value_type] for key, value_type in valid_key_type_pairs } @@ -208,7 +208,7 @@ def test_stage_share_keys_check(self) -> None: handler = get_test_handler(ctxt) set_stage = _make_set_state_fn(ctxt) - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { "1": [b"public key 1", b"public key 2"], "2": [b"public key 1", b"public key 2"], "3": [b"public key 1", b"public key 2"], @@ -225,7 +225,7 @@ def test_stage_share_keys_check(self) -> None: valid_configs[Key.STAGE] = Stage.SHARE_KEYS # Test invalid configs - invalid_values: List[ConfigsRecordValues] = [ + invalid_values: list[ConfigsRecordValues] = [ b"public key 1", [b"public key 1"], [b"public key 1", b"public key 2", b"public key 3"], @@ -245,7 +245,7 @@ def test_stage_collect_masked_vectors_check(self) -> None: handler = get_test_handler(ctxt) set_stage = _make_set_state_fn(ctxt) - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { Key.CIPHERTEXT_LIST: [b"ctxt!", b"ctxt@", b"ctxt#", b"ctxt?"], Key.SOURCE_LIST: [32, 51324, 32324123, -3], } @@ -289,7 +289,7 @@ def test_stage_unmask_check(self) -> None: handler = get_test_handler(ctxt) set_stage = _make_set_state_fn(ctxt) - valid_configs: Dict[str, ConfigsRecordValues] = { + valid_configs: dict[str, ConfigsRecordValues] = { Key.ACTIVE_NODE_ID_LIST: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], Key.DEAD_NODE_ID_LIST: [32, 51324, 32324123, -3], } diff --git a/src/py/flwr/client/mod/utils.py b/src/py/flwr/client/mod/utils.py index c8fb21379783..c76902cf263f 100644 --- a/src/py/flwr/client/mod/utils.py +++ b/src/py/flwr/client/mod/utils.py @@ -15,13 +15,11 @@ """Utility functions for mods.""" -from typing import List - from flwr.client.typing import ClientAppCallable, Mod from flwr.common import Context, Message -def make_ffn(ffn: ClientAppCallable, mods: List[Mod]) -> ClientAppCallable: +def make_ffn(ffn: ClientAppCallable, mods: list[Mod]) -> ClientAppCallable: """.""" def wrap_ffn(_ffn: ClientAppCallable, _mod: Mod) -> ClientAppCallable: diff --git a/src/py/flwr/client/mod/utils_test.py b/src/py/flwr/client/mod/utils_test.py index a5bbd0a0bb4d..e75fb5530b2c 100644 --- a/src/py/flwr/client/mod/utils_test.py +++ b/src/py/flwr/client/mod/utils_test.py @@ -16,7 +16,7 @@ import unittest -from typing import List, cast +from typing import cast from flwr.client.typing import ClientAppCallable, Mod from flwr.common import ( @@ -43,7 +43,7 @@ def _increment_context_counter(context: Context) -> None: context.state.metrics_records[METRIC] = MetricsRecord({COUNTER: current_counter}) -def make_mock_mod(name: str, footprint: List[str]) -> Mod: +def make_mock_mod(name: str, footprint: list[str]) -> Mod: """Make a mock mod.""" def mod(message: Message, context: Context, app: ClientAppCallable) -> Message: @@ -61,7 +61,7 @@ def mod(message: Message, context: Context, app: ClientAppCallable) -> Message: return mod -def make_mock_app(name: str, footprint: List[str]) -> ClientAppCallable: +def make_mock_app(name: str, footprint: list[str]) -> ClientAppCallable: """Make a mock app.""" def app(message: Message, context: Context) -> Message: @@ -97,7 +97,7 @@ class TestMakeApp(unittest.TestCase): def test_multiple_mods(self) -> None: """Test if multiple mods are called in the correct order.""" # Prepare - footprint: List[str] = [] + footprint: list[str] = [] mock_app = make_mock_app("app", footprint) mock_mod_names = [f"mod{i}" for i in range(1, 15)] mock_mods = [make_mock_mod(name, footprint) for name in mock_mod_names] @@ -127,7 +127,7 @@ def test_multiple_mods(self) -> None: def test_filter(self) -> None: """Test if a mod can filter incoming TaskIns.""" # Prepare - footprint: List[str] = [] + footprint: list[str] = [] mock_app = make_mock_app("app", footprint) context = Context(node_id=0, node_config={}, state=RecordSet(), run_config={}) message = _get_dummy_flower_message() diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index e16d7e34715d..e7967dfc8bee 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import Dict, Optional +from typing import Optional from flwr.common import Context, RecordSet from flwr.common.config import ( @@ -46,7 +46,7 @@ def __init__( ) -> None: self.node_id = node_id self.node_config = node_config - self.run_infos: Dict[int, RunInfo] = {} + self.run_infos: dict[int, RunInfo] = {} # pylint: disable=too-many-arguments def register_context( diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index b21a51b38e9b..6a656cb661d2 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -16,7 +16,7 @@ from abc import ABC -from typing import Callable, Dict, Tuple +from typing import Callable from flwr.client.client import Client from flwr.common import ( @@ -73,7 +73,7 @@ class NumPyClient(ABC): _context: Context - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return a client's set of properties. Parameters @@ -93,7 +93,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: _ = (self, config) return {} - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + def get_parameters(self, config: dict[str, Scalar]) -> NDArrays: """Return the current local model parameters. Parameters @@ -112,8 +112,8 @@ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: return [] def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Train the provided parameters using the locally held dataset. Parameters @@ -141,8 +141,8 @@ def fit( return [], 0, {} def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[float, int, dict[str, Scalar]]: """Evaluate the provided parameters using the locally held dataset. Parameters @@ -310,7 +310,7 @@ def _set_context(self: Client, context: Context) -> None: def _wrap_numpy_client(client: NumPyClient) -> Client: - member_dict: Dict[str, Callable] = { # type: ignore + member_dict: dict[str, Callable] = { # type: ignore "__init__": _constructor, "get_context": _get_context, "set_context": _set_context, diff --git a/src/py/flwr/client/numpy_client_test.py b/src/py/flwr/client/numpy_client_test.py index 06a0deafe2c9..c5d520a73ce1 100644 --- a/src/py/flwr/client/numpy_client_test.py +++ b/src/py/flwr/client/numpy_client_test.py @@ -15,8 +15,6 @@ """Flower NumPyClient tests.""" -from typing import Dict, Tuple - from flwr.common import Config, NDArrays, Properties, Scalar from .numpy_client import ( @@ -40,14 +38,14 @@ def get_parameters(self, config: Config) -> NDArrays: return [] def fit( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[NDArrays, int, dict[str, Scalar]]: """Simulate training by returning empty weights, 0 samples, empty metrics.""" return [], 0, {} def evaluate( - self, parameters: NDArrays, config: Dict[str, Scalar] - ) -> Tuple[float, int, Dict[str, Scalar]]: + self, parameters: NDArrays, config: dict[str, Scalar] + ) -> tuple[float, int, dict[str, Scalar]]: """Simulate evaluate by returning 0.0 loss, 0 samples, empty metrics.""" return 0.0, 0, {} diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d5f005fbaf77..72b6be25a708 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -18,10 +18,11 @@ import random import sys import threading +from collections.abc import Iterator from contextlib import contextmanager from copy import copy from logging import ERROR, INFO, WARN -from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union from cryptography.hazmat.primitives.asymmetric import ec from google.protobuf.message import Message as GrpcMessage @@ -90,10 +91,10 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Union[bytes, str] ] = None, # pylint: disable=unused-argument authentication_keys: Optional[ # pylint: disable=unused-argument - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ - Tuple[ + tuple[ Callable[[], Optional[Message]], Callable[[Message], None], Optional[Callable[[], Optional[int]]], @@ -173,7 +174,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 ########################################################################### def _request( - req: GrpcMessage, res_type: Type[T], api_path: str, retry: bool = True + req: GrpcMessage, res_type: type[T], api_path: str, retry: bool = True ) -> Optional[T]: # Serialize the request req_bytes = req.SerializeToString() diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 425c7f7133a4..d9af001bba53 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -18,7 +18,7 @@ import sys from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path -from typing import Optional, Tuple +from typing import Optional from cryptography.exceptions import UnsupportedAlgorithm from cryptography.hazmat.primitives.asymmetric import ec @@ -291,7 +291,7 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: def _try_setup_client_authentication( args: argparse.Namespace, -) -> Optional[Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: +) -> Optional[tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: if not args.auth_supernode_private_key and not args.auth_supernode_public_key: return None diff --git a/src/py/flwr/common/address.py b/src/py/flwr/common/address.py index 7a70925c0fc9..2b10097ccb71 100644 --- a/src/py/flwr/common/address.py +++ b/src/py/flwr/common/address.py @@ -16,12 +16,12 @@ import socket from ipaddress import ip_address -from typing import Optional, Tuple +from typing import Optional IPV6: int = 6 -def parse_address(address: str) -> Optional[Tuple[str, int, Optional[bool]]]: +def parse_address(address: str) -> Optional[tuple[str, int, Optional[bool]]]: """Parse an IP address into host, port, and version. Parameters diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 42039fa959ac..071d41a3ab5e 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -17,7 +17,7 @@ import os import re from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast, get_args +from typing import Any, Optional, Union, cast, get_args import tomli @@ -53,7 +53,7 @@ def get_project_dir( return Path(flwr_dir) / APP_DIR / publisher / project_name / fab_version -def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: +def get_project_config(project_dir: Union[str, Path]) -> dict[str, Any]: """Return pyproject.toml in the given project directory.""" # Load pyproject.toml file toml_path = Path(project_dir) / FAB_CONFIG_FILE @@ -137,13 +137,13 @@ def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> UserConfig: def flatten_dict( - raw_dict: Optional[Dict[str, Any]], parent_key: str = "" + raw_dict: Optional[dict[str, Any]], parent_key: str = "" ) -> UserConfig: """Flatten dict by joining nested keys with a given separator.""" if raw_dict is None: return {} - items: List[Tuple[str, UserConfigValue]] = [] + items: list[tuple[str, UserConfigValue]] = [] separator: str = "." for k, v in raw_dict.items(): new_key = f"{parent_key}{separator}{k}" if parent_key else k @@ -159,9 +159,9 @@ def flatten_dict( return dict(items) -def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: +def unflatten_dict(flat_dict: dict[str, Any]) -> dict[str, Any]: """Unflatten a dict with keys containing separators into a nested dict.""" - unflattened_dict: Dict[str, Any] = {} + unflattened_dict: dict[str, Any] = {} separator: str = "." for key, value in flat_dict.items(): @@ -177,7 +177,7 @@ def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: def parse_config_args( - config: Optional[List[str]], + config: Optional[list[str]], ) -> UserConfig: """Parse separator separated list of key-value pairs separated by '='.""" overrides: UserConfig = {} @@ -209,7 +209,7 @@ def parse_config_args( return overrides -def get_metadata_from_config(config: Dict[str, Any]) -> Tuple[str, str]: +def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]: """Extract `fab_version` and `fab_id` from a project config.""" return ( config["project"]["version"], diff --git a/src/py/flwr/common/differential_privacy.py b/src/py/flwr/common/differential_privacy.py index 85dc198ef8a0..56da98a3c805 100644 --- a/src/py/flwr/common/differential_privacy.py +++ b/src/py/flwr/common/differential_privacy.py @@ -16,7 +16,7 @@ from logging import WARNING -from typing import Optional, Tuple +from typing import Optional import numpy as np @@ -125,7 +125,7 @@ def compute_adaptive_noise_params( noise_multiplier: float, num_sampled_clients: float, clipped_count_stddev: Optional[float], -) -> Tuple[float, float]: +) -> tuple[float, float]: """Compute noising parameters for the adaptive clipping. Paper: https://arxiv.org/abs/1905.03871 diff --git a/src/py/flwr/common/dp.py b/src/py/flwr/common/dp.py index 527805c8ef42..13ae94461ef9 100644 --- a/src/py/flwr/common/dp.py +++ b/src/py/flwr/common/dp.py @@ -15,8 +15,6 @@ """Building block functions for DP algorithms.""" -from typing import Tuple - import numpy as np from flwr.common.logger import warn_deprecated_feature @@ -41,7 +39,7 @@ def add_gaussian_noise(update: NDArrays, std_dev: float) -> NDArrays: return update_noised -def clip_by_l2(update: NDArrays, threshold: float) -> Tuple[NDArrays, bool]: +def clip_by_l2(update: NDArrays, threshold: float) -> tuple[NDArrays, bool]: """Scales the update so thats its L2 norm is upper-bound to threshold.""" warn_deprecated_feature("`clip_by_l2` method") update_norm = _get_update_norm(update) diff --git a/src/py/flwr/common/exit_handlers.py b/src/py/flwr/common/exit_handlers.py index 30750c28a450..e5898b46a537 100644 --- a/src/py/flwr/common/exit_handlers.py +++ b/src/py/flwr/common/exit_handlers.py @@ -19,7 +19,7 @@ from signal import SIGINT, SIGTERM, signal from threading import Thread from types import FrameType -from typing import List, Optional +from typing import Optional from grpc import Server @@ -28,8 +28,8 @@ def register_exit_handlers( event_type: EventType, - grpc_servers: Optional[List[Server]] = None, - bckg_threads: Optional[List[Thread]] = None, + grpc_servers: Optional[list[Server]] = None, + bckg_threads: Optional[list[Thread]] = None, ) -> None: """Register exit handlers for `SIGINT` and `SIGTERM` signals. diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index ec8fe823a7eb..5a29c595119c 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -15,8 +15,9 @@ """Utility functions for gRPC.""" +from collections.abc import Sequence from logging import DEBUG -from typing import Optional, Sequence +from typing import Optional import grpc diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 2077f9beaca0..303780fc0b5d 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -18,7 +18,7 @@ import logging from logging import WARN, LogRecord from logging.handlers import HTTPHandler -from typing import TYPE_CHECKING, Any, Dict, Optional, TextIO, Tuple +from typing import TYPE_CHECKING, Any, Optional, TextIO # Create logger LOGGER_NAME = "flwr" @@ -119,12 +119,12 @@ def __init__( url: str, method: str = "GET", secure: bool = False, - credentials: Optional[Tuple[str, str]] = None, + credentials: Optional[tuple[str, str]] = None, ) -> None: super().__init__(host, url, method, secure, credentials) self.identifier = identifier - def mapLogRecord(self, record: LogRecord) -> Dict[str, Any]: + def mapLogRecord(self, record: LogRecord) -> dict[str, Any]: """Filter for the properties to be send to the logserver.""" record_dict = record.__dict__ return { diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py index c6142cb18256..57c57eb41bd9 100644 --- a/src/py/flwr/common/message_test.py +++ b/src/py/flwr/common/message_test.py @@ -17,7 +17,7 @@ import time from collections import namedtuple from contextlib import ExitStack -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import pytest @@ -193,7 +193,7 @@ def test_create_reply( ), ], ) -def test_repr(cls: type, kwargs: Dict[str, Any]) -> None: +def test_repr(cls: type, kwargs: dict[str, Any]) -> None: """Test string representations of Metadata/Message/Error.""" # Prepare anon_cls = namedtuple(cls.__qualname__, kwargs.keys()) # type: ignore diff --git a/src/py/flwr/common/object_ref.py b/src/py/flwr/common/object_ref.py index 9723c14037a0..6259b5ab557d 100644 --- a/src/py/flwr/common/object_ref.py +++ b/src/py/flwr/common/object_ref.py @@ -21,7 +21,7 @@ from importlib.util import find_spec from logging import WARN from pathlib import Path -from typing import Any, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from .logger import log @@ -40,7 +40,7 @@ def validate( module_attribute_str: str, check_module: bool = True, project_dir: Optional[Union[str, Path]] = None, -) -> Tuple[bool, Optional[str]]: +) -> tuple[bool, Optional[str]]: """Validate object reference. Parameters @@ -106,7 +106,7 @@ def validate( def load_app( # pylint: disable= too-many-branches module_attribute_str: str, - error_type: Type[Exception], + error_type: type[Exception], project_dir: Optional[Union[str, Path]] = None, ) -> Any: """Return the object specified in a module attribute string. diff --git a/src/py/flwr/common/record/configsrecord.py b/src/py/flwr/common/record/configsrecord.py index aeb311089bcd..f570e000cc9b 100644 --- a/src/py/flwr/common/record/configsrecord.py +++ b/src/py/flwr/common/record/configsrecord.py @@ -15,7 +15,7 @@ """ConfigsRecord.""" -from typing import Dict, List, Optional, get_args +from typing import Optional, get_args from flwr.common.typing import ConfigsRecordValues, ConfigsScalar @@ -109,7 +109,7 @@ class ConfigsRecord(TypedDict[str, ConfigsRecordValues]): def __init__( self, - configs_dict: Optional[Dict[str, ConfigsRecordValues]] = None, + configs_dict: Optional[dict[str, ConfigsRecordValues]] = None, keep_input: bool = True, ) -> None: @@ -141,7 +141,7 @@ def get_var_bytes(value: ConfigsScalar) -> int: num_bytes = 0 for k, v in self.items(): - if isinstance(v, List): + if isinstance(v, list): if isinstance(v[0], (bytes, str)): # not all str are of equal length necessarily # for both the footprint of each element is 1 Byte diff --git a/src/py/flwr/common/record/metricsrecord.py b/src/py/flwr/common/record/metricsrecord.py index 868ed82e79ca..d0a6123c807f 100644 --- a/src/py/flwr/common/record/metricsrecord.py +++ b/src/py/flwr/common/record/metricsrecord.py @@ -15,7 +15,7 @@ """MetricsRecord.""" -from typing import Dict, List, Optional, get_args +from typing import Optional, get_args from flwr.common.typing import MetricsRecordValues, MetricsScalar @@ -115,7 +115,7 @@ class MetricsRecord(TypedDict[str, MetricsRecordValues]): def __init__( self, - metrics_dict: Optional[Dict[str, MetricsRecordValues]] = None, + metrics_dict: Optional[dict[str, MetricsRecordValues]] = None, keep_input: bool = True, ): super().__init__(_check_key, _check_value) @@ -130,7 +130,7 @@ def count_bytes(self) -> int: num_bytes = 0 for k, v in self.items(): - if isinstance(v, List): + if isinstance(v, list): # both int and float normally take 4 bytes # But MetricRecords are mapped to 64bit int/float # during protobuffing diff --git a/src/py/flwr/common/record/parametersrecord.py b/src/py/flwr/common/record/parametersrecord.py index f088d682497b..10ec65ca0277 100644 --- a/src/py/flwr/common/record/parametersrecord.py +++ b/src/py/flwr/common/record/parametersrecord.py @@ -14,9 +14,10 @@ # ============================================================================== """ParametersRecord and Array.""" +from collections import OrderedDict from dataclasses import dataclass from io import BytesIO -from typing import List, Optional, OrderedDict, cast +from typing import Optional, cast import numpy as np @@ -51,7 +52,7 @@ class Array: """ dtype: str - shape: List[int] + shape: list[int] stype: str data: bytes diff --git a/src/py/flwr/common/record/parametersrecord_test.py b/src/py/flwr/common/record/parametersrecord_test.py index e840e5e266e4..9ac18a3ec854 100644 --- a/src/py/flwr/common/record/parametersrecord_test.py +++ b/src/py/flwr/common/record/parametersrecord_test.py @@ -17,7 +17,6 @@ import unittest from collections import OrderedDict from io import BytesIO -from typing import List import numpy as np import pytest @@ -81,7 +80,7 @@ def test_numpy_conversion_invalid(self) -> None: ([31, 153], "bool_"), # bool_ is represented as a whole Byte in NumPy ], ) -def test_count_bytes(shape: List[int], dtype: str) -> None: +def test_count_bytes(shape: list[int], dtype: str) -> None: """Test bytes in a ParametersRecord are computed correctly.""" original_array = np.random.randn(*shape).astype(np.dtype(dtype)) diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 96556d335f4c..154e320e5f0b 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -15,9 +15,9 @@ """RecordSet tests.""" import pickle -from collections import namedtuple +from collections import OrderedDict, namedtuple from copy import deepcopy -from typing import Callable, Dict, List, OrderedDict, Type, Union +from typing import Callable, Union import numpy as np import pytest @@ -158,8 +158,8 @@ def test_set_parameters_with_correct_types() -> None: ], ) def test_set_parameters_with_incorrect_types( - key_type: Type[Union[int, str]], - value_fn: Callable[[NDArray], Union[NDArray, List[float]]], + key_type: type[Union[int, str]], + value_fn: Callable[[NDArray], Union[NDArray, list[float]]], ) -> None: """Test adding dictionary of unsupported types to ParametersRecord.""" p_record = ParametersRecord() @@ -183,7 +183,7 @@ def test_set_parameters_with_incorrect_types( ], ) def test_set_metrics_to_metricsrecord_with_correct_types( - key_type: Type[str], + key_type: type[str], value_fn: Callable[[NDArray], MetricsRecordValues], ) -> None: """Test adding metrics of various types to a MetricsRecord.""" @@ -236,8 +236,8 @@ def test_set_metrics_to_metricsrecord_with_correct_types( ], ) def test_set_metrics_to_metricsrecord_with_incorrect_types( - key_type: Type[Union[str, int, float, bool]], - value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], + key_type: type[Union[str, int, float, bool]], + value_fn: Callable[[NDArray], Union[NDArray, dict[str, NDArray], list[float]]], ) -> None: """Test adding metrics of various unsupported types to a MetricsRecord.""" m_record = MetricsRecord() @@ -302,7 +302,7 @@ def test_set_metrics_to_metricsrecord_with_and_without_keeping_input( ], ) def test_set_configs_to_configsrecord_with_correct_types( - key_type: Type[str], + key_type: type[str], value_fn: Callable[[NDArray], ConfigsRecordValues], ) -> None: """Test adding configs of various types to a ConfigsRecord.""" @@ -346,8 +346,8 @@ def test_set_configs_to_configsrecord_with_correct_types( ], ) def test_set_configs_to_configsrecord_with_incorrect_types( - key_type: Type[Union[str, int, float]], - value_fn: Callable[[NDArray], Union[NDArray, Dict[str, NDArray], List[float]]], + key_type: type[Union[str, int, float]], + value_fn: Callable[[NDArray], Union[NDArray, dict[str, NDArray], list[float]]], ) -> None: """Test adding configs of various unsupported types to a ConfigsRecord.""" c_record = ConfigsRecord() diff --git a/src/py/flwr/common/record/typeddict.py b/src/py/flwr/common/record/typeddict.py index 37d98b01a306..c2c8548c4de3 100644 --- a/src/py/flwr/common/record/typeddict.py +++ b/src/py/flwr/common/record/typeddict.py @@ -15,18 +15,8 @@ """Typed dict base class for *Records.""" -from typing import ( - Callable, - Dict, - Generic, - ItemsView, - Iterator, - KeysView, - MutableMapping, - TypeVar, - ValuesView, - cast, -) +from collections.abc import ItemsView, Iterator, KeysView, MutableMapping, ValuesView +from typing import Callable, Generic, TypeVar, cast K = TypeVar("K") # Key type V = TypeVar("V") # Value type @@ -49,37 +39,37 @@ def __setitem__(self, key: K, value: V) -> None: cast(Callable[[V], None], self.__dict__["_check_value_fn"])(value) # Set key-value pair - cast(Dict[K, V], self.__dict__["_data"])[key] = value + cast(dict[K, V], self.__dict__["_data"])[key] = value def __delitem__(self, key: K) -> None: """Remove the item with the specified key.""" - del cast(Dict[K, V], self.__dict__["_data"])[key] + del cast(dict[K, V], self.__dict__["_data"])[key] def __getitem__(self, item: K) -> V: """Return the value for the specified key.""" - return cast(Dict[K, V], self.__dict__["_data"])[item] + return cast(dict[K, V], self.__dict__["_data"])[item] def __iter__(self) -> Iterator[K]: """Yield an iterator over the keys of the dictionary.""" - return iter(cast(Dict[K, V], self.__dict__["_data"])) + return iter(cast(dict[K, V], self.__dict__["_data"])) def __repr__(self) -> str: """Return a string representation of the dictionary.""" - return cast(Dict[K, V], self.__dict__["_data"]).__repr__() + return cast(dict[K, V], self.__dict__["_data"]).__repr__() def __len__(self) -> int: """Return the number of items in the dictionary.""" - return len(cast(Dict[K, V], self.__dict__["_data"])) + return len(cast(dict[K, V], self.__dict__["_data"])) def __contains__(self, key: object) -> bool: """Check if the dictionary contains the specified key.""" - return key in cast(Dict[K, V], self.__dict__["_data"]) + return key in cast(dict[K, V], self.__dict__["_data"]) def __eq__(self, other: object) -> bool: """Compare this instance to another dictionary or TypedDict.""" - data = cast(Dict[K, V], self.__dict__["_data"]) + data = cast(dict[K, V], self.__dict__["_data"]) if isinstance(other, TypedDict): - other_data = cast(Dict[K, V], other.__dict__["_data"]) + other_data = cast(dict[K, V], other.__dict__["_data"]) return data == other_data if isinstance(other, dict): return data == other @@ -87,12 +77,12 @@ def __eq__(self, other: object) -> bool: def keys(self) -> KeysView[K]: """D.keys() -> a set-like object providing a view on D's keys.""" - return cast(Dict[K, V], self.__dict__["_data"]).keys() + return cast(dict[K, V], self.__dict__["_data"]).keys() def values(self) -> ValuesView[V]: """D.values() -> an object providing a view on D's values.""" - return cast(Dict[K, V], self.__dict__["_data"]).values() + return cast(dict[K, V], self.__dict__["_data"]).values() def items(self) -> ItemsView[K, V]: """D.items() -> a set-like object providing a view on D's items.""" - return cast(Dict[K, V], self.__dict__["_data"]).items() + return cast(dict[K, V], self.__dict__["_data"]).items() diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py index 8bf884c30e58..35024fcd67d1 100644 --- a/src/py/flwr/common/recordset_compat.py +++ b/src/py/flwr/common/recordset_compat.py @@ -15,7 +15,9 @@ """RecordSet utilities.""" -from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args +from collections import OrderedDict +from collections.abc import Mapping +from typing import Union, cast, get_args from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet from .typing import ( @@ -115,7 +117,7 @@ def parameters_to_parametersrecord( def _check_mapping_from_recordscalartype_to_scalar( record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]] -) -> Dict[str, Scalar]: +) -> dict[str, Scalar]: """Check mapping `common.*RecordValues` into `common.Scalar` is possible.""" for value in record_data.values(): if not isinstance(value, get_args(Scalar)): @@ -126,14 +128,14 @@ def _check_mapping_from_recordscalartype_to_scalar( "supported by the `common.RecordSet` infrastructure. " f"You used type: {type(value)}" ) - return cast(Dict[str, Scalar], record_data) + return cast(dict[str, Scalar], record_data) def _recordset_to_fit_or_evaluate_ins_components( recordset: RecordSet, ins_str: str, keep_input: bool, -) -> Tuple[Parameters, Dict[str, Scalar]]: +) -> tuple[Parameters, dict[str, Scalar]]: """Derive Fit/Evaluate Ins from a RecordSet.""" # get Array and construct Parameters parameters_record = recordset.parameters_records[f"{ins_str}.parameters"] @@ -169,7 +171,7 @@ def _fit_or_evaluate_ins_to_recordset( def _embed_status_into_recordset( res_str: str, status: Status, recordset: RecordSet ) -> RecordSet: - status_dict: Dict[str, ConfigsRecordValues] = { + status_dict: dict[str, ConfigsRecordValues] = { "code": int(status.code.value), "message": status.message, } diff --git a/src/py/flwr/common/recordset_compat_test.py b/src/py/flwr/common/recordset_compat_test.py index e0ac7f216af9..05d821e37e40 100644 --- a/src/py/flwr/common/recordset_compat_test.py +++ b/src/py/flwr/common/recordset_compat_test.py @@ -15,7 +15,7 @@ """RecordSet from legacy messages tests.""" from copy import deepcopy -from typing import Callable, Dict +from typing import Callable import numpy as np import pytest @@ -82,7 +82,7 @@ def _get_valid_fitins_with_empty_ndarrays() -> FitIns: def _get_valid_fitres() -> FitRes: """Returnn Valid parameters but potentially invalid config.""" arrays = get_ndarrays() - metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + metrics: dict[str, Scalar] = {"a": 1.0, "b": 0} return FitRes( parameters=ndarrays_to_parameters(arrays), num_examples=1, @@ -98,7 +98,7 @@ def _get_valid_evaluateins() -> EvaluateIns: def _get_valid_evaluateres() -> EvaluateRes: """Return potentially invalid config.""" - metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + metrics: dict[str, Scalar] = {"a": 1.0, "b": 0} return EvaluateRes( num_examples=1, loss=0.1, @@ -108,7 +108,7 @@ def _get_valid_evaluateres() -> EvaluateRes: def _get_valid_getparametersins() -> GetParametersIns: - config_dict: Dict[str, Scalar] = { + config_dict: dict[str, Scalar] = { "a": 1.0, "b": 3, "c": True, @@ -131,7 +131,7 @@ def _get_valid_getpropertiesins() -> GetPropertiesIns: def _get_valid_getpropertiesres() -> GetPropertiesRes: - config_dict: Dict[str, Scalar] = { + config_dict: dict[str, Scalar] = { "a": 1.0, "b": 3, "c": True, diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index d12124b89840..303d5596f237 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -18,20 +18,9 @@ import itertools import random import time +from collections.abc import Generator, Iterable from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import Any, Callable, Optional, Union, cast def exponential( @@ -93,8 +82,8 @@ class RetryState: """State for callbacks in RetryInvoker.""" target: Callable[..., Any] - args: Tuple[Any, ...] - kwargs: Dict[str, Any] + args: tuple[Any, ...] + kwargs: dict[str, Any] tries: int elapsed_time: float exception: Optional[Exception] = None @@ -167,7 +156,7 @@ class RetryInvoker: def __init__( self, wait_gen_factory: Callable[[], Generator[float, None, None]], - recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]], + recoverable_exceptions: Union[type[Exception], tuple[type[Exception], ...]], max_tries: Optional[int], max_time: Optional[float], *, @@ -244,7 +233,7 @@ def try_call_event_handler( try_cnt = 0 wait_generator = self.wait_gen_factory() start = time.monotonic() - ref_state: List[Optional[RetryState]] = [None] + ref_state: list[Optional[RetryState]] = [None] while True: try_cnt += 1 diff --git a/src/py/flwr/common/retry_invoker_test.py b/src/py/flwr/common/retry_invoker_test.py index 2259ae47ded4..a9f2625ff443 100644 --- a/src/py/flwr/common/retry_invoker_test.py +++ b/src/py/flwr/common/retry_invoker_test.py @@ -15,7 +15,7 @@ """Tests for `RetryInvoker`.""" -from typing import Generator +from collections.abc import Generator from unittest.mock import MagicMock, Mock, patch import pytest diff --git a/src/py/flwr/common/secure_aggregation/crypto/shamir.py b/src/py/flwr/common/secure_aggregation/crypto/shamir.py index 688bfa2153ea..9c7e67abf94f 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/shamir.py +++ b/src/py/flwr/common/secure_aggregation/crypto/shamir.py @@ -17,20 +17,20 @@ import pickle from concurrent.futures import ThreadPoolExecutor -from typing import List, Tuple, cast +from typing import cast from Crypto.Protocol.SecretSharing import Shamir from Crypto.Util.Padding import pad, unpad -def create_shares(secret: bytes, threshold: int, num: int) -> List[bytes]: +def create_shares(secret: bytes, threshold: int, num: int) -> list[bytes]: """Return list of shares (bytes).""" secret_padded = pad(secret, 16) secret_padded_chunk = [ (threshold, num, secret_padded[i : i + 16]) for i in range(0, len(secret_padded), 16) ] - share_list: List[List[Tuple[int, bytes]]] = [[] for _ in range(num)] + share_list: list[list[tuple[int, bytes]]] = [[] for _ in range(num)] with ThreadPoolExecutor(max_workers=10) as executor: for chunk_shares in executor.map( @@ -43,22 +43,22 @@ def create_shares(secret: bytes, threshold: int, num: int) -> List[bytes]: return [pickle.dumps(shares) for shares in share_list] -def _shamir_split(threshold: int, num: int, chunk: bytes) -> List[Tuple[int, bytes]]: +def _shamir_split(threshold: int, num: int, chunk: bytes) -> list[tuple[int, bytes]]: return Shamir.split(threshold, num, chunk, ssss=False) # Reconstructing secret with PyCryptodome -def combine_shares(share_list: List[bytes]) -> bytes: +def combine_shares(share_list: list[bytes]) -> bytes: """Reconstruct secret from shares.""" - unpickled_share_list: List[List[Tuple[int, bytes]]] = [ - cast(List[Tuple[int, bytes]], pickle.loads(share)) for share in share_list + unpickled_share_list: list[list[tuple[int, bytes]]] = [ + cast(list[tuple[int, bytes]], pickle.loads(share)) for share in share_list ] chunk_num = len(unpickled_share_list[0]) secret_padded = bytearray(0) - chunk_shares_list: List[List[Tuple[int, bytes]]] = [] + chunk_shares_list: list[list[tuple[int, bytes]]] = [] for i in range(chunk_num): - chunk_shares: List[Tuple[int, bytes]] = [] + chunk_shares: list[tuple[int, bytes]] = [] for share in unpickled_share_list: chunk_shares.append(share[i]) chunk_shares_list.append(chunk_shares) @@ -71,5 +71,5 @@ def combine_shares(share_list: List[bytes]) -> bytes: return bytes(secret) -def _shamir_combine(shares: List[Tuple[int, bytes]]) -> bytes: +def _shamir_combine(shares: list[tuple[int, bytes]]) -> bytes: return Shamir.combine(shares, ssss=False) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 59ca84d604b8..f5c130fb2663 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -16,7 +16,7 @@ import base64 -from typing import Tuple, cast +from typing import cast from cryptography.exceptions import InvalidSignature from cryptography.fernet import Fernet @@ -26,7 +26,7 @@ def generate_key_pairs() -> ( - Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ): """Generate private and public key pairs with Cryptography.""" private_key = ec.generate_private_key(ec.SECP384R1()) diff --git a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py index 207c15b61518..3197fd852f3d 100644 --- a/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py +++ b/src/py/flwr/common/secure_aggregation/ndarrays_arithmetic.py @@ -15,51 +15,51 @@ """Utility functions for performing operations on Numpy NDArrays.""" -from typing import Any, List, Tuple, Union +from typing import Any, Union import numpy as np from numpy.typing import DTypeLike, NDArray -def factor_combine(factor: int, parameters: List[NDArray[Any]]) -> List[NDArray[Any]]: +def factor_combine(factor: int, parameters: list[NDArray[Any]]) -> list[NDArray[Any]]: """Combine factor with parameters.""" return [np.array([factor])] + parameters def factor_extract( - parameters: List[NDArray[Any]], -) -> Tuple[int, List[NDArray[Any]]]: + parameters: list[NDArray[Any]], +) -> tuple[int, list[NDArray[Any]]]: """Extract factor from parameters.""" return parameters[0][0], parameters[1:] -def get_parameters_shape(parameters: List[NDArray[Any]]) -> List[Tuple[int, ...]]: +def get_parameters_shape(parameters: list[NDArray[Any]]) -> list[tuple[int, ...]]: """Get dimensions of each NDArray in parameters.""" return [arr.shape for arr in parameters] def get_zero_parameters( - dimensions_list: List[Tuple[int, ...]], dtype: DTypeLike = np.int64 -) -> List[NDArray[Any]]: + dimensions_list: list[tuple[int, ...]], dtype: DTypeLike = np.int64 +) -> list[NDArray[Any]]: """Generate zero parameters based on the dimensions list.""" return [np.zeros(dimensions, dtype=dtype) for dimensions in dimensions_list] def parameters_addition( - parameters1: List[NDArray[Any]], parameters2: List[NDArray[Any]] -) -> List[NDArray[Any]]: + parameters1: list[NDArray[Any]], parameters2: list[NDArray[Any]] +) -> list[NDArray[Any]]: """Add two parameters.""" return [parameters1[idx] + parameters2[idx] for idx in range(len(parameters1))] def parameters_subtraction( - parameters1: List[NDArray[Any]], parameters2: List[NDArray[Any]] -) -> List[NDArray[Any]]: + parameters1: list[NDArray[Any]], parameters2: list[NDArray[Any]] +) -> list[NDArray[Any]]: """Subtract parameters from the other parameters.""" return [parameters1[idx] - parameters2[idx] for idx in range(len(parameters1))] -def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray[Any]]: +def parameters_mod(parameters: list[NDArray[Any]], divisor: int) -> list[NDArray[Any]]: """Take mod of parameters with an integer divisor.""" if bin(divisor).count("1") == 1: msk = divisor - 1 @@ -68,14 +68,14 @@ def parameters_mod(parameters: List[NDArray[Any]], divisor: int) -> List[NDArray def parameters_multiply( - parameters: List[NDArray[Any]], multiplier: Union[int, float] -) -> List[NDArray[Any]]: + parameters: list[NDArray[Any]], multiplier: Union[int, float] +) -> list[NDArray[Any]]: """Multiply parameters by an integer/float multiplier.""" return [parameters[idx] * multiplier for idx in range(len(parameters))] def parameters_divide( - parameters: List[NDArray[Any]], divisor: Union[int, float] -) -> List[NDArray[Any]]: + parameters: list[NDArray[Any]], divisor: Union[int, float] +) -> list[NDArray[Any]]: """Divide weight by an integer/float divisor.""" return [parameters[idx] / divisor for idx in range(len(parameters))] diff --git a/src/py/flwr/common/secure_aggregation/quantization.py b/src/py/flwr/common/secure_aggregation/quantization.py index 7946276b6a4f..ab8521eed981 100644 --- a/src/py/flwr/common/secure_aggregation/quantization.py +++ b/src/py/flwr/common/secure_aggregation/quantization.py @@ -15,7 +15,7 @@ """Utility functions for model quantization.""" -from typing import List, cast +from typing import cast import numpy as np @@ -30,10 +30,10 @@ def _stochastic_round(arr: NDArrayFloat) -> NDArrayInt: def quantize( - parameters: List[NDArrayFloat], clipping_range: float, target_range: int -) -> List[NDArrayInt]: + parameters: list[NDArrayFloat], clipping_range: float, target_range: int +) -> list[NDArrayInt]: """Quantize float Numpy arrays to integer Numpy arrays.""" - quantized_list: List[NDArrayInt] = [] + quantized_list: list[NDArrayInt] = [] quantizer = target_range / (2 * clipping_range) for arr in parameters: # Stochastic quantization @@ -49,12 +49,12 @@ def quantize( # Dequantize parameters to range [-clipping_range, clipping_range] def dequantize( - quantized_parameters: List[NDArrayInt], + quantized_parameters: list[NDArrayInt], clipping_range: float, target_range: int, -) -> List[NDArrayFloat]: +) -> list[NDArrayFloat]: """Dequantize integer Numpy arrays to float Numpy arrays.""" - reverse_quantized_list: List[NDArrayFloat] = [] + reverse_quantized_list: list[NDArrayFloat] = [] quantizer = (2 * clipping_range) / target_range shift = -clipping_range for arr in quantized_parameters: diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py index cf6ac3bfb003..7bfb80f57891 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_utils.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_utils.py @@ -15,8 +15,6 @@ """Utility functions for the SecAgg/SecAgg+ protocol.""" -from typing import List, Tuple - import numpy as np from flwr.common.typing import NDArrayInt @@ -54,7 +52,7 @@ def share_keys_plaintext_concat( ) -def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, bytes]: +def share_keys_plaintext_separate(plaintext: bytes) -> tuple[int, int, bytes, bytes]: """Retrieve arguments from bytes. Parameters @@ -83,8 +81,8 @@ def share_keys_plaintext_separate(plaintext: bytes) -> Tuple[int, int, bytes, by def pseudo_rand_gen( - seed: bytes, num_range: int, dimensions_list: List[Tuple[int, ...]] -) -> List[NDArrayInt]: + seed: bytes, num_range: int, dimensions_list: list[tuple[int, ...]] +) -> list[NDArrayInt]: """Seeded pseudo-random number generator for noise generation with Numpy.""" assert len(seed) & 0x3 == 0 seed32 = 0 diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 76265b9836d1..87e01b05d341 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -15,7 +15,9 @@ """ProtoBuf serialization and deserialization.""" -from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast +from collections import OrderedDict +from collections.abc import MutableMapping +from typing import Any, TypeVar, cast from google.protobuf.message import Message as GrpcMessage @@ -72,7 +74,7 @@ def parameters_to_proto(parameters: typing.Parameters) -> Parameters: def parameters_from_proto(msg: Parameters) -> typing.Parameters: """Deserialize `Parameters` from ProtoBuf.""" - tensors: List[bytes] = list(msg.tensors) + tensors: list[bytes] = list(msg.tensors) return typing.Parameters(tensors=tensors, tensor_type=msg.tensor_type) @@ -390,7 +392,7 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar: def _record_value_to_proto( - value: Any, allowed_types: List[type], proto_class: Type[T] + value: Any, allowed_types: list[type], proto_class: type[T] ) -> T: """Serialize `*RecordValue` to ProtoBuf. @@ -427,9 +429,9 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any: def _record_value_dict_to_proto( value_dict: TypedDict[str, Any], - allowed_types: List[type], - value_proto_class: Type[T], -) -> Dict[str, T]: + allowed_types: list[type], + value_proto_class: type[T], +) -> dict[str, T]: """Serialize the record value dict to ProtoBuf. Note: `bool` MUST be put in the front of allowd_types if it exists. @@ -447,7 +449,7 @@ def proto(_v: Any) -> T: def _record_value_dict_from_proto( value_dict_proto: MutableMapping[str, Any] -) -> Dict[str, Any]: +) -> dict[str, Any]: """Deserialize the record value dict from ProtoBuf.""" return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()} @@ -498,7 +500,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord """Deserialize MetricsRecord from ProtoBuf.""" return MetricsRecord( metrics_dict=cast( - Dict[str, typing.MetricsRecordValues], + dict[str, typing.MetricsRecordValues], _record_value_dict_from_proto(record_proto.data), ), keep_input=False, @@ -520,7 +522,7 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord """Deserialize ConfigsRecord from ProtoBuf.""" return ConfigsRecord( configs_dict=cast( - Dict[str, typing.ConfigsRecordValues], + dict[str, typing.ConfigsRecordValues], _record_value_dict_from_proto(record_proto.data), ), keep_input=False, diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index 013d04a32fd4..49d1e38fa897 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -16,7 +16,8 @@ import random import string -from typing import Any, Callable, Optional, OrderedDict, Type, TypeVar, Union, cast +from collections import OrderedDict +from typing import Any, Callable, Optional, TypeVar, Union, cast import pytest @@ -169,7 +170,7 @@ def get_str(self, length: Optional[int] = None) -> str: length = self.rng.randint(1, 10) return "".join(self.rng.choices(char_pool, k=length)) - def get_value(self, dtype: Type[T]) -> T: + def get_value(self, dtype: type[T]) -> T: """Create a value of a given type.""" ret: Any = None if dtype == bool: diff --git a/src/py/flwr/common/telemetry.py b/src/py/flwr/common/telemetry.py index 981cfe79966a..724f36d2b98f 100644 --- a/src/py/flwr/common/telemetry.py +++ b/src/py/flwr/common/telemetry.py @@ -25,7 +25,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto from pathlib import Path -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from flwr.common.version import package_name, package_version @@ -126,7 +126,7 @@ class EventType(str, Enum): # The type signature is not compatible with mypy, pylint and flake8 # so each of those needs to be disabled for this line. # pylint: disable-next=no-self-argument,arguments-differ,line-too-long - def _generate_next_value_(name: str, start: int, count: int, last_values: List[Any]) -> Any: # type: ignore # noqa: E501 + def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]) -> Any: # type: ignore # noqa: E501 return name # Ping @@ -189,7 +189,7 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A # Use the ThreadPoolExecutor with max_workers=1 to have a queue # and also ensure that telemetry calls are not blocking. -state: Dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = { +state: dict[str, Union[Optional[str], Optional[ThreadPoolExecutor]]] = { # Will be assigned ThreadPoolExecutor(max_workers=1) # in event() the first time it's required "executor": None, @@ -201,7 +201,7 @@ def _generate_next_value_(name: str, start: int, count: int, last_values: List[A def event( event_type: EventType, - event_details: Optional[Dict[str, Any]] = None, + event_details: Optional[dict[str, Any]] = None, ) -> Future: # type: ignore """Submit create_event to ThreadPoolExecutor to avoid blocking.""" if state["executor"] is None: @@ -213,7 +213,7 @@ def event( return result -def create_event(event_type: EventType, event_details: Optional[Dict[str, Any]]) -> str: +def create_event(event_type: EventType, event_details: Optional[dict[str, Any]]) -> str: """Create telemetry event.""" if state["source"] is None: state["source"] = _get_source_id() diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index b1dec8d0420b..081a957f28ff 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import numpy as np import numpy.typing as npt @@ -25,7 +25,7 @@ NDArray = npt.NDArray[Any] NDArrayInt = npt.NDArray[np.int_] NDArrayFloat = npt.NDArray[np.float_] -NDArrays = List[NDArray] +NDArrays = list[NDArray] # The following union type contains Python types corresponding to ProtoBuf types that # ProtoBuf considers to be "Scalar Value Types", even though some of them arguably do @@ -38,31 +38,31 @@ float, int, str, - List[bool], - List[bytes], - List[float], - List[int], - List[str], + list[bool], + list[bytes], + list[float], + list[int], + list[str], ] # Value types for common.MetricsRecord MetricsScalar = Union[int, float] -MetricsScalarList = Union[List[int], List[float]] +MetricsScalarList = Union[list[int], list[float]] MetricsRecordValues = Union[MetricsScalar, MetricsScalarList] # Value types for common.ConfigsRecord ConfigsScalar = Union[MetricsScalar, str, bytes, bool] -ConfigsScalarList = Union[MetricsScalarList, List[str], List[bytes], List[bool]] +ConfigsScalarList = Union[MetricsScalarList, list[str], list[bytes], list[bool]] ConfigsRecordValues = Union[ConfigsScalar, ConfigsScalarList] -Metrics = Dict[str, Scalar] -MetricsAggregationFn = Callable[[List[Tuple[int, Metrics]]], Metrics] +Metrics = dict[str, Scalar] +MetricsAggregationFn = Callable[[list[tuple[int, Metrics]]], Metrics] -Config = Dict[str, Scalar] -Properties = Dict[str, Scalar] +Config = dict[str, Scalar] +Properties = dict[str, Scalar] # Value type for user configs UserConfigValue = Union[bool, float, int, str] -UserConfig = Dict[str, UserConfigValue] +UserConfig = dict[str, UserConfigValue] class Code(Enum): @@ -103,7 +103,7 @@ class ClientAppOutputStatus: class Parameters: """Model parameters.""" - tensors: List[bytes] + tensors: list[bytes] tensor_type: str @@ -127,7 +127,7 @@ class FitIns: """Fit instructions for a client.""" parameters: Parameters - config: Dict[str, Scalar] + config: dict[str, Scalar] @dataclass @@ -137,7 +137,7 @@ class FitRes: status: Status parameters: Parameters num_examples: int - metrics: Dict[str, Scalar] + metrics: dict[str, Scalar] @dataclass @@ -145,7 +145,7 @@ class EvaluateIns: """Evaluate instructions for a client.""" parameters: Parameters - config: Dict[str, Scalar] + config: dict[str, Scalar] @dataclass @@ -155,7 +155,7 @@ class EvaluateRes: status: Status loss: float num_examples: int - metrics: Dict[str, Scalar] + metrics: dict[str, Scalar] @dataclass diff --git a/src/py/flwr/common/version.py b/src/py/flwr/common/version.py index ac13f70d8a88..141c16ac9367 100644 --- a/src/py/flwr/common/version.py +++ b/src/py/flwr/common/version.py @@ -15,15 +15,14 @@ """Flower package version helper.""" import importlib.metadata as importlib_metadata -from typing import Tuple -def _check_package(name: str) -> Tuple[str, str]: +def _check_package(name: str) -> tuple[str, str]: version: str = importlib_metadata.version(name) return name, version -def _version() -> Tuple[str, str]: +def _version() -> tuple[str, str]: """Read and return Flower package name and version. Returns diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 67fd54bfcae2..d156edaa3c99 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -19,10 +19,11 @@ import importlib.util import sys import threading +from collections.abc import Sequence from logging import INFO, WARN from os.path import isfile from pathlib import Path -from typing import Optional, Sequence, Set, Tuple +from typing import Optional import grpc from cryptography.exceptions import UnsupportedAlgorithm @@ -84,7 +85,7 @@ def start_server( # pylint: disable=too-many-arguments,too-many-locals strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, - certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + certificates: Optional[tuple[bytes, bytes, bytes]] = None, ) -> History: """Start a Flower server using the gRPC transport layer. @@ -333,7 +334,7 @@ def run_superlink() -> None: driver_server.wait_for_termination(timeout=1) -def _format_address(address: str) -> Tuple[str, str, int]: +def _format_address(address: str) -> tuple[str, str, int]: parsed_address = parse_address(address) if not parsed_address: sys.exit( @@ -345,8 +346,8 @@ def _format_address(address: str) -> Tuple[str, str, int]: def _try_setup_node_authentication( args: argparse.Namespace, - certificates: Optional[Tuple[bytes, bytes, bytes]], -) -> Optional[Tuple[Set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: + certificates: Optional[tuple[bytes, bytes, bytes]], +) -> Optional[tuple[set[bytes], ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]]: if ( not args.auth_list_public_keys and not args.auth_superlink_private_key @@ -381,7 +382,7 @@ def _try_setup_node_authentication( "to '--auth-list-public-keys'." ) - node_public_keys: Set[bytes] = set() + node_public_keys: set[bytes] = set() try: ssh_private_key = load_ssh_private_key( @@ -434,7 +435,7 @@ def _try_setup_node_authentication( def _try_obtain_certificates( args: argparse.Namespace, -) -> Optional[Tuple[bytes, bytes, bytes]]: +) -> Optional[tuple[bytes, bytes, bytes]]: # Obtain certificates if args.insecure: log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.") @@ -490,7 +491,7 @@ def _run_fleet_api_grpc_rere( address: str, state_factory: StateFactory, ffs_factory: FfsFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Run Fleet API (gRPC, request-response).""" @@ -518,7 +519,7 @@ def _run_fleet_api_grpc_adapter( address: str, state_factory: StateFactory, ffs_factory: FfsFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Fleet API (GrpcAdapter).""" # Create Fleet API gRPC server diff --git a/src/py/flwr/server/client_manager.py b/src/py/flwr/server/client_manager.py index 7956e282bd2c..175bd4a786ea 100644 --- a/src/py/flwr/server/client_manager.py +++ b/src/py/flwr/server/client_manager.py @@ -19,7 +19,7 @@ import threading from abc import ABC, abstractmethod from logging import INFO -from typing import Dict, List, Optional +from typing import Optional from flwr.common.logger import log @@ -67,7 +67,7 @@ def unregister(self, client: ClientProxy) -> None: """ @abstractmethod - def all(self) -> Dict[str, ClientProxy]: + def all(self) -> dict[str, ClientProxy]: """Return all available clients.""" @abstractmethod @@ -80,7 +80,7 @@ def sample( num_clients: int, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + ) -> list[ClientProxy]: """Sample a number of Flower ClientProxy instances.""" @@ -88,7 +88,7 @@ class SimpleClientManager(ClientManager): """Provides a pool of available clients.""" def __init__(self) -> None: - self.clients: Dict[str, ClientProxy] = {} + self.clients: dict[str, ClientProxy] = {} self._cv = threading.Condition() def __len__(self) -> int: @@ -170,7 +170,7 @@ def unregister(self, client: ClientProxy) -> None: with self._cv: self._cv.notify_all() - def all(self) -> Dict[str, ClientProxy]: + def all(self) -> dict[str, ClientProxy]: """Return all available clients.""" return self.clients @@ -179,7 +179,7 @@ def sample( num_clients: int, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + ) -> list[ClientProxy]: """Sample a number of Flower ClientProxy instances.""" # Block until at least num_clients are connected. if min_num_clients is None: diff --git a/src/py/flwr/server/compat/app_utils.py b/src/py/flwr/server/compat/app_utils.py index baff27307b88..8d2479f47d40 100644 --- a/src/py/flwr/server/compat/app_utils.py +++ b/src/py/flwr/server/compat/app_utils.py @@ -16,7 +16,6 @@ import threading -from typing import Dict, Tuple from ..client_manager import ClientManager from ..compat.driver_client_proxy import DriverClientProxy @@ -26,7 +25,7 @@ def start_update_client_manager_thread( driver: Driver, client_manager: ClientManager, -) -> Tuple[threading.Thread, threading.Event]: +) -> tuple[threading.Thread, threading.Event]: """Periodically update the nodes list in the client manager in a thread. This function starts a thread that periodically uses the associated driver to @@ -73,7 +72,7 @@ def _update_client_manager( ) -> None: """Update the nodes list in the client manager.""" # Loop until the driver is disconnected - registered_nodes: Dict[int, DriverClientProxy] = {} + registered_nodes: dict[int, DriverClientProxy] = {} while not f_stop.is_set(): all_node_ids = set(driver.get_node_ids()) dead_nodes = set(registered_nodes).difference(all_node_ids) diff --git a/src/py/flwr/server/compat/driver_client_proxy_test.py b/src/py/flwr/server/compat/driver_client_proxy_test.py index 31b917fa869b..a5b454c79f90 100644 --- a/src/py/flwr/server/compat/driver_client_proxy_test.py +++ b/src/py/flwr/server/compat/driver_client_proxy_test.py @@ -17,7 +17,8 @@ import unittest import unittest.mock -from typing import Any, Callable, Iterable, Optional, Union, cast +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union, cast from unittest.mock import Mock import numpy as np diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 4f888323e586..e8429e865db6 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -16,7 +16,8 @@ from abc import ABC, abstractmethod -from typing import Iterable, List, Optional +from collections.abc import Iterable +from typing import Optional from flwr.common import Message, RecordSet from flwr.common.typing import Run @@ -70,7 +71,7 @@ def create_message( # pylint: disable=too-many-arguments """ @abstractmethod - def get_node_ids(self) -> List[int]: + def get_node_ids(self) -> list[int]: """Get node IDs.""" @abstractmethod diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 2fe2c8a2e4aa..421dfd30ecb2 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -16,8 +16,9 @@ import time import warnings +from collections.abc import Iterable from logging import DEBUG, WARNING -from typing import Iterable, List, Optional, cast +from typing import Optional, cast import grpc @@ -192,7 +193,7 @@ def create_message( # pylint: disable=too-many-arguments ) return Message(metadata=metadata, content=content) - def get_node_ids(self) -> List[int]: + def get_node_ids(self) -> list[int]: """Get node IDs.""" self._init_run() # Call GrpcDriverStub method @@ -209,7 +210,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """ self._init_run() # Construct TaskIns - task_ins_list: List[TaskIns] = [] + task_ins_list: list[TaskIns] = [] for msg in messages: # Check message self._check_message(msg) @@ -255,7 +256,7 @@ def send_and_receive( # Pull messages end_time = time.time() + (timeout if timeout is not None else 0.0) - ret: List[Message] = [] + ret: list[Message] = [] while timeout is None or time.time() < end_time: res_msgs = self.pull_messages(msg_ids) ret.extend(res_msgs) diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 53406796750f..3a8a4b1bc73d 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,7 +17,8 @@ import time import warnings -from typing import Iterable, List, Optional, cast +from collections.abc import Iterable +from typing import Optional, cast from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet @@ -112,7 +113,7 @@ def create_message( # pylint: disable=too-many-arguments ) return Message(metadata=metadata, content=content) - def get_node_ids(self) -> List[int]: + def get_node_ids(self) -> list[int]: """Get node IDs.""" self._init_run() return list(self.state.get_nodes(cast(Run, self._run).run_id)) @@ -123,7 +124,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: This method takes an iterable of messages and sends each message to the node specified in `dst_node_id`. """ - task_ids: List[str] = [] + task_ids: list[str] = [] for msg in messages: # Check message self._check_message(msg) @@ -169,7 +170,7 @@ def send_and_receive( # Pull messages end_time = time.time() + (timeout if timeout is not None else 0.0) - ret: List[Message] = [] + ret: list[Message] = [] while timeout is None or time.time() < end_time: res_msgs = self.pull_messages(msg_ids) ret.extend(res_msgs) diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index ddfdb249c1b4..9e5aaeaa9ca7 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -17,7 +17,7 @@ import time import unittest -from typing import Iterable, List, Tuple +from collections.abc import Iterable from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -38,7 +38,7 @@ from .inmemory_driver import InMemoryDriver -def push_messages(driver: InMemoryDriver, num_nodes: int) -> Tuple[Iterable[str], int]: +def push_messages(driver: InMemoryDriver, num_nodes: int) -> tuple[Iterable[str], int]: """Help push messages to state.""" for _ in range(num_nodes): driver.state.create_node(ping_interval=PING_MAX_INTERVAL) @@ -55,7 +55,7 @@ def push_messages(driver: InMemoryDriver, num_nodes: int) -> Tuple[Iterable[str] def get_replies( driver: InMemoryDriver, msg_ids: Iterable[str], node_id: int -) -> List[str]: +) -> list[str]: """Help create message replies and pull taskres from state.""" taskins = driver.state.get_task_ins(node_id, limit=len(list(msg_ids))) for taskin in taskins: diff --git a/src/py/flwr/server/history.py b/src/py/flwr/server/history.py index 291974a4323c..50daf2e04de6 100644 --- a/src/py/flwr/server/history.py +++ b/src/py/flwr/server/history.py @@ -17,7 +17,6 @@ import pprint from functools import reduce -from typing import Dict, List, Tuple from flwr.common.typing import Scalar @@ -26,11 +25,11 @@ class History: """History class for training and/or evaluation metrics collection.""" def __init__(self) -> None: - self.losses_distributed: List[Tuple[int, float]] = [] - self.losses_centralized: List[Tuple[int, float]] = [] - self.metrics_distributed_fit: Dict[str, List[Tuple[int, Scalar]]] = {} - self.metrics_distributed: Dict[str, List[Tuple[int, Scalar]]] = {} - self.metrics_centralized: Dict[str, List[Tuple[int, Scalar]]] = {} + self.losses_distributed: list[tuple[int, float]] = [] + self.losses_centralized: list[tuple[int, float]] = [] + self.metrics_distributed_fit: dict[str, list[tuple[int, Scalar]]] = {} + self.metrics_distributed: dict[str, list[tuple[int, Scalar]]] = {} + self.metrics_centralized: dict[str, list[tuple[int, Scalar]]] = {} def add_loss_distributed(self, server_round: int, loss: float) -> None: """Add one loss entry (from distributed evaluation).""" @@ -41,7 +40,7 @@ def add_loss_centralized(self, server_round: int, loss: float) -> None: self.losses_centralized.append((server_round, loss)) def add_metrics_distributed_fit( - self, server_round: int, metrics: Dict[str, Scalar] + self, server_round: int, metrics: dict[str, Scalar] ) -> None: """Add metrics entries (from distributed fit).""" for key in metrics: @@ -52,7 +51,7 @@ def add_metrics_distributed_fit( self.metrics_distributed_fit[key].append((server_round, metrics[key])) def add_metrics_distributed( - self, server_round: int, metrics: Dict[str, Scalar] + self, server_round: int, metrics: dict[str, Scalar] ) -> None: """Add metrics entries (from distributed evaluation).""" for key in metrics: @@ -63,7 +62,7 @@ def add_metrics_distributed( self.metrics_distributed[key].append((server_round, metrics[key])) def add_metrics_centralized( - self, server_round: int, metrics: Dict[str, Scalar] + self, server_round: int, metrics: dict[str, Scalar] ) -> None: """Add metrics entries (from centralized evaluation).""" for key in metrics: diff --git a/src/py/flwr/server/server.py b/src/py/flwr/server/server.py index 5e2a0c6b2719..bdaa11ba20a2 100644 --- a/src/py/flwr/server/server.py +++ b/src/py/flwr/server/server.py @@ -19,7 +19,7 @@ import io import timeit from logging import INFO, WARN -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import ( Code, @@ -41,17 +41,17 @@ from .server_config import ServerConfig -FitResultsAndFailures = Tuple[ - List[Tuple[ClientProxy, FitRes]], - List[Union[Tuple[ClientProxy, FitRes], BaseException]], +FitResultsAndFailures = tuple[ + list[tuple[ClientProxy, FitRes]], + list[Union[tuple[ClientProxy, FitRes], BaseException]], ] -EvaluateResultsAndFailures = Tuple[ - List[Tuple[ClientProxy, EvaluateRes]], - List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], +EvaluateResultsAndFailures = tuple[ + list[tuple[ClientProxy, EvaluateRes]], + list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], ] -ReconnectResultsAndFailures = Tuple[ - List[Tuple[ClientProxy, DisconnectRes]], - List[Union[Tuple[ClientProxy, DisconnectRes], BaseException]], +ReconnectResultsAndFailures = tuple[ + list[tuple[ClientProxy, DisconnectRes]], + list[Union[tuple[ClientProxy, DisconnectRes], BaseException]], ] @@ -84,7 +84,7 @@ def client_manager(self) -> ClientManager: return self._client_manager # pylint: disable=too-many-locals - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: Optional[float]) -> tuple[History, float]: """Run federated averaging for a number of rounds.""" history = History() @@ -163,7 +163,7 @@ def evaluate_round( server_round: int, timeout: Optional[float], ) -> Optional[ - Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures] + tuple[Optional[float], dict[str, Scalar], EvaluateResultsAndFailures] ]: """Validate current global model on a number of clients.""" # Get clients and their respective instructions from strategy @@ -197,9 +197,9 @@ def evaluate_round( ) # Aggregate the evaluation results - aggregated_result: Tuple[ + aggregated_result: tuple[ Optional[float], - Dict[str, Scalar], + dict[str, Scalar], ] = self.strategy.aggregate_evaluate(server_round, results, failures) loss_aggregated, metrics_aggregated = aggregated_result @@ -210,7 +210,7 @@ def fit_round( server_round: int, timeout: Optional[float], ) -> Optional[ - Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures] + tuple[Optional[Parameters], dict[str, Scalar], FitResultsAndFailures] ]: """Perform a single round of federated averaging.""" # Get clients and their respective instructions from strategy @@ -245,9 +245,9 @@ def fit_round( ) # Aggregate training results - aggregated_result: Tuple[ + aggregated_result: tuple[ Optional[Parameters], - Dict[str, Scalar], + dict[str, Scalar], ] = self.strategy.aggregate_fit(server_round, results, failures) parameters_aggregated, metrics_aggregated = aggregated_result @@ -296,7 +296,7 @@ def _get_initial_parameters( def reconnect_clients( - client_instructions: List[Tuple[ClientProxy, ReconnectIns]], + client_instructions: list[tuple[ClientProxy, ReconnectIns]], max_workers: Optional[int], timeout: Optional[float], ) -> ReconnectResultsAndFailures: @@ -312,8 +312,8 @@ def reconnect_clients( ) # Gather results - results: List[Tuple[ClientProxy, DisconnectRes]] = [] - failures: List[Union[Tuple[ClientProxy, DisconnectRes], BaseException]] = [] + results: list[tuple[ClientProxy, DisconnectRes]] = [] + failures: list[Union[tuple[ClientProxy, DisconnectRes], BaseException]] = [] for future in finished_fs: failure = future.exception() if failure is not None: @@ -328,7 +328,7 @@ def reconnect_client( client: ClientProxy, reconnect: ReconnectIns, timeout: Optional[float], -) -> Tuple[ClientProxy, DisconnectRes]: +) -> tuple[ClientProxy, DisconnectRes]: """Instruct client to disconnect and (optionally) reconnect later.""" disconnect = client.reconnect( reconnect, @@ -339,7 +339,7 @@ def reconnect_client( def fit_clients( - client_instructions: List[Tuple[ClientProxy, FitIns]], + client_instructions: list[tuple[ClientProxy, FitIns]], max_workers: Optional[int], timeout: Optional[float], group_id: int, @@ -356,8 +356,8 @@ def fit_clients( ) # Gather results - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] for future in finished_fs: _handle_finished_future_after_fit( future=future, results=results, failures=failures @@ -367,7 +367,7 @@ def fit_clients( def fit_client( client: ClientProxy, ins: FitIns, timeout: Optional[float], group_id: int -) -> Tuple[ClientProxy, FitRes]: +) -> tuple[ClientProxy, FitRes]: """Refine parameters on a single client.""" fit_res = client.fit(ins, timeout=timeout, group_id=group_id) return client, fit_res @@ -375,8 +375,8 @@ def fit_client( def _handle_finished_future_after_fit( future: concurrent.futures.Future, # type: ignore - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], ) -> None: """Convert finished future into either a result or a failure.""" # Check if there was an exception @@ -386,7 +386,7 @@ def _handle_finished_future_after_fit( return # Successfully received a result from a client - result: Tuple[ClientProxy, FitRes] = future.result() + result: tuple[ClientProxy, FitRes] = future.result() _, res = result # Check result status code @@ -399,7 +399,7 @@ def _handle_finished_future_after_fit( def evaluate_clients( - client_instructions: List[Tuple[ClientProxy, EvaluateIns]], + client_instructions: list[tuple[ClientProxy, EvaluateIns]], max_workers: Optional[int], timeout: Optional[float], group_id: int, @@ -416,8 +416,8 @@ def evaluate_clients( ) # Gather results - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] for future in finished_fs: _handle_finished_future_after_evaluate( future=future, results=results, failures=failures @@ -430,7 +430,7 @@ def evaluate_client( ins: EvaluateIns, timeout: Optional[float], group_id: int, -) -> Tuple[ClientProxy, EvaluateRes]: +) -> tuple[ClientProxy, EvaluateRes]: """Evaluate parameters on a single client.""" evaluate_res = client.evaluate(ins, timeout=timeout, group_id=group_id) return client, evaluate_res @@ -438,8 +438,8 @@ def evaluate_client( def _handle_finished_future_after_evaluate( future: concurrent.futures.Future, # type: ignore - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], ) -> None: """Convert finished future into either a result or a failure.""" # Check if there was an exception @@ -449,7 +449,7 @@ def _handle_finished_future_after_evaluate( return # Successfully received a result from a client - result: Tuple[ClientProxy, EvaluateRes] = future.result() + result: tuple[ClientProxy, EvaluateRes] = future.result() _, res = result # Check result status code @@ -466,7 +466,7 @@ def init_defaults( config: Optional[ServerConfig], strategy: Optional[Strategy], client_manager: Optional[ClientManager], -) -> Tuple[Server, ServerConfig]: +) -> tuple[Server, ServerConfig]: """Create server instance if none was given.""" if server is None: if client_manager is None: diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index b80811a6f730..6e8f423fe115 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -19,7 +19,7 @@ import csv import tempfile from pathlib import Path -from typing import List, Optional +from typing import Optional import numpy as np from cryptography.hazmat.primitives.asymmetric import ec @@ -143,7 +143,7 @@ def reconnect( def test_fit_clients() -> None: """Test fit_clients.""" # Prepare - clients: List[ClientProxy] = [ + clients: list[ClientProxy] = [ FailingClient("0"), SuccessClient("1"), ] @@ -164,7 +164,7 @@ def test_fit_clients() -> None: def test_eval_clients() -> None: """Test eval_clients.""" # Prepare - clients: List[ClientProxy] = [ + clients: list[ClientProxy] = [ FailingClient("0"), SuccessClient("1"), ] diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index c668b55eebe6..d5ee7340f8ea 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -16,7 +16,7 @@ # mypy: disallow_untyped_calls=False from functools import reduce -from typing import Any, Callable, List, Tuple +from typing import Any, Callable import numpy as np @@ -24,7 +24,7 @@ from flwr.server.client_proxy import ClientProxy -def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: +def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays: """Compute weighted average.""" # Calculate the total number of examples used during training num_examples_total = sum(num_examples for (_, num_examples) in results) @@ -42,7 +42,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: return weights_prime -def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: +def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays: """Compute in-place weighted average.""" # Count total examples num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) @@ -67,7 +67,7 @@ def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: return params -def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: +def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays: """Compute median.""" # Create a list of weights and ignore the number of examples weights = [weights for weights, _ in results] @@ -80,7 +80,7 @@ def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: def aggregate_krum( - results: List[Tuple[NDArrays, int]], num_malicious: int, to_keep: int + results: list[tuple[NDArrays, int]], num_malicious: int, to_keep: int ) -> NDArrays: """Choose one parameter vector according to the Krum function. @@ -119,7 +119,7 @@ def aggregate_krum( # pylint: disable=too-many-locals def aggregate_bulyan( - results: List[Tuple[NDArrays, int]], + results: list[tuple[NDArrays, int]], num_malicious: int, aggregation_rule: Callable, # type: ignore **aggregation_rule_kwargs: Any, @@ -155,7 +155,7 @@ def aggregate_bulyan( "It is needed to ensure that the method reduces the attacker's leeway to " "the one proved in the paper." ) - selected_models_set: List[Tuple[NDArrays, int]] = [] + selected_models_set: list[tuple[NDArrays, int]] = [] theta = len(results) - 2 * num_malicious beta = theta - 2 * num_malicious @@ -200,7 +200,7 @@ def aggregate_bulyan( return parameters_aggregated -def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: +def weighted_loss_avg(results: list[tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results) weighted_losses = [num_examples * loss for num_examples, loss in results] @@ -208,7 +208,7 @@ def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: def aggregate_qffl( - parameters: NDArrays, deltas: List[NDArrays], hs_fll: List[NDArrays] + parameters: NDArrays, deltas: list[NDArrays], hs_fll: list[NDArrays] ) -> NDArrays: """Compute weighted average based on Q-FFL paper.""" demominator: float = np.sum(np.asarray(hs_fll)) @@ -225,7 +225,7 @@ def aggregate_qffl( return new_parameters -def _compute_distances(weights: List[NDArrays]) -> NDArray: +def _compute_distances(weights: list[NDArrays]) -> NDArray: """Compute distances between vectors. Input: weights - list of weights vectors @@ -265,7 +265,7 @@ def _trim_mean(array: NDArray, proportiontocut: float) -> NDArray: def aggregate_trimmed_avg( - results: List[Tuple[NDArrays, int]], proportiontocut: float + results: list[tuple[NDArrays, int]], proportiontocut: float ) -> NDArrays: """Compute trimmed average.""" # Create a list of weights and ignore the number of examples @@ -290,7 +290,7 @@ def _check_weights_equality(weights1: NDArrays, weights2: NDArrays) -> bool: def _find_reference_weights( - reference_weights: NDArrays, list_of_weights: List[NDArrays] + reference_weights: NDArrays, list_of_weights: list[NDArrays] ) -> int: """Find the reference weights by looping through the `list_of_weights`. @@ -320,7 +320,7 @@ def _find_reference_weights( def _aggregate_n_closest_weights( - reference_weights: NDArrays, results: List[Tuple[NDArrays, int]], beta_closest: int + reference_weights: NDArrays, results: list[tuple[NDArrays, int]], beta_closest: int ) -> NDArrays: """Calculate element-wise mean of the `N` closest values. diff --git a/src/py/flwr/server/strategy/aggregate_test.py b/src/py/flwr/server/strategy/aggregate_test.py index f8b4e3c03b50..9f9dba79ec7c 100644 --- a/src/py/flwr/server/strategy/aggregate_test.py +++ b/src/py/flwr/server/strategy/aggregate_test.py @@ -15,8 +15,6 @@ """Aggregation function tests.""" -from typing import List, Tuple - import numpy as np from .aggregate import ( @@ -49,7 +47,7 @@ def test_aggregate() -> None: def test_weighted_loss_avg_single_value() -> None: """Test weighted loss averaging.""" # Prepare - results: List[Tuple[int, float]] = [(5, 0.5)] + results: list[tuple[int, float]] = [(5, 0.5)] expected = 0.5 # Execute @@ -62,7 +60,7 @@ def test_weighted_loss_avg_single_value() -> None: def test_weighted_loss_avg_multiple_values() -> None: """Test weighted loss averaging.""" # Prepare - results: List[Tuple[int, float]] = [(1, 2.0), (2, 1.0), (1, 2.0)] + results: list[tuple[int, float]] = [(1, 2.0), (2, 1.0), (1, 2.0)] expected = 1.5 # Execute diff --git a/src/py/flwr/server/strategy/bulyan.py b/src/py/flwr/server/strategy/bulyan.py index a81406c255ad..84a261237ac5 100644 --- a/src/py/flwr/server/strategy/bulyan.py +++ b/src/py/flwr/server/strategy/bulyan.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union from flwr.common import ( FitRes, @@ -86,12 +86,12 @@ def __init__( num_malicious_clients: int = 0, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -125,9 +125,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using Bulyan.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/bulyan_test.py b/src/py/flwr/server/strategy/bulyan_test.py index 93a9ebda3783..c0b87c82a036 100644 --- a/src/py/flwr/server/strategy/bulyan_test.py +++ b/src/py/flwr/server/strategy/bulyan_test.py @@ -15,7 +15,6 @@ """Bulyan tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -62,7 +61,7 @@ def test_aggregate_fit() -> None: param_5: Parameters = ndarrays_to_parameters( [array([0.1, 0.1, 0.1, 0.1], dtype=float32)] ) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( diff --git a/src/py/flwr/server/strategy/dp_adaptive_clipping.py b/src/py/flwr/server/strategy/dp_adaptive_clipping.py index b25e1efdf0e9..77e70bb9af04 100644 --- a/src/py/flwr/server/strategy/dp_adaptive_clipping.py +++ b/src/py/flwr/server/strategy/dp_adaptive_clipping.py @@ -20,7 +20,7 @@ import math from logging import INFO, WARNING -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np @@ -156,14 +156,14 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" self.current_round_params = parameters_to_ndarrays(parameters) return self.strategy.configure_fit(server_round, parameters, client_manager) def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -172,9 +172,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results and update clip norms.""" if failures: return None, {} @@ -245,15 +245,15 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) @@ -372,7 +372,7 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} inner_strategy_config_result = self.strategy.configure_fit( @@ -385,7 +385,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -394,9 +394,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results and update clip norms.""" if failures: return None, {} @@ -432,7 +432,7 @@ def aggregate_fit( return aggregated_params, metrics - def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: + def _update_clip_norm(self, results: list[tuple[ClientProxy, FitRes]]) -> None: # Calculate the number of clients which set the norm indicator bit norm_bit_set_count = 0 for client_proxy, fit_res in results: @@ -457,14 +457,14 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dp_fixed_clipping.py b/src/py/flwr/server/strategy/dp_fixed_clipping.py index 92b2845fd846..2ca253c96370 100644 --- a/src/py/flwr/server/strategy/dp_fixed_clipping.py +++ b/src/py/flwr/server/strategy/dp_fixed_clipping.py @@ -19,7 +19,7 @@ from logging import INFO, WARNING -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import ( EvaluateIns, @@ -117,14 +117,14 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" self.current_round_params = parameters_to_ndarrays(parameters) return self.strategy.configure_fit(server_round, parameters, client_manager) def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -133,9 +133,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Compute the updates, clip, and pass them for aggregation. Afterward, add noise to the aggregated parameters. @@ -191,15 +191,15 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) @@ -285,7 +285,7 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" additional_config = {KEY_CLIPPING_NORM: self.clipping_norm} inner_strategy_config_result = self.strategy.configure_fit( @@ -298,7 +298,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" return self.strategy.configure_evaluate( server_round, parameters, client_manager @@ -307,9 +307,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Add noise to the aggregated parameters.""" if failures: return None, {} @@ -348,14 +348,14 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 423ddddeb379..ab513aba2269 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -19,7 +19,7 @@ import math -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np @@ -80,7 +80,7 @@ def __repr__(self) -> str: def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" additional_config = {"dpfedavg_adaptive_clip_enabled": True} @@ -93,7 +93,7 @@ def configure_fit( return client_instructions - def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: + def _update_clip_norm(self, results: list[tuple[ClientProxy, FitRes]]) -> None: # Calculating number of clients which set the norm indicator bit norm_bit_set_count = 0 for client_proxy, fit_res in results: @@ -118,9 +118,9 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results as in DPFedAvgFixed and update clip norms.""" if failures: return None, {} diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index d122f0688922..4ea84db30cd4 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -17,7 +17,7 @@ Paper: arxiv.org/pdf/1710.06963.pdf """ -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.common.dp import add_gaussian_noise @@ -79,7 +79,7 @@ def initialize_parameters( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training incorporating Differential Privacy (DP). Configuration of the next training round includes information related to DP, @@ -119,7 +119,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation using the specified strategy. Parameters @@ -147,9 +147,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results using unweighted aggregation.""" if failures: return None, {} @@ -168,14 +168,14 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using the given strategy.""" return self.strategy.aggregate_evaluate(server_round, results, failures) def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function from the strategy.""" return self.strategy.evaluate(server_round, parameters) diff --git a/src/py/flwr/server/strategy/fault_tolerant_fedavg.py b/src/py/flwr/server/strategy/fault_tolerant_fedavg.py index 663ac8872c39..60213db2efeb 100644 --- a/src/py/flwr/server/strategy/fault_tolerant_fedavg.py +++ b/src/py/flwr/server/strategy/fault_tolerant_fedavg.py @@ -16,7 +16,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( EvaluateRes, @@ -49,12 +49,12 @@ def __init__( min_available_clients: int = 1, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, min_completion_rate_fit: float = 0.5, min_completion_rate_evaluate: float = 0.5, initial_parameters: Optional[Parameters] = None, @@ -85,9 +85,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -117,9 +117,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py b/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py index 98f4cac032cb..a01a3a5c0ad5 100644 --- a/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py +++ b/src/py/flwr/server/strategy/fault_tolerant_fedavg_test.py @@ -15,7 +15,7 @@ """FaultTolerantFedAvg tests.""" -from typing import List, Optional, Tuple, Union +from typing import Optional, Union from unittest.mock import MagicMock from flwr.common import ( @@ -36,8 +36,8 @@ def test_aggregate_fit_no_results_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.1) - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: Optional[Parameters] = None # Execute @@ -51,8 +51,8 @@ def test_aggregate_fit_no_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.1) - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [Exception()] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [Exception()] expected: Optional[Parameters] = None # Execute @@ -66,7 +66,7 @@ def test_aggregate_fit_not_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.5) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -77,7 +77,7 @@ def test_aggregate_fit_not_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [ + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [ Exception(), Exception(), ] @@ -94,7 +94,7 @@ def test_aggregate_fit_just_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.5) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -105,7 +105,7 @@ def test_aggregate_fit_just_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [Exception()] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [Exception()] expected: Optional[NDArrays] = [] # Execute @@ -120,7 +120,7 @@ def test_aggregate_fit_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_fit=0.99) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -131,7 +131,7 @@ def test_aggregate_fit_no_failures() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: Optional[NDArrays] = [] # Execute @@ -146,8 +146,8 @@ def test_aggregate_evaluate_no_results_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.1) - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] expected: Optional[float] = None # Execute @@ -161,8 +161,8 @@ def test_aggregate_evaluate_no_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.1) - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [ Exception() ] expected: Optional[float] = None @@ -178,7 +178,7 @@ def test_aggregate_evaluate_not_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.5) - results: List[Tuple[ClientProxy, EvaluateRes]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [ ( MagicMock(), EvaluateRes( @@ -189,7 +189,7 @@ def test_aggregate_evaluate_not_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [ + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [ Exception(), Exception(), ] @@ -206,7 +206,7 @@ def test_aggregate_evaluate_just_enough_results() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.5) - results: List[Tuple[ClientProxy, EvaluateRes]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [ ( MagicMock(), EvaluateRes( @@ -217,7 +217,7 @@ def test_aggregate_evaluate_just_enough_results() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [ + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [ Exception() ] expected: Optional[float] = 2.3 @@ -233,7 +233,7 @@ def test_aggregate_evaluate_no_failures() -> None: """Test evaluate function.""" # Prepare strategy = FaultTolerantFedAvg(min_completion_rate_evaluate=0.99) - results: List[Tuple[ClientProxy, EvaluateRes]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [ ( MagicMock(), EvaluateRes( @@ -244,7 +244,7 @@ def test_aggregate_evaluate_no_failures() -> None: ), ) ] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] expected: Optional[float] = 2.3 # Execute diff --git a/src/py/flwr/server/strategy/fedadagrad.py b/src/py/flwr/server/strategy/fedadagrad.py index f13c5358da25..75befdd0e796 100644 --- a/src/py/flwr/server/strategy/fedadagrad.py +++ b/src/py/flwr/server/strategy/fedadagrad.py @@ -20,7 +20,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -89,12 +89,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, accept_failures: bool = True, @@ -131,9 +131,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures diff --git a/src/py/flwr/server/strategy/fedadagrad_test.py b/src/py/flwr/server/strategy/fedadagrad_test.py index b43a4c75d123..96d98fe750f3 100644 --- a/src/py/flwr/server/strategy/fedadagrad_test.py +++ b/src/py/flwr/server/strategy/fedadagrad_test.py @@ -15,7 +15,6 @@ """FedAdagrad tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -54,7 +53,7 @@ def test_aggregate_fit() -> None: bridge = MagicMock() client_0 = GrpcClientProxy(cid="0", bridge=bridge) client_1 = GrpcClientProxy(cid="1", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/fedadam.py b/src/py/flwr/server/strategy/fedadam.py index dc90e90c7568..d0f87a43f79b 100644 --- a/src/py/flwr/server/strategy/fedadam.py +++ b/src/py/flwr/server/strategy/fedadam.py @@ -20,7 +20,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -93,12 +93,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Parameters, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -137,9 +137,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index 3b9b2640c2b5..2d0b855c3186 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( EvaluateIns, @@ -99,12 +99,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -138,12 +138,12 @@ def __repr__(self) -> str: rep = f"FedAvg(accept_failures={self.accept_failures})" return rep - def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]: """Return the sample size and the required number of available clients.""" num_clients = int(num_available_clients * self.fraction_fit) return max(num_clients, self.min_fit_clients), self.min_available_clients - def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]: """Use a fraction of available clients for evaluation.""" num_clients = int(num_available_clients * self.fraction_evaluate) return max(num_clients, self.min_evaluate_clients), self.min_available_clients @@ -158,7 +158,7 @@ def initialize_parameters( def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided @@ -172,7 +172,7 @@ def evaluate( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" config = {} if self.on_fit_config_fn is not None: @@ -193,7 +193,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction eval is 0. if self.fraction_evaluate == 0.0: @@ -220,9 +220,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -256,9 +256,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index 2f49cf8784c9..bcecf8efb504 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Callable, Optional, Union, cast import numpy as np @@ -81,12 +81,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, ) -> None: @@ -107,12 +107,12 @@ def __repr__(self) -> str: rep = f"FedAvg(accept_failures={self.accept_failures})" return rep - def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]: """Return the sample size and the required number of available clients.""" num_clients = int(num_available_clients * self.fraction_fit) return max(num_clients, self.min_fit_clients), self.min_available_clients - def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]: """Use a fraction of available clients for evaluation.""" num_clients = int(num_available_clients * self.fraction_evaluate) return max(num_clients, self.min_evaluate_clients), self.min_available_clients @@ -127,7 +127,7 @@ def initialize_parameters( def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided @@ -141,7 +141,7 @@ def evaluate( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" config = {} if self.on_fit_config_fn is not None: @@ -162,7 +162,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction_evaluate is 0 if self.fraction_evaluate == 0.0: @@ -189,9 +189,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -208,9 +208,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedavg_test.py b/src/py/flwr/server/strategy/fedavg_test.py index e62eaa5c5832..66241c3ab66a 100644 --- a/src/py/flwr/server/strategy/fedavg_test.py +++ b/src/py/flwr/server/strategy/fedavg_test.py @@ -15,7 +15,7 @@ """FedAvg tests.""" -from typing import List, Tuple, Union +from typing import Union from unittest.mock import MagicMock import numpy as np @@ -140,7 +140,7 @@ def test_inplace_aggregate_fit_equivalence() -> None: weights1_0 = np.random.randn(100, 64) weights1_1 = np.random.randn(314, 628, 3) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -160,7 +160,7 @@ def test_inplace_aggregate_fit_equivalence() -> None: ), ), ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] fedavg_reference = FedAvg(inplace=False) fedavg_inplace = FedAvg() diff --git a/src/py/flwr/server/strategy/fedavgm.py b/src/py/flwr/server/strategy/fedavgm.py index ab3d37249db6..a7c37c38770f 100644 --- a/src/py/flwr/server/strategy/fedavgm.py +++ b/src/py/flwr/server/strategy/fedavgm.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( FitRes, @@ -84,12 +84,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -132,9 +132,9 @@ def initialize_parameters( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedavgm_test.py b/src/py/flwr/server/strategy/fedavgm_test.py index 39da5f4b82c4..400fa3c97247 100644 --- a/src/py/flwr/server/strategy/fedavgm_test.py +++ b/src/py/flwr/server/strategy/fedavgm_test.py @@ -15,7 +15,7 @@ """FedAvgM tests.""" -from typing import List, Tuple, Union +from typing import Union from unittest.mock import MagicMock from numpy import array, float32 @@ -41,7 +41,7 @@ def test_aggregate_fit_using_near_one_server_lr_and_no_momentum() -> None: array([0, 0, 0, 0], dtype=float32), ] - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -61,7 +61,7 @@ def test_aggregate_fit_using_near_one_server_lr_and_no_momentum() -> None: ), ), ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: NDArrays = [ array([[1, 2, 3], [4, 5, 6]], dtype=float32), array([7, 8, 9, 10], dtype=float32), @@ -94,7 +94,7 @@ def test_aggregate_fit_server_learning_rate_and_momentum() -> None: array([0, 0, 0, 0], dtype=float32), ] - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( MagicMock(), FitRes( @@ -114,7 +114,7 @@ def test_aggregate_fit_server_learning_rate_and_momentum() -> None: ), ), ] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] expected: NDArrays = [ array([[1, 2, 3], [4, 5, 6]], dtype=float32), array([7, 8, 9, 10], dtype=float32), diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index e7cba5324fa8..35044d42b22c 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import ( FitRes, @@ -46,9 +46,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using median.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedmedian_test.py b/src/py/flwr/server/strategy/fedmedian_test.py index 3960ad70b145..2c9881635319 100644 --- a/src/py/flwr/server/strategy/fedmedian_test.py +++ b/src/py/flwr/server/strategy/fedmedian_test.py @@ -15,7 +15,6 @@ """FedMedian tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -159,7 +158,7 @@ def test_aggregate_fit() -> None: client_0 = GrpcClientProxy(cid="0", bridge=bridge) client_1 = GrpcClientProxy(cid="1", bridge=bridge) client_2 = GrpcClientProxy(cid="2", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/fedopt.py b/src/py/flwr/server/strategy/fedopt.py index c581d4797123..3e143fc3ca59 100644 --- a/src/py/flwr/server/strategy/fedopt.py +++ b/src/py/flwr/server/strategy/fedopt.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Optional from flwr.common import ( MetricsAggregationFn, @@ -86,12 +86,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Parameters, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, diff --git a/src/py/flwr/server/strategy/fedprox.py b/src/py/flwr/server/strategy/fedprox.py index f15271e06060..218fece0491f 100644 --- a/src/py/flwr/server/strategy/fedprox.py +++ b/src/py/flwr/server/strategy/fedprox.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Optional from flwr.common import FitIns, MetricsAggregationFn, NDArrays, Parameters, Scalar from flwr.server.client_manager import ClientManager @@ -113,12 +113,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -148,7 +148,7 @@ def __repr__(self) -> str: def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training. Sends the proximal factor mu to the clients diff --git a/src/py/flwr/server/strategy/fedtrimmedavg.py b/src/py/flwr/server/strategy/fedtrimmedavg.py index 96b0d35e7a61..8a0e4e50fbff 100644 --- a/src/py/flwr/server/strategy/fedtrimmedavg.py +++ b/src/py/flwr/server/strategy/fedtrimmedavg.py @@ -17,7 +17,7 @@ Paper: arxiv.org/abs/1803.01498 """ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( FitRes, @@ -78,12 +78,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -114,9 +114,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using trimmed average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedxgb_bagging.py b/src/py/flwr/server/strategy/fedxgb_bagging.py index a74ee81976a6..1e55466808f8 100644 --- a/src/py/flwr/server/strategy/fedxgb_bagging.py +++ b/src/py/flwr/server/strategy/fedxgb_bagging.py @@ -17,7 +17,7 @@ import json from logging import WARNING -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Optional, Union, cast from flwr.common import EvaluateRes, FitRes, Parameters, Scalar from flwr.common.logger import log @@ -34,8 +34,8 @@ def __init__( self, evaluate_function: Optional[ Callable[ - [int, Parameters, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, Parameters, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, **kwargs: Any, @@ -52,9 +52,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using bagging.""" if not results: return None, {} @@ -79,9 +79,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation metrics using average.""" if not results: return None, {} @@ -101,7 +101,7 @@ def aggregate_evaluate( def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_function is None: # No evaluation function provided @@ -152,7 +152,7 @@ def aggregate( return bst_prev_bytes -def _get_tree_nums(xgb_model_org: bytes) -> Tuple[int, int]: +def _get_tree_nums(xgb_model_org: bytes) -> tuple[int, int]: xgb_model = json.loads(bytearray(xgb_model_org)) # Get the number of trees tree_num = int( diff --git a/src/py/flwr/server/strategy/fedxgb_cyclic.py b/src/py/flwr/server/strategy/fedxgb_cyclic.py index 75025a89728b..c2dc3d797c7e 100644 --- a/src/py/flwr/server/strategy/fedxgb_cyclic.py +++ b/src/py/flwr/server/strategy/fedxgb_cyclic.py @@ -16,7 +16,7 @@ from logging import WARNING -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.common.logger import log @@ -45,9 +45,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using bagging.""" if not results: return None, {} @@ -69,9 +69,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation metrics using average.""" if not results: return None, {} @@ -91,7 +91,7 @@ def aggregate_evaluate( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" config = {} if self.on_fit_config_fn is not None: @@ -117,7 +117,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction eval is 0. if self.fraction_evaluate == 0.0: diff --git a/src/py/flwr/server/strategy/fedxgb_nn_avg.py b/src/py/flwr/server/strategy/fedxgb_nn_avg.py index 4562663287ae..a7da4a919af7 100644 --- a/src/py/flwr/server/strategy/fedxgb_nn_avg.py +++ b/src/py/flwr/server/strategy/fedxgb_nn_avg.py @@ -22,7 +22,7 @@ from logging import WARNING -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from flwr.common import FitRes, Scalar, ndarrays_to_parameters, parameters_to_ndarrays from flwr.common.logger import log, warn_deprecated_feature @@ -56,7 +56,7 @@ def __repr__(self) -> str: def evaluate( self, server_round: int, parameters: Any - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate model parameters using an evaluation function.""" if self.evaluate_fn is None: # No evaluation function provided @@ -70,9 +70,9 @@ def evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Any], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Any], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/fedyogi.py b/src/py/flwr/server/strategy/fedyogi.py index c7b2ebb51667..11873d1b781f 100644 --- a/src/py/flwr/server/strategy/fedyogi.py +++ b/src/py/flwr/server/strategy/fedyogi.py @@ -18,7 +18,7 @@ """ -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -93,12 +93,12 @@ def __init__( min_available_clients: int = 2, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Parameters, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -137,9 +137,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures diff --git a/src/py/flwr/server/strategy/krum.py b/src/py/flwr/server/strategy/krum.py index 074d018c35a3..5d33874b9789 100644 --- a/src/py/flwr/server/strategy/krum.py +++ b/src/py/flwr/server/strategy/krum.py @@ -21,7 +21,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union from flwr.common import ( FitRes, @@ -87,12 +87,12 @@ def __init__( num_clients_to_keep: int = 0, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -123,9 +123,9 @@ def __repr__(self) -> str: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using Krum.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/krum_test.py b/src/py/flwr/server/strategy/krum_test.py index b34982325b39..dc996b480630 100644 --- a/src/py/flwr/server/strategy/krum_test.py +++ b/src/py/flwr/server/strategy/krum_test.py @@ -15,7 +15,6 @@ """Krum tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -160,7 +159,7 @@ def test_aggregate_fit() -> None: client_0 = GrpcClientProxy(cid="0", bridge=bridge) client_1 = GrpcClientProxy(cid="1", bridge=bridge) client_2 = GrpcClientProxy(cid="2", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/multikrum_test.py b/src/py/flwr/server/strategy/multikrum_test.py index 7a1a4c3ecf38..90607e2c0edc 100644 --- a/src/py/flwr/server/strategy/multikrum_test.py +++ b/src/py/flwr/server/strategy/multikrum_test.py @@ -15,7 +15,6 @@ """Krum tests.""" -from typing import List, Tuple from unittest.mock import MagicMock from numpy import array, float32 @@ -59,7 +58,7 @@ def test_aggregate_fit() -> None: client_1 = GrpcClientProxy(cid="1", bridge=bridge) client_2 = GrpcClientProxy(cid="2", bridge=bridge) - results: List[Tuple[ClientProxy, FitRes]] = [ + results: list[tuple[ClientProxy, FitRes]] = [ ( client_0, FitRes( diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 26a397d4cf8c..30a3cc53ee94 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -19,7 +19,7 @@ from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np @@ -60,12 +60,12 @@ def __init__( min_available_clients: int = 1, evaluate_fn: Optional[ Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + Optional[tuple[float, dict[str, Scalar]]], ] ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_fit_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], dict[str, Scalar]]] = None, accept_failures: bool = True, initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, @@ -95,19 +95,19 @@ def __repr__(self) -> str: rep += f"q_param={self.q_param}, pre_weights={self.pre_weights})" return rep - def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_fit_clients(self, num_available_clients: int) -> tuple[int, int]: """Return the sample size and the required number of available clients.""" num_clients = int(num_available_clients * self.fraction_fit) return max(num_clients, self.min_fit_clients), self.min_available_clients - def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + def num_evaluation_clients(self, num_available_clients: int) -> tuple[int, int]: """Use a fraction of available clients for evaluation.""" num_clients = int(num_available_clients * self.fraction_evaluate) return max(num_clients, self.min_evaluate_clients), self.min_available_clients def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training.""" weights = parameters_to_ndarrays(parameters) self.pre_weights = weights @@ -131,7 +131,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation.""" # Do not configure federated evaluation if fraction_evaluate is 0 if self.fraction_evaluate == 0.0: @@ -158,9 +158,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate fit results using weighted average.""" if not results: return None, {} @@ -229,9 +229,9 @@ def norm_grad(grad_list: NDArrays) -> float: def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation losses using weighted average.""" if not results: return None, {} diff --git a/src/py/flwr/server/strategy/strategy.py b/src/py/flwr/server/strategy/strategy.py index cfdfe2e246c5..14999e9a8993 100644 --- a/src/py/flwr/server/strategy/strategy.py +++ b/src/py/flwr/server/strategy/strategy.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar from flwr.server.client_manager import ClientManager @@ -47,7 +47,7 @@ def initialize_parameters( @abstractmethod def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training. Parameters @@ -72,9 +72,9 @@ def configure_fit( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]], + ) -> tuple[Optional[Parameters], dict[str, Scalar]]: """Aggregate training results. Parameters @@ -108,7 +108,7 @@ def aggregate_fit( @abstractmethod def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """Configure the next round of evaluation. Parameters @@ -134,9 +134,9 @@ def configure_evaluate( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Aggregate evaluation results. Parameters @@ -164,7 +164,7 @@ def aggregate_evaluate( @abstractmethod def evaluate( self, server_round: int, parameters: Parameters - ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + ) -> Optional[tuple[float, dict[str, Scalar]]]: """Evaluate the current model parameters. This function can be used to perform centralized (i.e., server-side) evaluation diff --git a/src/py/flwr/server/superlink/driver/driver_grpc.py b/src/py/flwr/server/superlink/driver/driver_grpc.py index b7b914206f72..70354387812e 100644 --- a/src/py/flwr/server/superlink/driver/driver_grpc.py +++ b/src/py/flwr/server/superlink/driver/driver_grpc.py @@ -15,7 +15,7 @@ """Driver gRPC API.""" from logging import INFO -from typing import Optional, Tuple +from typing import Optional import grpc @@ -35,7 +35,7 @@ def run_driver_api_grpc( address: str, state_factory: StateFactory, ffs_factory: FfsFactory, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], ) -> grpc.Server: """Run Driver API (gRPC, request-response).""" # Create Driver API gRPC server diff --git a/src/py/flwr/server/superlink/driver/driver_servicer.py b/src/py/flwr/server/superlink/driver/driver_servicer.py index 73cd1c73a6fd..4d7d6cb6ce89 100644 --- a/src/py/flwr/server/superlink/driver/driver_servicer.py +++ b/src/py/flwr/server/superlink/driver/driver_servicer.py @@ -17,7 +17,7 @@ import time from logging import DEBUG -from typing import List, Optional, Set +from typing import Optional from uuid import UUID import grpc @@ -68,8 +68,8 @@ def GetNodes( """Get available nodes.""" log(DEBUG, "DriverServicer.GetNodes") state: State = self.state_factory.state() - all_ids: Set[int] = state.get_nodes(request.run_id) - nodes: List[Node] = [ + all_ids: set[int] = state.get_nodes(request.run_id) + nodes: list[Node] = [ Node(node_id=node_id, anonymous=False) for node_id in all_ids ] return GetNodesResponse(nodes=nodes) @@ -119,7 +119,7 @@ def PushTaskIns( state: State = self.state_factory.state() # Store each TaskIns - task_ids: List[Optional[UUID]] = [] + task_ids: list[Optional[UUID]] = [] for task_ins in request.task_ins_list: task_id: Optional[UUID] = state.store_task_ins(task_ins=task_ins) task_ids.append(task_id) @@ -135,7 +135,7 @@ def PullTaskRes( log(DEBUG, "DriverServicer.PullTaskRes") # Convert each task_id str to UUID - task_ids: Set[UUID] = {UUID(task_id) for task_id in request.task_ids} + task_ids: set[UUID] = {UUID(task_id) for task_id in request.task_ids} # Init state state: State = self.state_factory.state() @@ -155,7 +155,7 @@ def on_rpc_done() -> None: context.add_callback(on_rpc_done) # Read from state - task_res_list: List[TaskRes] = state.get_task_res(task_ids=task_ids, limit=None) + task_res_list: list[TaskRes] = state.get_task_res(task_ids=task_ids, limit=None) context.set_code(grpc.StatusCode.OK) return PullTaskResResponse(task_res_list=task_res_list) diff --git a/src/py/flwr/server/superlink/ffs/disk_ffs.py b/src/py/flwr/server/superlink/ffs/disk_ffs.py index 98ec4f93498f..4f1ab05be9a2 100644 --- a/src/py/flwr/server/superlink/ffs/disk_ffs.py +++ b/src/py/flwr/server/superlink/ffs/disk_ffs.py @@ -17,7 +17,7 @@ import hashlib import json from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Optional from flwr.server.superlink.ffs.ffs import Ffs @@ -35,7 +35,7 @@ def __init__(self, base_dir: str) -> None: """ self.base_dir = Path(base_dir) - def put(self, content: bytes, meta: Dict[str, str]) -> str: + def put(self, content: bytes, meta: dict[str, str]) -> str: """Store bytes and metadata and return key (hash of content). Parameters @@ -58,7 +58,7 @@ def put(self, content: bytes, meta: Dict[str, str]) -> str: return content_hash - def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]: + def get(self, key: str) -> Optional[tuple[bytes, dict[str, str]]]: """Return tuple containing the object content and metadata. Parameters @@ -90,7 +90,7 @@ def delete(self, key: str) -> None: (self.base_dir / key).unlink() (self.base_dir / f"{key}.META").unlink() - def list(self) -> List[str]: + def list(self) -> list[str]: """List all keys. Return all available keys in this `Ffs` instance. diff --git a/src/py/flwr/server/superlink/ffs/ffs.py b/src/py/flwr/server/superlink/ffs/ffs.py index fab3b1fdfb3e..b1d26e74c157 100644 --- a/src/py/flwr/server/superlink/ffs/ffs.py +++ b/src/py/flwr/server/superlink/ffs/ffs.py @@ -16,14 +16,14 @@ import abc -from typing import Dict, List, Optional, Tuple +from typing import Optional class Ffs(abc.ABC): # pylint: disable=R0904 """Abstract Flower File Storage interface for large objects.""" @abc.abstractmethod - def put(self, content: bytes, meta: Dict[str, str]) -> str: + def put(self, content: bytes, meta: dict[str, str]) -> str: """Store bytes and metadata and return sha256hex hash of data as str. Parameters @@ -40,7 +40,7 @@ def put(self, content: bytes, meta: Dict[str, str]) -> str: """ @abc.abstractmethod - def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]: + def get(self, key: str) -> Optional[tuple[bytes, dict[str, str]]]: """Return tuple containing the object content and metadata. Parameters @@ -65,7 +65,7 @@ def delete(self, key: str) -> None: """ @abc.abstractmethod - def list(self) -> List[str]: + def list(self) -> list[str]: """List keys of all stored objects. Return all available keys in this `Ffs` instance. diff --git a/src/py/flwr/server/superlink/ffs/ffs_test.py b/src/py/flwr/server/superlink/ffs/ffs_test.py index f7fbbf1218e1..5cf28cfd2cbe 100644 --- a/src/py/flwr/server/superlink/ffs/ffs_test.py +++ b/src/py/flwr/server/superlink/ffs/ffs_test.py @@ -21,7 +21,6 @@ import tempfile import unittest from abc import abstractmethod -from typing import Dict from flwr.server.superlink.ffs import DiskFfs, Ffs @@ -65,7 +64,7 @@ def test_get(self) -> None: ffs: Ffs = self.ffs_factory() content_expected = b"content" hash_expected = hashlib.sha256(content_expected).hexdigest() - meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + meta_expected: dict[str, str] = {"meta_key": "meta_value"} with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: file.write(content_expected) @@ -93,7 +92,7 @@ def test_delete(self) -> None: ffs: Ffs = self.ffs_factory() content_expected = b"content" hash_expected = hashlib.sha256(content_expected).hexdigest() - meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + meta_expected: dict[str, str] = {"meta_key": "meta_value"} with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: file.write(content_expected) @@ -117,7 +116,7 @@ def test_list(self) -> None: ffs: Ffs = self.ffs_factory() content_expected = b"content" hash_expected = hashlib.sha256(content_expected).hexdigest() - meta_expected: Dict[str, str] = {"meta_key": "meta_value"} + meta_expected: dict[str, str] = {"meta_key": "meta_value"} with open(os.path.join(self.tmp_dir.name, hash_expected), "wb") as file: file.write(content_expected) diff --git a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py index 278e20eb1d69..dbfbb236a7e4 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py @@ -16,7 +16,7 @@ from logging import DEBUG, INFO -from typing import Callable, Type, TypeVar +from typing import Callable, TypeVar import grpc from google.protobuf.message import Message as GrpcMessage @@ -47,7 +47,7 @@ def _handle( msg_container: MessageContainer, - request_type: Type[T], + request_type: type[T], handler: Callable[[T], GrpcMessage], ) -> MessageContainer: req = request_type.FromString(msg_container.grpc_message_content) diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py index 79f1a8f9902b..38f0dfdae299 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py @@ -19,7 +19,8 @@ """ import uuid -from typing import Callable, Iterator +from collections.abc import Iterator +from typing import Callable import grpc from iterators import TimeoutIterator diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py index 5fe0396696ab..476e2914f4d9 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py @@ -15,10 +15,11 @@ """Provides class GrpcBridge.""" +from collections.abc import Iterator from dataclasses import dataclass from enum import Enum from threading import Condition -from typing import Iterator, Optional +from typing import Optional from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py index f9b6b97030f0..6d9e081d8dd4 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_bridge_test.py @@ -17,7 +17,7 @@ import time from threading import Thread -from typing import List, Union +from typing import Union from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 ClientMessage, @@ -32,7 +32,7 @@ def start_worker( - rounds: int, bridge: GrpcBridge, results: List[ClientMessage] + rounds: int, bridge: GrpcBridge, results: list[ClientMessage] ) -> Thread: """Simulate processing loop with five calls.""" @@ -59,7 +59,7 @@ def test_workflow_successful() -> None: """Test full workflow.""" # Prepare rounds = 5 - client_messages_received: List[ClientMessage] = [] + client_messages_received: list[ClientMessage] = [] bridge = GrpcBridge() ins_wrapper_iterator = bridge.ins_wrapper_iterator() @@ -90,7 +90,7 @@ def test_workflow_close() -> None: """ # Prepare rounds = 5 - client_messages_received: List[ClientMessage] = [] + client_messages_received: list[ClientMessage] = [] bridge = GrpcBridge() ins_wrapper_iterator = bridge.ins_wrapper_iterator() @@ -135,7 +135,7 @@ def test_ins_wrapper_iterator_close_while_blocking() -> None: """ # Prepare rounds = 5 - client_messages_received: List[ClientMessage] = [] + client_messages_received: list[ClientMessage] = [] bridge = GrpcBridge() ins_wrapper_iterator = bridge.ins_wrapper_iterator() diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index dd78acb72fb1..b161492000f2 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -17,8 +17,9 @@ import concurrent.futures import sys +from collections.abc import Sequence from logging import ERROR -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import grpc @@ -46,7 +47,7 @@ AddServicerToServerFn = Callable[..., Any] -def valid_certificates(certificates: Tuple[bytes, bytes, bytes]) -> bool: +def valid_certificates(certificates: tuple[bytes, bytes, bytes]) -> bool: """Validate certificates tuple.""" is_valid = ( all(isinstance(certificate, bytes) for certificate in certificates) @@ -65,7 +66,7 @@ def start_grpc_server( # pylint: disable=too-many-arguments max_concurrent_workers: int = 1000, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, - certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + certificates: Optional[tuple[bytes, bytes, bytes]] = None, ) -> grpc.Server: """Create and start a gRPC server running FlowerServiceServicer. @@ -157,16 +158,16 @@ def start_grpc_server( # pylint: disable=too-many-arguments def generic_create_grpc_server( # pylint: disable=too-many-arguments servicer_and_add_fn: Union[ - Tuple[FleetServicer, AddServicerToServerFn], - Tuple[GrpcAdapterServicer, AddServicerToServerFn], - Tuple[FlowerServiceServicer, AddServicerToServerFn], - Tuple[DriverServicer, AddServicerToServerFn], + tuple[FleetServicer, AddServicerToServerFn], + tuple[GrpcAdapterServicer, AddServicerToServerFn], + tuple[FlowerServiceServicer, AddServicerToServerFn], + tuple[DriverServicer, AddServicerToServerFn], ], server_address: str, max_concurrent_workers: int = 1000, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, keepalive_time_ms: int = 210000, - certificates: Optional[Tuple[bytes, bytes, bytes]] = None, + certificates: Optional[tuple[bytes, bytes, bytes]] = None, interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, ) -> grpc.Server: """Create a gRPC server with a single servicer. diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py index 7ff730b17afa..9635993e0ad5 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server_test.py @@ -20,7 +20,7 @@ from contextlib import closing from os.path import abspath, dirname, join from pathlib import Path -from typing import Tuple, cast +from typing import cast from flwr.server.client_manager import SimpleClientManager from flwr.server.superlink.fleet.grpc_bidi.grpc_server import ( @@ -31,7 +31,7 @@ root_dir = dirname(abspath(join(__file__, "../../../../../../.."))) -def load_certificates() -> Tuple[str, str, str]: +def load_certificates() -> tuple[str, str, str]: """Generate and load SSL credentials/certificates. Utility function for loading for SSL-enabled gRPC servertests. diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 2c58d0049849..d836a74bef2e 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -16,8 +16,9 @@ import base64 +from collections.abc import Sequence from logging import INFO, WARNING -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -68,7 +69,7 @@ def _get_value_from_tuples( - key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] + key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] ) -> bytes: value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): diff --git a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py index 64f9ac609998..85f3fa34e0ac 100644 --- a/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py +++ b/src/py/flwr/server/superlink/fleet/message_handler/message_handler.py @@ -16,7 +16,7 @@ import time -from typing import List, Optional +from typing import Optional from uuid import UUID from flwr.common.serde import fab_to_proto, user_config_to_proto @@ -83,7 +83,7 @@ def pull_task_ins(request: PullTaskInsRequest, state: State) -> PullTaskInsRespo node_id: Optional[int] = None if node.anonymous else node.node_id # Retrieve TaskIns from State - task_ins_list: List[TaskIns] = state.get_task_ins(node_id=node_id, limit=1) + task_ins_list: list[TaskIns] = state.get_task_ins(node_id=node_id, limit=1) # Build response response = PullTaskInsResponse( diff --git a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py index cf5ad16f7999..a988252b3ea2 100644 --- a/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py +++ b/src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py @@ -18,7 +18,8 @@ from __future__ import annotations import sys -from typing import Awaitable, Callable, TypeVar +from collections.abc import Awaitable +from typing import Callable, TypeVar from google.protobuf.message import Message as GrpcMessage diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py index a8c671810a51..31129fce1b1b 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/__init__.py @@ -15,17 +15,16 @@ """Simulation Engine Backends.""" import importlib -from typing import Dict, Type from .backend import Backend, BackendConfig is_ray_installed = importlib.util.find_spec("ray") is not None # Mapping of supported backends -supported_backends: Dict[str, Type[Backend]] = {} +supported_backends: dict[str, type[Backend]] = {} # To log backend-specific error message when chosen backend isn't available -error_messages_backends: Dict[str, str] = {} +error_messages_backends: dict[str, str] = {} if is_ray_installed: from .raybackend import RayBackend diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py index 89341c0d238f..38be6032e3a5 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -16,14 +16,14 @@ from abc import ABC, abstractmethod -from typing import Callable, Dict, Tuple +from typing import Callable from flwr.client.client_app import ClientApp from flwr.common.context import Context from flwr.common.message import Message from flwr.common.typing import ConfigsRecordValues -BackendConfig = Dict[str, Dict[str, ConfigsRecordValues]] +BackendConfig = dict[str, dict[str, ConfigsRecordValues]] class Backend(ABC): @@ -62,5 +62,5 @@ def process_message( self, message: Message, context: Context, - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Submit a job to the backend.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index 2024b8760d95..dd79d2ef7f62 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -16,7 +16,7 @@ import sys from logging import DEBUG, ERROR -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Union import ray @@ -31,8 +31,8 @@ from .backend import Backend, BackendConfig -ClientResourcesDict = Dict[str, Union[int, float]] -ActorArgsDict = Dict[str, Union[int, float, Callable[[], None]]] +ClientResourcesDict = dict[str, Union[int, float]] +ActorArgsDict = dict[str, Union[int, float, Callable[[], None]]] class RayBackend(Backend): @@ -101,7 +101,7 @@ def _validate_actor_arguments(self, config: BackendConfig) -> ActorArgsDict: def init_ray(self, backend_config: BackendConfig) -> None: """Intialises Ray if not already initialised.""" if not ray.is_initialized(): - ray_init_args: Dict[ + ray_init_args: dict[ str, ConfigsRecordValues, ] = {} @@ -144,7 +144,7 @@ def process_message( self, message: Message, context: Context, - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Run ClientApp that process a given message. Return output message and updated context. diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index cdb11401c29c..1cbdc230c938 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -15,7 +15,7 @@ """Test for Ray backend for the Fleet API using the Simulation Engine.""" from math import pi -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Union from unittest import TestCase import ray @@ -47,7 +47,7 @@ class DummyClient(NumPyClient): def __init__(self, state: RecordSet) -> None: self.client_state = state - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return properties by doing a simple calculation.""" result = float(config["factor"]) * pi @@ -69,8 +69,8 @@ def _load_app() -> ClientApp: def backend_build_process_and_termination( backend: RayBackend, app_fn: Callable[[], ClientApp], - process_args: Optional[Tuple[Message, Context]] = None, -) -> Union[Tuple[Message, Context], None]: + process_args: Optional[tuple[Message, Context]] = None, +) -> Union[tuple[Message, Context], None]: """Build, process job and terminate RayBackend.""" backend.build(app_fn) to_return = None @@ -83,7 +83,7 @@ def backend_build_process_and_termination( return to_return -def _create_message_and_context() -> Tuple[Message, Context, float]: +def _create_message_and_context() -> tuple[Message, Context, float]: # Construct a Message mult_factor = 2024 diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 165c2de73c21..8f4e18e14e28 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -24,7 +24,7 @@ from pathlib import Path from queue import Empty, Queue from time import sleep -from typing import Callable, Dict, Optional +from typing import Callable, Optional from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.clientapp.utils import get_load_client_app_fn @@ -44,7 +44,7 @@ from .backend import Backend, error_messages_backends, supported_backends -NodeToPartitionMapping = Dict[int, int] +NodeToPartitionMapping = dict[int, int] def _register_nodes( @@ -64,9 +64,9 @@ def _register_node_states( nodes_mapping: NodeToPartitionMapping, run: Run, app_dir: Optional[str] = None, -) -> Dict[int, NodeState]: +) -> dict[int, NodeState]: """Create NodeState objects and pre-register the context for the run.""" - node_states: Dict[int, NodeState] = {} + node_states: dict[int, NodeState] = {} num_partitions = len(set(nodes_mapping.values())) for node_id, partition_id in nodes_mapping.items(): node_states[node_id] = NodeState( @@ -89,7 +89,7 @@ def _register_node_states( def worker( taskins_queue: "Queue[TaskIns]", taskres_queue: "Queue[TaskRes]", - node_states: Dict[int, NodeState], + node_states: dict[int, NodeState], backend: Backend, f_stop: threading.Event, ) -> None: @@ -177,7 +177,7 @@ def run_api( backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, state_factory: StateFactory, - node_states: Dict[int, NodeState], + node_states: dict[int, NodeState], f_stop: threading.Event, ) -> None: """Run the VCE.""" diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 76e8ac9156d2..1cc3a8f128b6 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -22,7 +22,7 @@ from math import pi from pathlib import Path from time import sleep -from typing import Dict, Optional, Set, Tuple +from typing import Optional from unittest import TestCase from uuid import UUID @@ -57,7 +57,7 @@ class DummyClient(NumPyClient): def __init__(self, state: RecordSet) -> None: self.client_state = state - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return properties by doing a simple calculation.""" result = float(config["factor"]) * pi @@ -86,7 +86,7 @@ def terminate_simulation(f_stop: threading.Event, sleep_duration: int) -> None: def init_state_factory_nodes_mapping( num_nodes: int, num_messages: int, -) -> Tuple[StateFactory, NodeToPartitionMapping, Dict[UUID, float]]: +) -> tuple[StateFactory, NodeToPartitionMapping, dict[UUID, float]]: """Instatiate StateFactory, register nodes and pre-insert messages in the state.""" # Register a state and a run_id in it run_id = 1234 @@ -110,7 +110,7 @@ def register_messages_into_state( nodes_mapping: NodeToPartitionMapping, run_id: int, num_messages: int, -) -> Dict[UUID, float]: +) -> dict[UUID, float]: """Register `num_messages` into the state factory.""" state: InMemoryState = state_factory.state() # type: ignore state.run_ids[run_id] = Run( @@ -123,7 +123,7 @@ def register_messages_into_state( # Artificially add TaskIns to state so they can be processed # by the Simulation Engine logic nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes - task_ids: Set[UUID] = set() # so we can retrieve them later + task_ids: set[UUID] = set() # so we can retrieve them later expected_results = {} for i in range(num_messages): dst_node_id = next(nodes_cycle) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index c87ba86e47e7..e34d15374350 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -18,7 +18,7 @@ import threading import time from logging import ERROR -from typing import Dict, List, Optional, Set, Tuple +from typing import Optional from uuid import UUID, uuid4 from flwr.common import log, now @@ -37,15 +37,15 @@ class InMemoryState(State): # pylint: disable=R0902,R0904 def __init__(self) -> None: # Map node_id to (online_until, ping_interval) - self.node_ids: Dict[int, Tuple[float, float]] = {} - self.public_key_to_node_id: Dict[bytes, int] = {} + self.node_ids: dict[int, tuple[float, float]] = {} + self.public_key_to_node_id: dict[bytes, int] = {} # Map run_id to (fab_id, fab_version) - self.run_ids: Dict[int, Run] = {} - self.task_ins_store: Dict[UUID, TaskIns] = {} - self.task_res_store: Dict[UUID, TaskRes] = {} + self.run_ids: dict[int, Run] = {} + self.task_ins_store: dict[UUID, TaskIns] = {} + self.task_res_store: dict[UUID, TaskRes] = {} - self.node_public_keys: Set[bytes] = set() + self.node_public_keys: set[bytes] = set() self.server_public_key: Optional[bytes] = None self.server_private_key: Optional[bytes] = None @@ -76,13 +76,13 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: def get_task_ins( self, node_id: Optional[int], limit: Optional[int] - ) -> List[TaskIns]: + ) -> list[TaskIns]: """Get all TaskIns that have not been delivered yet.""" if limit is not None and limit < 1: raise AssertionError("`limit` must be >= 1") # Find TaskIns for node_id that were not delivered yet - task_ins_list: List[TaskIns] = [] + task_ins_list: list[TaskIns] = [] with self.lock: for _, task_ins in self.task_ins_store.items(): # pylint: disable=too-many-boolean-expressions @@ -133,15 +133,15 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Return the new task_id return task_id - def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]: + def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]: """Get all TaskRes that have not been delivered yet.""" if limit is not None and limit < 1: raise AssertionError("`limit` must be >= 1") with self.lock: # Find TaskRes that were not delivered yet - task_res_list: List[TaskRes] = [] - replied_task_ids: Set[UUID] = set() + task_res_list: list[TaskRes] = [] + replied_task_ids: set[UUID] = set() for _, task_res in self.task_res_store.items(): reply_to = UUID(task_res.task.ancestry[0]) if reply_to in task_ids and task_res.task.delivered_at == "": @@ -175,10 +175,10 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe # Return TaskRes return task_res_list - def delete_tasks(self, task_ids: Set[UUID]) -> None: + def delete_tasks(self, task_ids: set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" - task_ins_to_be_deleted: Set[UUID] = set() - task_res_to_be_deleted: Set[UUID] = set() + task_ins_to_be_deleted: set[UUID] = set() + task_res_to_be_deleted: set[UUID] = set() with self.lock: for task_ins_id in task_ids: @@ -253,7 +253,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: del self.node_ids[node_id] - def get_nodes(self, run_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> set[int]: """Return all available nodes. Constraints @@ -318,7 +318,7 @@ def get_server_public_key(self) -> Optional[bytes]: """Retrieve `server_public_key` in urlsafe bytes.""" return self.server_public_key - def store_node_public_keys(self, public_keys: Set[bytes]) -> None: + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in state.""" with self.lock: self.node_public_keys = public_keys @@ -328,7 +328,7 @@ def store_node_public_key(self, public_key: bytes) -> None: with self.lock: self.node_public_keys.add(public_key) - def get_node_public_keys(self) -> Set[bytes]: + def get_node_public_keys(self) -> set[bytes]: """Retrieve all currently stored `node_public_keys` as a set.""" return self.node_public_keys diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index daa211560912..4bb31fa6cea5 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -19,8 +19,9 @@ import re import sqlite3 import time +from collections.abc import Sequence from logging import DEBUG, ERROR -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import Any, Optional, Union, cast from uuid import UUID, uuid4 from flwr.common import log, now @@ -110,7 +111,7 @@ ); """ -DictOrTuple = Union[Tuple[Any, ...], Dict[str, Any]] +DictOrTuple = Union[tuple[Any, ...], dict[str, Any]] class SqliteState(State): # pylint: disable=R0904 @@ -131,7 +132,7 @@ def __init__( self.database_path = database_path self.conn: Optional[sqlite3.Connection] = None - def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: + def initialize(self, log_queries: bool = False) -> list[tuple[str]]: """Create tables if they don't exist yet. Parameters @@ -162,7 +163,7 @@ def query( self, query: str, data: Optional[Union[Sequence[DictOrTuple], DictOrTuple]] = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: raise AttributeError("State is not initialized.") @@ -237,7 +238,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: def get_task_ins( self, node_id: Optional[int], limit: Optional[int] - ) -> List[TaskIns]: + ) -> list[TaskIns]: """Get undelivered TaskIns for one node (either anonymous or with ID). Usually, the Fleet API calls this for Nodes planning to work on one or more @@ -271,7 +272,7 @@ def get_task_ins( ) raise AssertionError(msg) - data: Dict[str, Union[str, int]] = {} + data: dict[str, Union[str, int]] = {} if node_id is None: # Retrieve all anonymous Tasks @@ -367,7 +368,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: return task_id # pylint: disable-next=R0914 - def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]: + def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]: """Get TaskRes for task_ids. Usually, the Driver API calls this method to get results for instructions it has @@ -397,7 +398,7 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe AND delivered_at = "" """ - data: Dict[str, Union[str, float, int]] = {} + data: dict[str, Union[str, float, int]] = {} if limit is not None: query += " LIMIT :limit" @@ -435,7 +436,7 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe # 1. Query: Fetch consumer_node_id of remaining task_ids # Assume the ancestry field only contains one element data.clear() - replied_task_ids: Set[UUID] = {UUID(str(row["ancestry"])) for row in rows} + replied_task_ids: set[UUID] = {UUID(str(row["ancestry"])) for row in rows} remaining_task_ids = task_ids - replied_task_ids placeholders = ",".join([f":id_{i}" for i in range(len(remaining_task_ids))]) query = f""" @@ -499,10 +500,10 @@ def num_task_res(self) -> int: """ query = "SELECT count(*) AS num FROM task_res;" rows = self.query(query) - result: Dict[str, int] = rows[0] + result: dict[str, int] = rows[0] return result["num"] - def delete_tasks(self, task_ids: Set[UUID]) -> None: + def delete_tasks(self, task_ids: set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" ids = list(task_ids) if len(ids) == 0: @@ -588,7 +589,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: except KeyError as exc: log(ERROR, {"query": query, "data": params, "exception": exc}) - def get_nodes(self, run_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -604,7 +605,7 @@ def get_nodes(self, run_id: int) -> Set[int]: # Get nodes query = "SELECT node_id FROM node WHERE online_until > ?;" rows = self.query(query, (time.time(),)) - result: Set[int] = {row["node_id"] for row in rows} + result: set[int] = {row["node_id"] for row in rows} return result def get_node_id(self, node_public_key: bytes) -> Optional[int]: @@ -684,7 +685,7 @@ def get_server_public_key(self) -> Optional[bytes]: public_key = None return public_key - def store_node_public_keys(self, public_keys: Set[bytes]) -> None: + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in state.""" query = "INSERT INTO public_key (public_key) VALUES (?)" data = [(key,) for key in public_keys] @@ -695,11 +696,11 @@ def store_node_public_key(self, public_key: bytes) -> None: query = "INSERT INTO public_key (public_key) VALUES (:public_key)" self.query(query, {"public_key": public_key}) - def get_node_public_keys(self) -> Set[bytes]: + def get_node_public_keys(self) -> set[bytes]: """Retrieve all currently stored `node_public_keys` as a set.""" query = "SELECT public_key FROM public_key" rows = self.query(query) - result: Set[bytes] = {row["public_key"] for row in rows} + result: set[bytes] = {row["public_key"] for row in rows} return result def get_run(self, run_id: int) -> Optional[Run]: @@ -733,7 +734,7 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: def dict_factory( cursor: sqlite3.Cursor, row: sqlite3.Row, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Turn SQLite results into dicts. Less efficent for retrival of large amounts of data but easier to use. @@ -742,7 +743,7 @@ def dict_factory( return dict(zip(fields, row)) -def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: +def task_ins_to_dict(task_msg: TaskIns) -> dict[str, Any]: """Transform TaskIns to dict.""" result = { "task_id": task_msg.task_id, @@ -763,7 +764,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: return result -def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: +def task_res_to_dict(task_msg: TaskRes) -> dict[str, Any]: """Transform TaskRes to dict.""" result = { "task_id": task_msg.task_id, @@ -784,7 +785,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: return result -def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: +def dict_to_task_ins(task_dict: dict[str, Any]) -> TaskIns: """Turn task_dict into protobuf message.""" recordset = RecordSet() recordset.ParseFromString(task_dict["recordset"]) @@ -814,7 +815,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: return result -def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: +def dict_to_task_res(task_dict: dict[str, Any]) -> TaskRes: """Turn task_dict into protobuf message.""" recordset = RecordSet() recordset.ParseFromString(task_dict["recordset"]) diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index fea53105b23f..39da052fb0aa 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -16,7 +16,7 @@ import abc -from typing import List, Optional, Set +from typing import Optional from uuid import UUID from flwr.common.typing import Run, UserConfig @@ -51,7 +51,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: @abc.abstractmethod def get_task_ins( self, node_id: Optional[int], limit: Optional[int] - ) -> List[TaskIns]: + ) -> list[TaskIns]: """Get TaskIns optionally filtered by node_id. Usually, the Fleet API calls this for Nodes planning to work on one or more @@ -98,7 +98,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: """ @abc.abstractmethod - def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRes]: + def get_task_res(self, task_ids: set[UUID], limit: Optional[int]) -> list[TaskRes]: """Get TaskRes for task_ids. Usually, the Driver API calls this method to get results for instructions it has @@ -129,7 +129,7 @@ def num_task_res(self) -> int: """ @abc.abstractmethod - def delete_tasks(self, task_ids: Set[UUID]) -> None: + def delete_tasks(self, task_ids: set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" @abc.abstractmethod @@ -143,7 +143,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: """Remove `node_id` from state.""" @abc.abstractmethod - def get_nodes(self, run_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> set[int]: """Retrieve all currently stored node IDs as a set. Constraints @@ -199,7 +199,7 @@ def get_server_public_key(self) -> Optional[bytes]: """Retrieve `server_public_key` in urlsafe bytes.""" @abc.abstractmethod - def store_node_public_keys(self, public_keys: Set[bytes]) -> None: + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in state.""" @abc.abstractmethod @@ -207,7 +207,7 @@ def store_node_public_key(self, public_key: bytes) -> None: """Store a `node_public_key` in state.""" @abc.abstractmethod - def get_node_public_keys(self) -> Set[bytes]: + def get_node_public_keys(self) -> set[bytes]: """Retrieve all currently stored `node_public_keys` as a set.""" @abc.abstractmethod diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 0cf30a42ca2c..42c0768f1c7d 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -20,7 +20,6 @@ import unittest from abc import abstractmethod from datetime import datetime, timezone -from typing import List from unittest.mock import patch from uuid import uuid4 @@ -655,7 +654,7 @@ def test_node_unavailable_error(self) -> None: # Execute current_time = time.time() - task_res_list: List[TaskRes] = [] + task_res_list: list[TaskRes] = [] with patch("time.time", side_effect=lambda: current_time + 50): task_res_list = state.get_task_res({task_id_0, task_id_1}, limit=None) @@ -698,7 +697,7 @@ def create_task_ins( def create_task_res( producer_node_id: int, anonymous: bool, - ancestry: List[str], + ancestry: list[str], run_id: int, ) -> TaskRes: """Create a TaskRes for testing.""" diff --git a/src/py/flwr/server/utils/tensorboard.py b/src/py/flwr/server/utils/tensorboard.py index 5d38fc159657..281e8949c53c 100644 --- a/src/py/flwr/server/utils/tensorboard.py +++ b/src/py/flwr/server/utils/tensorboard.py @@ -18,7 +18,7 @@ import os from datetime import datetime from logging import WARN -from typing import Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Callable, Optional, Union, cast from flwr.common import EvaluateRes, Scalar from flwr.common.logger import log @@ -92,9 +92,9 @@ class TBWrapper(strategy_class): # type: ignore def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> tuple[Optional[float], dict[str, Scalar]]: """Hooks into aggregate_evaluate for TensorBoard logging purpose.""" # Execute decorated function and extract results for logging # They will be returned at the end of this function but also diff --git a/src/py/flwr/server/utils/validator.py b/src/py/flwr/server/utils/validator.py index c0b0ec85761c..fb3d0425db86 100644 --- a/src/py/flwr/server/utils/validator.py +++ b/src/py/flwr/server/utils/validator.py @@ -15,13 +15,13 @@ """Validators.""" -from typing import List, Union +from typing import Union from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 # pylint: disable-next=too-many-branches,too-many-statements -def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> List[str]: +def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str]: """Validate a TaskIns or TaskRes.""" validation_errors = [] diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index 61fe094c23d4..20162883efea 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -17,7 +17,6 @@ import time import unittest -from typing import List, Tuple from flwr.common import DEFAULT_TTL from flwr.proto.node_pb2 import Node # pylint: disable=E0611 @@ -52,12 +51,12 @@ def test_is_valid_task_res(self) -> None: """Test is_valid task_res.""" # Prepare # (producer_node_id, anonymous, ancestry) - valid_res: List[Tuple[int, bool, List[str]]] = [ + valid_res: list[tuple[int, bool, list[str]]] = [ (0, True, ["1"]), (1, False, ["1"]), ] - invalid_res: List[Tuple[int, bool, List[str]]] = [ + invalid_res: list[tuple[int, bool, list[str]]] = [ (0, False, []), (0, False, ["1"]), (0, True, []), @@ -110,7 +109,7 @@ def create_task_ins( def create_task_res( producer_node_id: int, anonymous: bool, - ancestry: List[str], + ancestry: list[str], ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index 82d8d5d4ccb6..484a747292d5 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -18,7 +18,7 @@ import io import timeit from logging import INFO, WARN -from typing import List, Optional, Tuple, Union, cast +from typing import Optional, Union, cast import flwr.common.recordset_compat as compat from flwr.common import ( @@ -276,8 +276,8 @@ def default_fit_workflow( # pylint: disable=R0914 ) # Aggregate training results - results: List[Tuple[ClientProxy, FitRes]] = [] - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + results: list[tuple[ClientProxy, FitRes]] = [] + failures: list[Union[tuple[ClientProxy, FitRes], BaseException]] = [] for msg in messages: if msg.has_content(): proxy = node_id_to_proxy[msg.metadata.src_node_id] @@ -362,8 +362,8 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: ) # Aggregate the evaluation results - results: List[Tuple[ClientProxy, EvaluateRes]] = [] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + results: list[tuple[ClientProxy, EvaluateRes]] = [] + failures: list[Union[tuple[ClientProxy, EvaluateRes], BaseException]] = [] for msg in messages: if msg.has_content(): proxy = node_id_to_proxy[msg.metadata.src_node_id] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index 322e32ed5019..d84a5496dfe1 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -18,7 +18,7 @@ import random from dataclasses import dataclass, field from logging import DEBUG, ERROR, INFO, WARN -from typing import Dict, List, Optional, Set, Tuple, Union, cast +from typing import Optional, Union, cast import flwr.common.recordset_compat as compat from flwr.common import ( @@ -65,23 +65,23 @@ class WorkflowState: # pylint: disable=R0902 """The state of the SecAgg+ protocol.""" - nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict) - nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict) - sampled_node_ids: Set[int] = field(default_factory=set) - active_node_ids: Set[int] = field(default_factory=set) + nid_to_proxies: dict[int, ClientProxy] = field(default_factory=dict) + nid_to_fitins: dict[int, RecordSet] = field(default_factory=dict) + sampled_node_ids: set[int] = field(default_factory=set) + active_node_ids: set[int] = field(default_factory=set) num_shares: int = 0 threshold: int = 0 clipping_range: float = 0.0 quantization_range: int = 0 mod_range: int = 0 max_weight: float = 0.0 - nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict) - nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict) - forward_srcs: Dict[int, List[int]] = field(default_factory=dict) - forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict) + nid_to_neighbours: dict[int, set[int]] = field(default_factory=dict) + nid_to_publickeys: dict[int, list[bytes]] = field(default_factory=dict) + forward_srcs: dict[int, list[int]] = field(default_factory=dict) + forward_ciphertexts: dict[int, list[bytes]] = field(default_factory=dict) aggregate_ndarrays: NDArrays = field(default_factory=list) - legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list) - failures: List[Exception] = field(default_factory=list) + legacy_results: list[tuple[ClientProxy, FitRes]] = field(default_factory=list) + failures: list[Exception] = field(default_factory=list) class SecAggPlusWorkflow: @@ -444,13 +444,13 @@ def make(nid: int) -> Message: ) # Build forward packet list dictionary - srcs: List[int] = [] - dsts: List[int] = [] - ciphertexts: List[bytes] = [] - fwd_ciphertexts: Dict[int, List[bytes]] = { + srcs: list[int] = [] + dsts: list[int] = [] + ciphertexts: list[bytes] = [] + fwd_ciphertexts: dict[int, list[bytes]] = { nid: [] for nid in state.active_node_ids } # dest node ID -> list of ciphertexts - fwd_srcs: Dict[int, List[int]] = { + fwd_srcs: dict[int, list[int]] = { nid: [] for nid in state.active_node_ids } # dest node ID -> list of src node IDs for msg in msgs: @@ -459,8 +459,8 @@ def make(nid: int) -> Message: continue node_id = msg.metadata.src_node_id res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] - dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST]) - ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST]) + dst_lst = cast(list[int], res_dict[Key.DESTINATION_LIST]) + ctxt_lst = cast(list[bytes], res_dict[Key.CIPHERTEXT_LIST]) srcs += [node_id] * len(dst_lst) dsts += dst_lst ciphertexts += ctxt_lst @@ -525,7 +525,7 @@ def make(nid: int) -> Message: state.failures.append(Exception(msg.error)) continue res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] - bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS]) + bytes_list = cast(list[bytes], res_dict[Key.MASKED_PARAMETERS]) client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list] if masked_vector is None: masked_vector = client_masked_vec @@ -592,7 +592,7 @@ def make(nid: int) -> Message: ) # Build collected shares dict - collected_shares_dict: Dict[int, List[bytes]] = {} + collected_shares_dict: dict[int, list[bytes]] = {} for nid in state.sampled_node_ids: collected_shares_dict[nid] = [] for msg in msgs: @@ -600,8 +600,8 @@ def make(nid: int) -> Message: state.failures.append(Exception(msg.error)) continue res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS] - nids = cast(List[int], res_dict[Key.NODE_ID_LIST]) - shares = cast(List[bytes], res_dict[Key.SHARE_LIST]) + nids = cast(list[int], res_dict[Key.NODE_ID_LIST]) + shares = cast(list[bytes], res_dict[Key.SHARE_LIST]) for owner_nid, share in zip(nids, shares): collected_shares_dict[owner_nid].append(share) diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index 973a9a89e652..0070d75c53dc 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -22,7 +22,7 @@ import traceback import warnings from logging import ERROR, INFO -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Optional, Union import ray from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy @@ -72,7 +72,7 @@ """ -NodeToPartitionMapping = Dict[int, int] +NodeToPartitionMapping = dict[int, int] def _create_node_id_to_partition_mapping( @@ -94,16 +94,16 @@ def start_simulation( *, client_fn: ClientFnExt, num_clients: int, - clients_ids: Optional[List[str]] = None, # UNSUPPORTED, WILL BE REMOVED - client_resources: Optional[Dict[str, float]] = None, + clients_ids: Optional[list[str]] = None, # UNSUPPORTED, WILL BE REMOVED + client_resources: Optional[dict[str, float]] = None, server: Optional[Server] = None, config: Optional[ServerConfig] = None, strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, - ray_init_args: Optional[Dict[str, Any]] = None, + ray_init_args: Optional[dict[str, Any]] = None, keep_initialised: Optional[bool] = False, - actor_type: Type[VirtualClientEngineActor] = ClientAppActor, - actor_kwargs: Optional[Dict[str, Any]] = None, + actor_type: type[VirtualClientEngineActor] = ClientAppActor, + actor_kwargs: Optional[dict[str, Any]] = None, actor_scheduling: Union[str, NodeAffinitySchedulingStrategy] = "DEFAULT", ) -> History: """Start a Ray-based Flower simulation server. @@ -279,7 +279,7 @@ def start_simulation( # An actor factory. This is called N times to add N actors # to the pool. If at some point the pool can accommodate more actors # this will be called again. - def create_actor_fn() -> Type[VirtualClientEngineActor]: + def create_actor_fn() -> type[VirtualClientEngineActor]: return actor_type.options( # type: ignore **client_resources, scheduling_strategy=actor_scheduling, diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 698eb78f2aef..4fb48a99b689 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -17,7 +17,7 @@ import threading from abc import ABC from logging import DEBUG, ERROR, WARNING -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import ray from ray import ObjectRef @@ -44,7 +44,7 @@ def run( message: Message, cid: str, context: Context, - ) -> Tuple[str, Message, Context]: + ) -> tuple[str, Message, Context]: """Run a client run.""" # Pass message through ClientApp and return a message # return also cid which is needed to ensure results @@ -81,7 +81,7 @@ def __init__(self, on_actor_init_fn: Optional[Callable[[], None]] = None) -> Non on_actor_init_fn() -def pool_size_from_resources(client_resources: Dict[str, Union[int, float]]) -> int: +def pool_size_from_resources(client_resources: dict[str, Union[int, float]]) -> int: """Calculate number of Actors that fit in the cluster. For this we consider the resources available on each node and those required per @@ -162,9 +162,9 @@ class VirtualClientEngineActorPool(ActorPool): def __init__( self, - create_actor_fn: Callable[[], Type[VirtualClientEngineActor]], - client_resources: Dict[str, Union[int, float]], - actor_list: Optional[List[Type[VirtualClientEngineActor]]] = None, + create_actor_fn: Callable[[], type[VirtualClientEngineActor]], + client_resources: dict[str, Union[int, float]], + actor_list: Optional[list[type[VirtualClientEngineActor]]] = None, ): self.client_resources = client_resources self.create_actor_fn = create_actor_fn @@ -183,10 +183,10 @@ def __init__( # A dict that maps cid to another dict containing: a reference to the remote job # and its status (i.e. whether it is ready or not) - self._cid_to_future: Dict[ - str, Dict[str, Union[bool, Optional[ObjectRef[Any]]]] + self._cid_to_future: dict[ + str, dict[str, Union[bool, Optional[ObjectRef[Any]]]] ] = {} - self.actor_to_remove: Set[str] = set() # a set + self.actor_to_remove: set[str] = set() # a set self.num_actors = len(actors) self.lock = threading.RLock() @@ -210,7 +210,7 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> None: + def submit(self, fn: Any, value: tuple[ClientAppFn, Message, str, Context]) -> None: """Take an idle actor and assign it to run a client app and Message. Submit a job to an actor by first removing it from the list of idle actors, then @@ -220,7 +220,7 @@ def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> N actor = self._idle_actors.pop() if self._check_and_remove_actor_from_pool(actor): future = fn(actor, app_fn, mssg, cid, context) - future_key = tuple(future) if isinstance(future, List) else future + future_key = tuple(future) if isinstance(future, list) else future self._future_to_actor[future_key] = (self._next_task_index, actor, cid) self._next_task_index += 1 @@ -228,7 +228,7 @@ def submit(self, fn: Any, value: Tuple[ClientAppFn, Message, str, Context]) -> N self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] + self, actor_fn: Any, job: tuple[ClientAppFn, Message, str, Context] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -268,7 +268,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[Message, Context]: + def _fetch_future_result(self, cid: str) -> tuple[Message, Context]: """Fetch result and updated context for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -382,7 +382,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -403,14 +403,14 @@ class BasicActorPool: def __init__( self, - actor_type: Type[VirtualClientEngineActor], - client_resources: Dict[str, Union[int, float]], - actor_kwargs: Dict[str, Any], + actor_type: type[VirtualClientEngineActor], + client_resources: dict[str, Union[int, float]], + actor_kwargs: dict[str, Any], ): self.client_resources = client_resources # Queue of idle actors - self.pool: List[VirtualClientEngineActor] = [] + self.pool: list[VirtualClientEngineActor] = [] self.num_actors = 0 # Resolve arguments to pass during actor init @@ -424,7 +424,7 @@ def __init__( # Figure out how many actors can be created given the cluster resources # and the resources the user indicates each VirtualClient will need self.actors_capacity = pool_size_from_resources(client_resources) - self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {} + self._future_to_actor: dict[Any, VirtualClientEngineActor] = {} def is_actor_available(self) -> bool: """Return true if there is an idle actor.""" @@ -450,7 +450,7 @@ def terminate_all_actors(self) -> None: log(DEBUG, "Terminated %i actors", num_terminated) def submit( - self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] + self, actor_fn: Any, job: tuple[ClientAppFn, Message, str, Context] ) -> Any: """On idle actor, submit job and return future.""" # Remove idle actor from pool @@ -470,7 +470,7 @@ def add_actor_back_to_pool(self, future: Any) -> None: def fetch_result_and_return_actor_to_pool( self, future: Any - ) -> Tuple[Message, Context]: + ) -> tuple[Message, Context]: """Pull result given a future and add actor back to pool.""" # Retrieve result for object store # Instead of doing ray.get(future) we await it diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index 1c2aa455d9cd..ce0ef46d135f 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -17,7 +17,6 @@ from math import pi from random import shuffle -from typing import Dict, List, Tuple, Type import ray @@ -60,7 +59,7 @@ def __init__(self, node_id: int, state: RecordSet) -> None: self.node_id = node_id self.client_state = state - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """Return properties by doing a simple calculation.""" result = self.node_id * pi # store something in context @@ -76,14 +75,14 @@ def get_dummy_client(context: Context) -> Client: def prep( - actor_type: Type[VirtualClientEngineActor] = ClientAppActor, -) -> Tuple[ - List[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping + actor_type: type[VirtualClientEngineActor] = ClientAppActor, +) -> tuple[ + list[RayActorClientProxy], VirtualClientEngineActorPool, NodeToPartitionMapping ]: # pragma: no cover """Prepare ClientProxies and pool for tests.""" client_resources = {"num_cpus": 1, "num_gpus": 0.0} - def create_actor_fn() -> Type[VirtualClientEngineActor]: + def create_actor_fn() -> type[VirtualClientEngineActor]: return actor_type.options(**client_resources).remote() # type: ignore # Create actor pool @@ -195,7 +194,7 @@ def test_cid_consistency_without_proxies() -> None: node_ids = list(mapping.keys()) # register node states - node_states: Dict[int, NodeState] = {} + node_states: dict[int, NodeState] = {} for node_id, partition_id in mapping.items(): node_states[node_id] = NodeState( node_id=node_id, diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index be6410dcbd6b..2d29629c4f01 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -25,7 +25,7 @@ from logging import DEBUG, ERROR, INFO, WARNING from pathlib import Path from time import sleep -from typing import Any, List, Optional +from typing import Any, Optional from flwr.cli.config_utils import load_and_validate from flwr.client import ClientApp @@ -56,7 +56,7 @@ def _check_args_do_not_interfere(args: Namespace) -> bool: mode_one_args = ["app", "run_config"] mode_two_args = ["client_app", "server_app"] - def _resolve_message(conflict_keys: List[str]) -> str: + def _resolve_message(conflict_keys: list[str]) -> str: return ",".join([f"`--{key}`".replace("_", "-") for key in conflict_keys]) # When passing `--app`, `--app-dir` is ignored diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index 36f781706146..c00aa0f88e7b 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -18,7 +18,7 @@ import sys from logging import INFO, WARN from pathlib import Path -from typing import Optional, Tuple +from typing import Optional import grpc @@ -130,7 +130,7 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser: def _try_obtain_certificates( args: argparse.Namespace, -) -> Optional[Tuple[bytes, bytes, bytes]]: +) -> Optional[tuple[bytes, bytes, bytes]]: # Obtain certificates if args.insecure: log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.") diff --git a/src/py/flwr/superexec/exec_grpc.py b/src/py/flwr/superexec/exec_grpc.py index a32ebc1b3e35..017395bc8002 100644 --- a/src/py/flwr/superexec/exec_grpc.py +++ b/src/py/flwr/superexec/exec_grpc.py @@ -15,7 +15,7 @@ """SuperExec gRPC API.""" from logging import INFO -from typing import Optional, Tuple +from typing import Optional import grpc @@ -32,7 +32,7 @@ def run_superexec_api_grpc( address: str, executor: Executor, - certificates: Optional[Tuple[bytes, bytes, bytes]], + certificates: Optional[tuple[bytes, bytes, bytes]], config: UserConfig, ) -> grpc.Server: """Run SuperExec API (gRPC, request-response).""" diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index dda3e96994de..5b729dbc2b8e 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -15,8 +15,9 @@ """SuperExec API servicer.""" +from collections.abc import Generator from logging import ERROR, INFO -from typing import Any, Dict, Generator +from typing import Any import grpc @@ -38,7 +39,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer): def __init__(self, executor: Executor) -> None: self.executor = executor - self.runs: Dict[int, RunTracker] = {} + self.runs: dict[int, RunTracker] = {} def StartRun( self, request: StartRunRequest, context: grpc.ServicerContext