Skip to content

Commit

Permalink
Improve ir ast compatibility with NodeTransformer.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark committed Sep 10, 2024
1 parent eda72c0 commit 77b42b2
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 18 deletions.
7 changes: 5 additions & 2 deletions edb/common/ast/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(self, **kwargs):
self.__dict__ = kwargs

def __copy__(self):
copied = self.__class__()
copied = self._init_copy()
for field, value in iter_fields(self, include_meta=True):
try:
object.__setattr__(copied, field, value)
Expand All @@ -271,11 +271,14 @@ def __copy__(self):
return copied

def __deepcopy__(self, memo):
copied = self.__class__()
copied = self._init_copy()
for field, value in iter_fields(self, include_meta=True):
object.__setattr__(copied, field, copy.deepcopy(value, memo))
return copied

def _init_copy(self):
return self.__class__()

def replace(self: T, **changes) -> T:
copied = copy.copy(self)
for field, value in changes.items():
Expand Down
36 changes: 27 additions & 9 deletions edb/common/ast/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,34 @@ def visit_Name(self, node):
"""

def generic_visit(self, node):
for field, old_value in base.iter_fields(node, include_meta=False):
old_value = getattr(node, field, None)
if isinstance(node, base.ImmutableASTMixin):
changes = {}

if typeutils.is_container(old_value):
new_values = old_value.__class__(self.visit(old_value))
setattr(node, field, old_value.__class__(new_values))
for field, old_value in base.iter_fields(node, include_meta=False):
old_value = getattr(node, field, None)

elif isinstance(old_value, base.AST):
new_node = self.visit(old_value)
if new_node is not old_value:
setattr(node, field, new_node)
if typeutils.is_container(old_value):
new_values = old_value.__class__(self.visit(old_value))
changes[field] = old_value.__class__(new_values)

elif isinstance(old_value, base.AST):
new_node = self.visit(old_value)
if new_node is not old_value:
changes[field] = new_node

node = node.replace(**changes)

else:
for field, old_value in base.iter_fields(node, include_meta=False):
old_value = getattr(node, field, None)

if typeutils.is_container(old_value):
new_values = old_value.__class__(self.visit(old_value))
setattr(node, field, old_value.__class__(new_values))

elif isinstance(old_value, base.AST):
new_node = self.visit(old_value)
if new_node is not old_value:
setattr(node, field, new_node)

return node
22 changes: 16 additions & 6 deletions edb/common/ast/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,22 @@ def run(cls, node, **kwargs):
return visitor.visit(node)

def container_visit(self, node):
result = []
for elem in (node.values() if isinstance(node, dict) else node):
if base.is_ast_node(elem) or typeutils.is_container(elem):
result.append(self.visit(elem))
else:
result.append(elem)
if isinstance(node, dict):
result = {}
for key, value in node.items():
if base.is_ast_node(value) or typeutils.is_container(value):
result[key] = self.visit(value)
else:
result[key] = value

else:
result = []
for elem in node:
if base.is_ast_node(elem) or typeutils.is_container(elem):
result.append(self.visit(elem))
else:
result.append(elem)

return result

def repeated_node_visit(self, node):
Expand Down
2 changes: 2 additions & 0 deletions edb/common/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ def container_visit(self, node) -> List[Span | None]:
pass
elif isinstance(span, list):
span_list.extend(span)
elif isinstance(span, dict):
span_list.extend(span.values())
else:
span_list.append(span)
return span_list
Expand Down
2 changes: 2 additions & 0 deletions edb/graphql/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,8 @@ def combine_field_results(self, results, *, flatten=True):
for res in results:
if isinstance(res, Field):
flattened.append(res)
elif isinstance(res, dict):
flattened.extend(res.values())
elif typeutils.is_container(res):
flattened.extend(res)
else:
Expand Down
3 changes: 3 additions & 0 deletions edb/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,9 @@ def __init__(
if self.value is None:
raise ValueError('cannot create irast.Constant without a value')

def _init_copy(self) -> BaseConstant:
return self.__class__(typeref=self.typeref, value=self.value)


class BaseStrConstant(BaseConstant):
__abstract_node__ = True
Expand Down
5 changes: 4 additions & 1 deletion edb/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,11 @@ def __init__(self, *, skip_bindings: bool) -> None:
super().__init__()
self.skip_bindings = skip_bindings

def combine_field_results(self, xs: List[Optional[bool]]) -> bool:
def combine_field_results(self, xs: Iterable[Optional[bool]]) -> bool:
return any(
x is True
or (isinstance(x, list) and self.combine_field_results(x))
or (isinstance(x, dict) and self.combine_field_results(x.values()))
for x in xs
)

Expand Down Expand Up @@ -420,6 +421,8 @@ def combine_field_results(self, xs: Any) -> Set[irast.Set]:
for x in xs:
if isinstance(x, list):
x = self.combine_field_results(x)
if isinstance(x, dict):
x = self.combine_field_results(x.values())
if x:
if isinstance(x, set):
out.update(x)
Expand Down

0 comments on commit 77b42b2

Please sign in to comment.