Skip to content

Commit

Permalink
use libcst
Browse files Browse the repository at this point in the history
for preserving comments, code structure
  • Loading branch information
SmartManoj committed Sep 26, 2024
1 parent 13157c1 commit 35a8e83
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 53 deletions.
158 changes: 106 additions & 52 deletions openhands/runtime/plugins/agent_skills/file_ops/ast_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import os

import libcst as cst


# Converting the AST 'arguments' object to a readable format
def format_signature(args):
Expand Down Expand Up @@ -33,9 +35,9 @@ def read_file(file_path):
return f.read()


def find_base_class_file(base_class_module):
def find_base_class_file(file_path, base_class_module):
# Assuming the module is in the current working directory or a standard path
dir_name = os.path.dirname(file_name)
dir_name = os.path.dirname(file_path)
if dir_name == '':
dir_name = os.getcwd()
possible_paths = [
Expand Down Expand Up @@ -71,7 +73,7 @@ def get_base_class_init_signature(file_path, base_class_name):
current_code = read_file(file_path)
base_class_module = find_imported_base_class(current_code, base_class_name)
if base_class_module:
base_class_file = find_base_class_file(base_class_module)
base_class_file = find_base_class_file(file_path, base_class_module)
if base_class_file:
base_init_signature = process_file_for_base_class_init(
base_class_file, base_class_name
Expand Down Expand Up @@ -100,8 +102,8 @@ def visit_ClassDef(self, node):
return node


def get_base_class_name(file_name, class_name):
code = read_file(file_name)
def get_base_class_name(file_path, class_name):
code = read_file(file_path)
tree = ast.parse(code)
finder = BaseClassFinder()

Expand All @@ -114,68 +116,120 @@ def get_base_class_name(file_name, class_name):
return finder.base_class_name


class SelectiveClassInitModifier(ast.NodeTransformer):
def __init__(self, file_name, class_name, param_name):
self.file_name = file_name
class InitMethodModifier(cst.CSTTransformer):
def __init__(self, file_path, class_name, param_name):
self.file_path = file_path
self.class_name = class_name
self.param_name = param_name

def visit_ClassDef(self, node):
# Only modify the specified class
if node.name == self.class_name:
for n in node.body:
if isinstance(n, ast.FunctionDef) and n.name == '__init__':
self.visit_FunctionDef(n)
return node

def visit_FunctionDef(self, node):
# Add the new parameter to the __init__ method
new_param = ast.arg(arg=self.param_name, annotation=None)
node.args.args.append(new_param)

# Add the parameter to the super().__init__ call

if self.param_name in get_base_class_init_signature(
self.file_name, get_base_class_name(self.file_name, self.class_name)
self.inside_target_class = False

def readable_params(self, params):
return [param.name.value for param in params]

def readable_args(self, args):
return [arg.keyword.value for arg in args]

def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
# Modify only if we are inside the target class's __init__ method
if self.inside_target_class and original_node.name.value == '__init__':
# Add the new parameter to the __init__ method
new_param = cst.Param(cst.Name(self.param_name))
if self.param_name not in self.readable_params(updated_node.params.params):
new_params = updated_node.params.with_changes(
params=[*updated_node.params.params, new_param]
)

# Update the __init__ function with the new parameter
return updated_node.with_changes(params=new_params)
else:
print(
f'{self.param_name} already present in the __init__ method of the {self.class_name} class in the {self.file_path} file.'
)
self.already_added = True

return updated_node

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef):
# Exit the target class after processing
self.inside_target_class = False
return updated_node

def visit_ClassDef(self, node: cst.ClassDef):
# Enter the target class
if node.name.value == self.class_name:
self.inside_target_class = True

def leave_Expr(self, original_node: cst.Expr, updated_node: cst.Expr) -> cst.Expr:
# if self.already_added:
# return updated_node
if self.param_name not in (
get_base_class_init_signature(
self.file_path, get_base_class_name(self.file_path, self.class_name)
)
or []
):
for n in node.body:
if isinstance(n, ast.Expr) and isinstance(n.value, ast.Call):
if (
isinstance(n.value.func, ast.Attribute)
and n.value.func.attr == '__init__'
):
kw = ast.keyword(
arg=self.param_name,
value=ast.Name(id=self.param_name, ctx=ast.Load()),
return updated_node
# Modify the super().__init__ call to include the new parameter
if self.inside_target_class and isinstance(original_node.value, cst.Call):
if isinstance(original_node.value.func, cst.Attribute):
if original_node.value.func.attr.value == '__init__':
# Check if existing arguments contain keyword arguments
has_keyword_args = any(
isinstance(arg, cst.Arg) and arg.keyword
for arg in original_node.value.args
)
# if the param is already present, do not add it again
if self.param_name in self.readable_args(original_node.value.args):
return updated_node
if has_keyword_args:
# Add the new parameter as a keyword argument to avoid positional after keyword error
new_arg = cst.Arg(
keyword=cst.Name(self.param_name),
value=cst.Name(self.param_name),
)
n.value.keywords.append(kw)
return node
else:
# Add the new parameter as a positional argument
new_arg = cst.Arg(value=cst.Name(self.param_name))

new_args = [*original_node.value.args, new_arg]
new_call = updated_node.value.with_changes(args=new_args)
return updated_node.with_changes(value=new_call)

return updated_node


def add_param_to_init_in_subclass(file_name, class_name, param_name):
def add_param_to_init_in_subclass(file_path, class_name, param_name):
"""
This function adds a new parameter to the __init__ method of a specified sub class in a given file and adds it to the super().__init__ call if the parameter is present in the base class __init__ method using AST.
This function adds a new parameter to the __init__ method of a specified sub class in a given file and adds it to the super().__init__ call automatically by checking if the parameter is present in the base class __init__ method using AST.
Args:
file_name (str): The path to the file containing the class.
file_path (str): The path to the file containing the class.
class_name (str): The name of the subclass to modify.
param_name (str): The name of the new parameter to add to the __init__ method.
Returns:
str: The modified code with the new parameter added to the __init__ method.
"""
code = read_file(file_name)
tree = ast.parse(code)
modifier = SelectiveClassInitModifier(file_name, class_name, param_name)
modified_tree = modifier.visit(tree)
return ast.unparse(modified_tree)
code = read_file(file_path)
# Parse the code into a CST tree
tree = cst.parse_module(code)

# Create the transformer to modify the __init__ method
transformer = InitMethodModifier(file_path, class_name, param_name)

# Apply the transformation
modified_tree = tree.visit(transformer)
new_code = modified_tree.code
if new_code != code:
with open(file_path, 'w') as f:
f.write(modified_tree.code)
print(
f'{param_name} added to the __init__ method of the {class_name} class in the {file_path} file.'
)


if __name__ == '__main__':
file_name = r'C:\Users\smart\Desktop\GD\astropy\astropy\io\ascii\rst.py'
# print(get_base_class_init_signature(file_name, "FixedWidth"))
# print(get_base_class_init_signature(file_path, "FixedWidth"))

modified_code_class_one = add_param_to_init_in_subclass(
file_name, 'RST', 'header_rows'
)
print(modified_code_class_one)
add_param_to_init_in_subclass(file_name, 'RST', 'header_rows')
42 changes: 41 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pylatexenc = "*"
tornado = "*"
python-dotenv = "*"

libcst = "^1.4.0"
[tool.poetry.group.llama-index.dependencies]
llama-index = "*"
llama-index-vector-stores-chroma = "*"
Expand Down

0 comments on commit 35a8e83

Please sign in to comment.