Skip to content

Commit

Permalink
ci(framework:skip) Add test and format for __init__.__all__
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jun 19, 2024
1 parent 2cf54b0 commit e6e95b1
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 26 deletions.
1 change: 1 addition & 0 deletions dev/format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../

# Python
python -m flwr_tool.init_py_fix src/py/flwr
python -m isort --skip src/py/flwr/proto src/py
python -m black -q --exclude src/py/flwr/proto src/py
python -m docformatter -i -r src/py/flwr -e src/py/flwr/proto
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"Client",
"ClientApp",
"ClientFn",
"mod",
"NumPyClient",
"mod",
"run_client_app",
"run_supernode",
"start_client",
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/mod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from .utils import make_ffn

__all__ = [
"LocalDpMod",
"adaptiveclipping_mod",
"fixedclipping_mod",
"LocalDpMod",
"make_ffn",
"secagg_mod",
"secaggplus_mod",
"message_size_mod",
"parameters_size_mod",
"secagg_mod",
"secaggplus_mod",
]
24 changes: 12 additions & 12 deletions src/py/flwr/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,48 +63,48 @@

__all__ = [
"Array",
"array_from_numpy",
"bytes_to_ndarray",
"ClientMessage",
"Code",
"Config",
"ConfigsRecord",
"configure",
"Context",
"DEFAULT_TTL",
"DisconnectRes",
"Error",
"EvaluateIns",
"EvaluateRes",
"event",
"EventType",
"FitIns",
"FitRes",
"Error",
"GRPC_MAX_MESSAGE_LENGTH",
"GetParametersIns",
"GetParametersRes",
"GetPropertiesIns",
"GetPropertiesRes",
"GRPC_MAX_MESSAGE_LENGTH",
"log",
"Message",
"MessageType",
"MessageTypeLegacy",
"DEFAULT_TTL",
"Metadata",
"Metrics",
"MetricsAggregationFn",
"MetricsRecord",
"ndarray_to_bytes",
"now",
"NDArray",
"NDArrays",
"ndarrays_to_parameters",
"Parameters",
"parameters_to_ndarrays",
"ParametersRecord",
"Properties",
"ReconnectIns",
"RecordSet",
"Scalar",
"ServerMessage",
"Status",
"array_from_numpy",
"bytes_to_ndarray",
"configure",
"event",
"log",
"ndarray_to_bytes",
"ndarrays_to_parameters",
"now",
"parameters_to_ndarrays",
]
2 changes: 1 addition & 1 deletion src/py/flwr/common/record/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

__all__ = [
"Array",
"array_from_numpy",
"ConfigsRecord",
"MetricsRecord",
"ParametersRecord",
"RecordSet",
"array_from_numpy",
]
4 changes: 2 additions & 2 deletions src/py/flwr/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
"Driver",
"History",
"LegacyContext",
"run_server_app",
"run_superlink",
"Server",
"ServerApp",
"ServerConfig",
"SimpleClientManager",
"run_server_app",
"run_superlink",
"start_server",
"strategy",
"workflow",
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@
"DPFedAvgAdaptive",
"DPFedAvgFixed",
"DifferentialPrivacyClientSideAdaptiveClipping",
"DifferentialPrivacyServerSideAdaptiveClipping",
"DifferentialPrivacyClientSideFixedClipping",
"DifferentialPrivacyServerSideAdaptiveClipping",
"DifferentialPrivacyServerSideFixedClipping",
"FaultTolerantFedAvg",
"FedAdagrad",
"FedAdam",
"FedAvg",
Expand All @@ -69,7 +70,6 @@
"FedXgbCyclic",
"FedXgbNnAvg",
"FedYogi",
"FaultTolerantFedAvg",
"Krum",
"QFedAvg",
"Strategy",
Expand Down
5 changes: 4 additions & 1 deletion src/py/flwr/simulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ def start_simulation(*args, **kwargs): # type: ignore
raise ImportError(RAY_IMPORT_ERROR)


