diff --git a/src/py/flwr_tool/check_copyright.py b/src/py/flwr_tool/check_copyright.py index 60dd7f2954ff..4ebcab823a38 100755 --- a/src/py/flwr_tool/check_copyright.py +++ b/src/py/flwr_tool/check_copyright.py @@ -35,6 +35,7 @@ def _get_file_creation_year(filepath: str): ["git", "log", "--diff-filter=A", "--format=%ai", "--", filepath], stdout=subprocess.PIPE, text=True, + check=True, ) date_str = result.stdout.splitlines()[-1] # Get the first commit date creation_year = date_str.split("-")[0] # Extract the year diff --git a/src/py/flwr_tool/fix_copyright.py b/src/py/flwr_tool/fix_copyright.py index 378e0749f4ce..5366cd5ad82e 100755 --- a/src/py/flwr_tool/fix_copyright.py +++ b/src/py/flwr_tool/fix_copyright.py @@ -15,26 +15,33 @@ from flwr_tool.init_py_check import get_init_dir_list_and_warnings +def _insert_or_edit_copyright(py_file: Path) -> None: + contents = py_file.read_text() + lines = contents.splitlines() + creation_year = _get_file_creation_year(str(py_file.absolute())) + expected_copyright = COPYRIGHT_FORMAT.format(creation_year) + + if expected_copyright not in contents: + if "Copyright" in lines[0]: + end_index = 0 + for idx, line in enumerate(lines): + if ( + line.strip() + == COPYRIGHT_FORMAT.rsplit("\n", maxsplit=1)[-1].strip() + ): + end_index = idx + 1 + break + lines = lines[end_index:] + + lines.insert(0, expected_copyright) + py_file.write_text("\n".join(lines)) + + def _fix_copyright(dir_list: List[str]) -> None: for valid_dir in dir_list: dir_path = Path(valid_dir) for py_file in dir_path.glob("*.py"): - contents = py_file.read_text() - lines = contents.splitlines() - creation_year = _get_file_creation_year(str(py_file.absolute())) - expected_copyright = COPYRIGHT_FORMAT.format(creation_year) - - if expected_copyright not in contents: - if "Copyright" in lines[0]: - end_index = 0 - for i, line in enumerate(lines): - if line.strip() == COPYRIGHT_FORMAT.split("\n")[-1].strip(): - end_index = i + 1 - break - lines = lines[end_index:] - - lines.insert(0, expected_copyright) - py_file.write_text("\n".join(lines)) + _insert_or_edit_copyright(py_file) if __name__ == "__main__":