From e6e95b13c01ba206e106e06ad189e9957f129c67 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 19 Jun 2024 12:44:07 +0200 Subject: [PATCH] ci(framework:skip) Add test and format for `__init__.__all__` --- dev/format.sh | 1 + src/py/flwr/client/__init__.py | 2 +- src/py/flwr/client/mod/__init__.py | 6 +-- src/py/flwr/common/__init__.py | 24 ++++----- src/py/flwr/common/record/__init__.py | 2 +- src/py/flwr/server/__init__.py | 4 +- src/py/flwr/server/strategy/__init__.py | 4 +- src/py/flwr/simulation/__init__.py | 5 +- src/py/flwr_tool/init_py_check.py | 72 +++++++++++++++++++++++-- src/py/flwr_tool/init_py_fix.py | 69 ++++++++++++++++++++++++ 10 files changed, 163 insertions(+), 26 deletions(-) create mode 100755 src/py/flwr_tool/init_py_fix.py diff --git a/dev/format.sh b/dev/format.sh index 05248b5eed3d..b9e3b00dffe1 100755 --- a/dev/format.sh +++ b/dev/format.sh @@ -3,6 +3,7 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ # Python +python -m flwr_tool.init_py_fix src/py/flwr python -m isort --skip src/py/flwr/proto src/py python -m black -q --exclude src/py/flwr/proto src/py python -m docformatter -i -r src/py/flwr -e src/py/flwr/proto diff --git a/src/py/flwr/client/__init__.py b/src/py/flwr/client/__init__.py index b4da71302cb4..58fd94448586 100644 --- a/src/py/flwr/client/__init__.py +++ b/src/py/flwr/client/__init__.py @@ -28,8 +28,8 @@ "Client", "ClientApp", "ClientFn", - "mod", "NumPyClient", + "mod", "run_client_app", "run_supernode", "start_client", diff --git a/src/py/flwr/client/mod/__init__.py b/src/py/flwr/client/mod/__init__.py index 1774e4b8ca0a..0b4cf6488421 100644 --- a/src/py/flwr/client/mod/__init__.py +++ b/src/py/flwr/client/mod/__init__.py @@ -22,12 +22,12 @@ from .utils import make_ffn __all__ = [ + "LocalDpMod", "adaptiveclipping_mod", "fixedclipping_mod", - "LocalDpMod", "make_ffn", - "secagg_mod", - "secaggplus_mod", "message_size_mod", "parameters_size_mod", + "secagg_mod", + "secaggplus_mod", ] diff --git a/src/py/flwr/common/__init__.py b/src/py/flwr/common/__init__.py index 2fb98c82dd6f..bbdf48425e0a 100644 --- a/src/py/flwr/common/__init__.py +++ b/src/py/flwr/common/__init__.py @@ -63,43 +63,34 @@ __all__ = [ "Array", - "array_from_numpy", - "bytes_to_ndarray", "ClientMessage", "Code", "Config", "ConfigsRecord", - "configure", "Context", + "DEFAULT_TTL", "DisconnectRes", + "Error", "EvaluateIns", "EvaluateRes", - "event", "EventType", "FitIns", "FitRes", - "Error", + "GRPC_MAX_MESSAGE_LENGTH", "GetParametersIns", "GetParametersRes", "GetPropertiesIns", "GetPropertiesRes", - "GRPC_MAX_MESSAGE_LENGTH", - "log", "Message", "MessageType", "MessageTypeLegacy", - "DEFAULT_TTL", "Metadata", "Metrics", "MetricsAggregationFn", "MetricsRecord", - "ndarray_to_bytes", - "now", "NDArray", "NDArrays", - "ndarrays_to_parameters", "Parameters", - "parameters_to_ndarrays", "ParametersRecord", "Properties", "ReconnectIns", @@ -107,4 +98,13 @@ "Scalar", "ServerMessage", "Status", + "array_from_numpy", + "bytes_to_ndarray", + "configure", + "event", + "log", + "ndarray_to_bytes", + "ndarrays_to_parameters", + "now", + "parameters_to_ndarrays", ] diff --git a/src/py/flwr/common/record/__init__.py b/src/py/flwr/common/record/__init__.py index 60bc54b8552a..88eef5f7aea1 100644 --- a/src/py/flwr/common/record/__init__.py +++ b/src/py/flwr/common/record/__init__.py @@ -22,9 +22,9 @@ __all__ = [ "Array", - "array_from_numpy", "ConfigsRecord", "MetricsRecord", "ParametersRecord", "RecordSet", + "array_from_numpy", ] diff --git a/src/py/flwr/server/__init__.py b/src/py/flwr/server/__init__.py index 19c6034bcaa1..546ce263e2d5 100644 --- a/src/py/flwr/server/__init__.py +++ b/src/py/flwr/server/__init__.py @@ -34,12 +34,12 @@ "Driver", "History", "LegacyContext", - "run_server_app", - "run_superlink", "Server", "ServerApp", "ServerConfig", "SimpleClientManager", + "run_server_app", + "run_superlink", "start_server", "strategy", "workflow", diff --git a/src/py/flwr/server/strategy/__init__.py b/src/py/flwr/server/strategy/__init__.py index b7de9a946fff..e5bc30009819 100644 --- a/src/py/flwr/server/strategy/__init__.py +++ b/src/py/flwr/server/strategy/__init__.py @@ -53,9 +53,10 @@ "DPFedAvgAdaptive", "DPFedAvgFixed", "DifferentialPrivacyClientSideAdaptiveClipping", - "DifferentialPrivacyServerSideAdaptiveClipping", "DifferentialPrivacyClientSideFixedClipping", + "DifferentialPrivacyServerSideAdaptiveClipping", "DifferentialPrivacyServerSideFixedClipping", + "FaultTolerantFedAvg", "FedAdagrad", "FedAdam", "FedAvg", @@ -69,7 +70,6 @@ "FedXgbCyclic", "FedXgbNnAvg", "FedYogi", - "FaultTolerantFedAvg", "Krum", "QFedAvg", "Strategy", diff --git a/src/py/flwr/simulation/__init__.py b/src/py/flwr/simulation/__init__.py index 57b0b01eb319..3d648b14edba 100644 --- a/src/py/flwr/simulation/__init__.py +++ b/src/py/flwr/simulation/__init__.py @@ -36,4 +36,7 @@ def start_simulation(*args, **kwargs): # type: ignore raise ImportError(RAY_IMPORT_ERROR) -__all__ = ["start_simulation", "run_simulation"] +__all__ = [ + "run_simulation", + "start_simulation", +] diff --git a/src/py/flwr_tool/init_py_check.py b/src/py/flwr_tool/init_py_check.py index 67425139f991..0ecbc6359344 100755 --- a/src/py/flwr_tool/init_py_check.py +++ b/src/py/flwr_tool/init_py_check.py @@ -6,15 +6,19 @@ """ +import ast import os import re import sys +from pathlib import Path +from typing import List, Tuple -def check_missing_init_files(absolute_path: str) -> None: - """Search absolute_path and look for missing __init__.py files.""" +def get_init_dir_list_and_warnings(absolute_path: str) -> Tuple[List[str], List[str]]: + """Search given path and return list of dirs containing __init__.py files.""" path = os.walk(absolute_path) warning_list = [] + dir_list = [] ignore_list = ["__pycache__$", ".pytest_cache.*$", "dist", "flwr.egg-info$"] for dir_path, _, files_in_dir in path: @@ -26,6 +30,14 @@ def check_missing_init_files(absolute_path: str) -> None: if not any(filename == "__init__.py" for filename in files_in_dir): warning_message = "- " + dir_path warning_list.append(warning_message) + else: + dir_list.append(dir_path) + return warning_list, dir_list + + +def check_missing_init_files(absolute_path: str) -> List[str]: + """Search absolute_path and look for missing __init__.py files.""" + warning_list, dir_list = get_init_dir_list_and_warnings(absolute_path) if len(warning_list) > 0: print("Could not find '__init__.py' in the following directories:") @@ -33,12 +45,64 @@ def check_missing_init_files(absolute_path: str) -> None: print(warning) sys.exit(1) + return dir_list + + +def get_all_var_list(dir: str) -> Tuple[Path, List[str], List[str]]: + """Get the __all__ list of a __init__.py file. + + The function returns the path of the '__init__.py' file of the given dir, as well as + the list itself, and the list of lines corresponding to the list. + """ + init_file = Path(dir) / "__init__.py" + all_lines = [] + all_list = [] + capture = False + for line in init_file.read_text().splitlines(): + stripped_line = line.strip() + if stripped_line.startswith("__all__"): + capture = True + if capture: + all_lines.append(line) + if stripped_line.endswith("]"): + capture = False + break + + if all_lines: + all_string = "".join(all_lines) + all_list = ast.literal_eval(all_string.split("=", 1)[1].strip()) + + return init_file, all_list, all_lines + + +def check_all_init_files(dir_list: List[str]) -> None: + """Check if __all__ is in alphabetical order in __init__.py files.""" + warning_list = [] + + for dir in dir_list: + init_file, all_list, _ = get_all_var_list(dir) + + if all_list and not all_list == sorted(all_list): + warning_message = "- " + str(init_file) + warning_list.append(warning_message) + + if len(warning_list) > 0: + print( + "'__all__' lists in the following '__init__.py' files are " + "incorrectly sorted:" + ) + for warning in warning_list: + print(warning) + sys.exit(1) + if __name__ == "__main__": if len(sys.argv) == 0: raise Exception( # pylint: disable=W0719 - "Please provide at least one directory path relative to your current working directory." + "Please provide at least one directory path relative " + "to your current working directory." ) for i, _ in enumerate(sys.argv): abs_path: str = os.path.abspath(os.path.join(os.getcwd(), sys.argv[i])) - check_missing_init_files(abs_path) + dir_list = check_missing_init_files(abs_path) + check_all_init_files(dir_list) diff --git a/src/py/flwr_tool/init_py_fix.py b/src/py/flwr_tool/init_py_fix.py new file mode 100755 index 000000000000..f3ebd7395a42 --- /dev/null +++ b/src/py/flwr_tool/init_py_fix.py @@ -0,0 +1,69 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +"""Fix provided directory and sub-directories for unsorted __all__ in __init__.py files. + +Example: + python -m flwr_tool.init_py_fix src/py/flwr +""" + + +import os +import sys +from typing import List + +import black + +from flwr_tool.init_py_check import get_all_var_list, get_init_dir_list_and_warnings + + +def fix_all_init_files(dir_list: List[str]) -> None: + """Sort the __all__ variables that are in __init__.py files.""" + warning_list = [] + + for dir in dir_list: + init_file, all_list, all_lines = get_all_var_list(dir) + + if all_list: + sorted_all_list = sorted(all_list) + if not all_list == sorted_all_list: + warning_message = "- " + str(dir) + warning_list.append(warning_message) + + old_all_lines = "\n".join(all_lines) + new_all_lines = ( + old_all_lines.split("=", 1)[0] + + "= " + + str(sorted_all_list)[:-1] + + ",]" + ) + + new_content = init_file.read_text().replace( + old_all_lines, new_all_lines + ) + + # Write the fixed content back to the file + init_file.write_text(new_content) + + # Format the file with black + black.format_file_in_place( + init_file, + fast=False, + mode=black.FileMode(), + write_back=black.WriteBack.YES, + ) + + if len(warning_list) > 0: + print("'__all__' lists in the following '__init__.py' files have been sorted:") + for warning in warning_list: + print(warning) + + +if __name__ == "__main__": + if len(sys.argv) == 0: + raise Exception( # pylint: disable=W0719 + "Please provide at least one directory path relative " + "to your current working directory." + ) + for i, _ in enumerate(sys.argv): + abs_path: str = os.path.abspath(os.path.join(os.getcwd(), sys.argv[i])) + warnings, dir_list = get_init_dir_list_and_warnings(abs_path) + fix_all_init_files(dir_list)