__all__ = ["start_simulation", "run_simulation"]
__all__ = [
"run_simulation",
"start_simulation",
]
72 changes: 68 additions & 4 deletions src/py/flwr_tool/init_py_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
"""


import ast
import os
import re
import sys
from pathlib import Path
from typing import List, Tuple


def check_missing_init_files(absolute_path: str) -> None:
"""Search absolute_path and look for missing __init__.py files."""
def get_init_dir_list_and_warnings(absolute_path: str) -> Tuple[List[str], List[str]]:
"""Search given path and return list of dirs containing __init__.py files."""
path = os.walk(absolute_path)
warning_list = []
dir_list = []
ignore_list = ["__pycache__$", ".pytest_cache.*$", "dist", "flwr.egg-info$"]

for dir_path, _, files_in_dir in path:
Expand All @@ -26,19 +30,79 @@ def check_missing_init_files(absolute_path: str) -> None:
if not any(filename == "__init__.py" for filename in files_in_dir):
warning_message = "- " + dir_path
warning_list.append(warning_message)
else:
dir_list.append(dir_path)
return warning_list, dir_list


def check_missing_init_files(absolute_path: str) -> List[str]:
"""Search absolute_path and look for missing __init__.py files."""
warning_list, dir_list = get_init_dir_list_and_warnings(absolute_path)

if len(warning_list) > 0:
print("Could not find '__init__.py' in the following directories:")
for warning in warning_list:
print(warning)
sys.exit(1)

return dir_list


def get_all_var_list(dir: str) -> Tuple[Path, List[str], List[str]]:
"""Get the __all__ list of a __init__.py file.
The function returns the path of the '__init__.py' file of the given dir, as well as
the list itself, and the list of lines corresponding to the list.
"""
init_file = Path(dir) / "__init__.py"
all_lines = []
all_list = []
capture = False
for line in init_file.read_text().splitlines():
stripped_line = line.strip()
if stripped_line.startswith("__all__"):
capture = True
if capture:
all_lines.append(line)
if stripped_line.endswith("]"):
capture = False
break

if all_lines:
all_string = "".join(all_lines)
all_list = ast.literal_eval(all_string.split("=", 1)[1].strip())

return init_file, all_list, all_lines


def check_all_init_files(dir_list: List[str]) -> None:
"""Check if __all__ is in alphabetical order in __init__.py files."""
warning_list = []

for dir in dir_list:
init_file, all_list, _ = get_all_var_list(dir)

if all_list and not all_list == sorted(all_list):
warning_message = "- " + str(init_file)
warning_list.append(warning_message)

if len(warning_list) > 0:
print(
"'__all__' lists in the following '__init__.py' files are "
"incorrectly sorted:"
)
for warning in warning_list:
print(warning)
sys.exit(1)


if __name__ == "__main__":
if len(sys.argv) == 0:
raise Exception( # pylint: disable=W0719
"Please provide at least one directory path relative to your current working directory."
"Please provide at least one directory path relative "
"to your current working directory."
)
for i, _ in enumerate(sys.argv):
abs_path: str = os.path.abspath(os.path.join(os.getcwd(), sys.argv[i]))
check_missing_init_files(abs_path)
dir_list = check_missing_init_files(abs_path)
check_all_init_files(dir_list)
69 changes: 69 additions & 0 deletions src/py/flwr_tool/init_py_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2020 Flower Labs GmbH. All Rights Reserved.
"""Fix provided directory and sub-directories for unsorted __all__ in __init__.py files.
Example:
python -m flwr_tool.init_py_fix src/py/flwr
"""


import os
import sys
from typing import List

import black

from flwr_tool.init_py_check import get_all_var_list, get_init_dir_list_and_warnings


def fix_all_init_files(dir_list: List[str]) -> None:
"""Sort the __all__ variables that are in __init__.py files."""
warning_list = []

for dir in dir_list:
init_file, all_list, all_lines = get_all_var_list(dir)

if all_list:
sorted_all_list = sorted(all_list)
if not all_list == sorted_all_list:
warning_message = "- " + str(dir)
warning_list.append(warning_message)

old_all_lines = "\n".join(all_lines)
new_all_lines = (
old_all_lines.split("=", 1)[0]
+ "= "
+ str(sorted_all_list)[:-1]
+ ",]"
)

new_content = init_file.read_text().replace(
old_all_lines, new_all_lines
)

# Write the fixed content back to the file
init_file.write_text(new_content)

# Format the file with black
black.format_file_in_place(
init_file,
fast=False,
mode=black.FileMode(),
write_back=black.WriteBack.YES,
)

if len(warning_list) > 0:
print("'__all__' lists in the following '__init__.py' files have been sorted:")
for warning in warning_list:
print(warning)


if __name__ == "__main__":
if len(sys.argv) == 0:
raise Exception( # pylint: disable=W0719
"Please provide at least one directory path relative "
"to your current working directory."
)
for i, _ in enumerate(sys.argv):
abs_path: str = os.path.abspath(os.path.join(os.getcwd(), sys.argv[i]))
warnings, dir_list = get_init_dir_list_and_warnings(abs_path)
fix_all_init_files(dir_list)

0 comments on commit e6e95b1

Please sign in to comment.