diff --git a/openhands/runtime/plugins/agent_skills/file_ops/ast_ops.py b/openhands/runtime/plugins/agent_skills/file_ops/ast_ops.py index a84af7db4b79..28b9307a07e4 100644 --- a/openhands/runtime/plugins/agent_skills/file_ops/ast_ops.py +++ b/openhands/runtime/plugins/agent_skills/file_ops/ast_ops.py @@ -1,6 +1,8 @@ import ast import os +import libcst as cst + # Converting the AST 'arguments' object to a readable format def format_signature(args): @@ -33,9 +35,9 @@ def read_file(file_path): return f.read() -def find_base_class_file(base_class_module): +def find_base_class_file(file_path, base_class_module): # Assuming the module is in the current working directory or a standard path - dir_name = os.path.dirname(file_name) + dir_name = os.path.dirname(file_path) if dir_name == '': dir_name = os.getcwd() possible_paths = [ @@ -71,7 +73,7 @@ def get_base_class_init_signature(file_path, base_class_name): current_code = read_file(file_path) base_class_module = find_imported_base_class(current_code, base_class_name) if base_class_module: - base_class_file = find_base_class_file(base_class_module) + base_class_file = find_base_class_file(file_path, base_class_module) if base_class_file: base_init_signature = process_file_for_base_class_init( base_class_file, base_class_name @@ -100,8 +102,8 @@ def visit_ClassDef(self, node): return node -def get_base_class_name(file_name, class_name): - code = read_file(file_name) +def get_base_class_name(file_path, class_name): + code = read_file(file_path) tree = ast.parse(code) finder = BaseClassFinder() @@ -114,68 +116,120 @@ def get_base_class_name(file_name, class_name): return finder.base_class_name -class SelectiveClassInitModifier(ast.NodeTransformer): - def __init__(self, file_name, class_name, param_name): - self.file_name = file_name +class InitMethodModifier(cst.CSTTransformer): + def __init__(self, file_path, class_name, param_name): + self.file_path = file_path self.class_name = class_name self.param_name = param_name - - def visit_ClassDef(self, node): - # Only modify the specified class - if node.name == self.class_name: - for n in node.body: - if isinstance(n, ast.FunctionDef) and n.name == '__init__': - self.visit_FunctionDef(n) - return node - - def visit_FunctionDef(self, node): - # Add the new parameter to the __init__ method - new_param = ast.arg(arg=self.param_name, annotation=None) - node.args.args.append(new_param) - - # Add the parameter to the super().__init__ call - - if self.param_name in get_base_class_init_signature( - self.file_name, get_base_class_name(self.file_name, self.class_name) + self.inside_target_class = False + + def readable_params(self, params): + return [param.name.value for param in params] + + def readable_args(self, args): + return [arg.keyword.value for arg in args] + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + # Modify only if we are inside the target class's __init__ method + if self.inside_target_class and original_node.name.value == '__init__': + # Add the new parameter to the __init__ method + new_param = cst.Param(cst.Name(self.param_name)) + if self.param_name not in self.readable_params(updated_node.params.params): + new_params = updated_node.params.with_changes( + params=[*updated_node.params.params, new_param] + ) + + # Update the __init__ function with the new parameter + return updated_node.with_changes(params=new_params) + else: + print( + f'{self.param_name} already present in the __init__ method of the {self.class_name} class in the {self.file_path} file.' + ) + self.already_added = True + + return updated_node + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef): + # Exit the target class after processing + self.inside_target_class = False + return updated_node + + def visit_ClassDef(self, node: cst.ClassDef): + # Enter the target class + if node.name.value == self.class_name: + self.inside_target_class = True + + def leave_Expr(self, original_node: cst.Expr, updated_node: cst.Expr) -> cst.Expr: + # if self.already_added: + # return updated_node + if self.param_name not in ( + get_base_class_init_signature( + self.file_path, get_base_class_name(self.file_path, self.class_name) + ) + or [] ): - for n in node.body: - if isinstance(n, ast.Expr) and isinstance(n.value, ast.Call): - if ( - isinstance(n.value.func, ast.Attribute) - and n.value.func.attr == '__init__' - ): - kw = ast.keyword( - arg=self.param_name, - value=ast.Name(id=self.param_name, ctx=ast.Load()), + return updated_node + # Modify the super().__init__ call to include the new parameter + if self.inside_target_class and isinstance(original_node.value, cst.Call): + if isinstance(original_node.value.func, cst.Attribute): + if original_node.value.func.attr.value == '__init__': + # Check if existing arguments contain keyword arguments + has_keyword_args = any( + isinstance(arg, cst.Arg) and arg.keyword + for arg in original_node.value.args + ) + # if the param is already present, do not add it again + if self.param_name in self.readable_args(original_node.value.args): + return updated_node + if has_keyword_args: + # Add the new parameter as a keyword argument to avoid positional after keyword error + new_arg = cst.Arg( + keyword=cst.Name(self.param_name), + value=cst.Name(self.param_name), ) - n.value.keywords.append(kw) - return node + else: + # Add the new parameter as a positional argument + new_arg = cst.Arg(value=cst.Name(self.param_name)) + + new_args = [*original_node.value.args, new_arg] + new_call = updated_node.value.with_changes(args=new_args) + return updated_node.with_changes(value=new_call) + + return updated_node -def add_param_to_init_in_subclass(file_name, class_name, param_name): +def add_param_to_init_in_subclass(file_path, class_name, param_name): """ - This function adds a new parameter to the __init__ method of a specified sub class in a given file and adds it to the super().__init__ call if the parameter is present in the base class __init__ method using AST. + This function adds a new parameter to the __init__ method of a specified sub class in a given file and adds it to the super().__init__ call automatically by checking if the parameter is present in the base class __init__ method using AST. Args: - file_name (str): The path to the file containing the class. + file_path (str): The path to the file containing the class. class_name (str): The name of the subclass to modify. param_name (str): The name of the new parameter to add to the __init__ method. - Returns: - str: The modified code with the new parameter added to the __init__ method. """ - code = read_file(file_name) - tree = ast.parse(code) - modifier = SelectiveClassInitModifier(file_name, class_name, param_name) - modified_tree = modifier.visit(tree) - return ast.unparse(modified_tree) + code = read_file(file_path) + # Parse the code into a CST tree + tree = cst.parse_module(code) + + # Create the transformer to modify the __init__ method + transformer = InitMethodModifier(file_path, class_name, param_name) + + # Apply the transformation + modified_tree = tree.visit(transformer) + new_code = modified_tree.code + if new_code != code: + with open(file_path, 'w') as f: + f.write(modified_tree.code) + print( + f'{param_name} added to the __init__ method of the {class_name} class in the {file_path} file.' + ) if __name__ == '__main__': file_name = r'C:\Users\smart\Desktop\GD\astropy\astropy\io\ascii\rst.py' - # print(get_base_class_init_signature(file_name, "FixedWidth")) + # print(get_base_class_init_signature(file_path, "FixedWidth")) - modified_code_class_one = add_param_to_init_in_subclass( - file_name, 'RST', 'header_rows' - ) - print(modified_code_class_one) + add_param_to_init_in_subclass(file_name, 'RST', 'header_rows') diff --git a/poetry.lock b/poetry.lock index dbdd19d04fa9..afd5454e02bf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3707,6 +3707,46 @@ dev = ["changelist (==0.5)"] lint = ["pre-commit (==3.7.0)"] test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"] +[[package]] +name = "libcst" +version = "1.4.0" +description = "A concrete syntax tree with AST-like properties for Python 3.0 through 3.12 programs." +optional = false +python-versions = ">=3.9" +files = [ + {file = "libcst-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:279b54568ea1f25add50ea4ba3d76d4f5835500c82f24d54daae4c5095b986aa"}, + {file = "libcst-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3401dae41fe24565387a65baee3887e31a44e3e58066b0250bc3f3ccf85b1b5a"}, + {file = "libcst-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1989fa12d3cd79118ebd29ebe2a6976d23d509b1a4226bc3d66fcb7cb50bd5d"}, + {file = "libcst-1.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:addc6d585141a7677591868886f6bda0577529401a59d210aa8112114340e129"}, + {file = "libcst-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17d71001cb25e94cfe8c3d997095741a8c4aa7a6d234c0f972bc42818c88dfaf"}, + {file = "libcst-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:2d47de16d105e7dd5f4e01a428d9f4dc1e71efd74f79766daf54528ce37f23c3"}, + {file = "libcst-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e6227562fc5c9c1efd15dfe90b0971ae254461b8b6b23c1b617139b6003de1c1"}, + {file = "libcst-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3399e6c95df89921511b44d8c5bf6a75bcbc2d51f1f6429763609ba005c10f6b"}, + {file = "libcst-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48601e3e590e2d6a7ab8c019cf3937c70511a78d778ab3333764531253acdb33"}, + {file = "libcst-1.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42797309bb725f0f000510d5463175ccd7155395f09b5e7723971b0007a976d"}, + {file = "libcst-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb4e42ea107a37bff7f9fdbee9532d39f9ea77b89caa5c5112b37057b12e0838"}, + {file = "libcst-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:9d0cc3c5a2a51fa7e1d579a828c0a2e46b2170024fd8b1a0691c8a52f3abb2d9"}, + {file = "libcst-1.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7ece51d935bc9bf60b528473d2e5cc67cbb88e2f8146297e40ee2c7d80be6f13"}, + {file = "libcst-1.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:81653dea1cdfa4c6520a7c5ffb95fa4d220cbd242e446c7a06d42d8636bfcbba"}, + {file = "libcst-1.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6abce0e66bba2babfadc20530fd3688f672d565674336595b4623cd800b91ef"}, + {file = "libcst-1.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5da9d7dc83801aba3b8d911f82dc1a375db0d508318bad79d9fb245374afe068"}, + {file = "libcst-1.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c54aa66c86d8ece9c93156a2cf5ca512b0dce40142fe9e072c86af2bf892411"}, + {file = "libcst-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:62e2682ee1567b6a89c91853865372bf34f178bfd237853d84df2b87b446e654"}, + {file = "libcst-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b8ecdba8934632b4dadacb666cd3816627a6ead831b806336972ccc4ba7ca0e9"}, + {file = "libcst-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8e54c777b8d27339b70f304d16fc8bc8674ef1bd34ed05ea874bf4921eb5a313"}, + {file = "libcst-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:061d6855ef30efe38b8a292b7e5d57c8e820e71fc9ec9846678b60a934b53bbb"}, + {file = "libcst-1.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb0abf627ee14903d05d0ad9b2c6865f1b21eb4081e2c7bea1033f85db2b8bae"}, + {file = "libcst-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d024f44059a853b4b852cfc04fec33e346659d851371e46fc8e7c19de24d3da9"}, + {file = "libcst-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:3c6a8faab9da48c5b371557d0999b4ca51f4f2cbd37ee8c2c4df0ac01c781465"}, + {file = "libcst-1.4.0.tar.gz", hash = "sha256:449e0b16604f054fa7f27c3ffe86ea7ef6c409836fe68fe4e752a1894175db00"}, +] + +[package.dependencies] +pyyaml = ">=5.2" + +[package.extras] +dev = ["Sphinx (>=5.1.1)", "black (==23.12.1)", "build (>=0.10.0)", "coverage (>=4.5.4)", "fixit (==2.1.0)", "flake8 (==7.0.0)", "hypothesis (>=4.36.0)", "hypothesmith (>=0.0.4)", "jinja2 (==3.1.4)", "jupyter (>=1.0.0)", "maturin (>=0.8.3,<1.6)", "nbsphinx (>=0.4.2)", "prompt-toolkit (>=2.0.9)", "pyre-check (==0.9.18)", "setuptools-rust (>=1.5.2)", "setuptools-scm (>=6.0.1)", "slotscheck (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "ufmt (==2.6.0)", "usort (==1.0.8.post1)"] + [[package]] name = "libvisualwebarena" version = "0.0.8" @@ -9688,4 +9728,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "90636ce436e5c05146a69730f461f46fd3185b595be37d3eafd8aef36667db81" +content-hash = "dbe32b510876baad6a9c2a09704c089c5ca93238d37f8713c2e0786b23c02e95" diff --git a/pyproject.toml b/pyproject.toml index eb17e59f984b..afb05c6d4d3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ pylatexenc = "*" tornado = "*" python-dotenv = "*" +libcst = "^1.4.0" [tool.poetry.group.llama-index.dependencies] llama-index = "*" llama-index-vector-stores-chroma = "*"