From 77b42b24c0a6b7e1e5bf6264457c442734acfc2c Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 4 Sep 2024 10:39:58 -0400 Subject: [PATCH] Improve ir ast compatibility with NodeTransformer. --- edb/common/ast/base.py | 7 +++++-- edb/common/ast/transformer.py | 36 ++++++++++++++++++++++++++--------- edb/common/ast/visitor.py | 22 +++++++++++++++------ edb/common/span.py | 2 ++ edb/graphql/translator.py | 2 ++ edb/ir/ast.py | 3 +++ edb/ir/utils.py | 5 ++++- 7 files changed, 59 insertions(+), 18 deletions(-) diff --git a/edb/common/ast/base.py b/edb/common/ast/base.py index 9ef48c6e9f1..3b44ffb8a5e 100644 --- a/edb/common/ast/base.py +++ b/edb/common/ast/base.py @@ -261,7 +261,7 @@ def __init__(self, **kwargs): self.__dict__ = kwargs def __copy__(self): - copied = self.__class__() + copied = self._init_copy() for field, value in iter_fields(self, include_meta=True): try: object.__setattr__(copied, field, value) @@ -271,11 +271,14 @@ def __copy__(self): return copied def __deepcopy__(self, memo): - copied = self.__class__() + copied = self._init_copy() for field, value in iter_fields(self, include_meta=True): object.__setattr__(copied, field, copy.deepcopy(value, memo)) return copied + def _init_copy(self): + return self.__class__() + def replace(self: T, **changes) -> T: copied = copy.copy(self) for field, value in changes.items(): diff --git a/edb/common/ast/transformer.py b/edb/common/ast/transformer.py index 0dab0c2e13a..30d11671c79 100644 --- a/edb/common/ast/transformer.py +++ b/edb/common/ast/transformer.py @@ -60,16 +60,34 @@ def visit_Name(self, node): """ def generic_visit(self, node): - for field, old_value in base.iter_fields(node, include_meta=False): - old_value = getattr(node, field, None) + if isinstance(node, base.ImmutableASTMixin): + changes = {} - if typeutils.is_container(old_value): - new_values = old_value.__class__(self.visit(old_value)) - setattr(node, field, old_value.__class__(new_values)) + for field, old_value in base.iter_fields(node, include_meta=False): + old_value = getattr(node, field, None) - elif isinstance(old_value, base.AST): - new_node = self.visit(old_value) - if new_node is not old_value: - setattr(node, field, new_node) + if typeutils.is_container(old_value): + new_values = old_value.__class__(self.visit(old_value)) + changes[field] = old_value.__class__(new_values) + + elif isinstance(old_value, base.AST): + new_node = self.visit(old_value) + if new_node is not old_value: + changes[field] = new_node + + node = node.replace(**changes) + + else: + for field, old_value in base.iter_fields(node, include_meta=False): + old_value = getattr(node, field, None) + + if typeutils.is_container(old_value): + new_values = old_value.__class__(self.visit(old_value)) + setattr(node, field, old_value.__class__(new_values)) + + elif isinstance(old_value, base.AST): + new_node = self.visit(old_value) + if new_node is not old_value: + setattr(node, field, new_node) return node diff --git a/edb/common/ast/visitor.py b/edb/common/ast/visitor.py index 73d68634c93..b1df270f320 100644 --- a/edb/common/ast/visitor.py +++ b/edb/common/ast/visitor.py @@ -121,12 +121,22 @@ def run(cls, node, **kwargs): return visitor.visit(node) def container_visit(self, node): - result = [] - for elem in (node.values() if isinstance(node, dict) else node): - if base.is_ast_node(elem) or typeutils.is_container(elem): - result.append(self.visit(elem)) - else: - result.append(elem) + if isinstance(node, dict): + result = {} + for key, value in node.items(): + if base.is_ast_node(value) or typeutils.is_container(value): + result[key] = self.visit(value) + else: + result[key] = value + + else: + result = [] + for elem in node: + if base.is_ast_node(elem) or typeutils.is_container(elem): + result.append(self.visit(elem)) + else: + result.append(elem) + return result def repeated_node_visit(self, node): diff --git a/edb/common/span.py b/edb/common/span.py index ca543c95588..1610bda4581 100644 --- a/edb/common/span.py +++ b/edb/common/span.py @@ -260,6 +260,8 @@ def container_visit(self, node) -> List[Span | None]: pass elif isinstance(span, list): span_list.extend(span) + elif isinstance(span, dict): + span_list.extend(span.values()) else: span_list.append(span) return span_list diff --git a/edb/graphql/translator.py b/edb/graphql/translator.py index 7860394691a..9c419a8fc9b 100644 --- a/edb/graphql/translator.py +++ b/edb/graphql/translator.py @@ -1809,6 +1809,8 @@ def combine_field_results(self, results, *, flatten=True): for res in results: if isinstance(res, Field): flattened.append(res) + elif isinstance(res, dict): + flattened.extend(res.values()) elif typeutils.is_container(res): flattened.extend(res) else: diff --git a/edb/ir/ast.py b/edb/ir/ast.py index 9559f05a022..dbaa1327828 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -829,6 +829,9 @@ def __init__( if self.value is None: raise ValueError('cannot create irast.Constant without a value') + def _init_copy(self) -> BaseConstant: + return self.__class__(typeref=self.typeref, value=self.value) + class BaseStrConstant(BaseConstant): __abstract_node__ = True diff --git a/edb/ir/utils.py b/edb/ir/utils.py index 06289569d5c..a431ee84b34 100644 --- a/edb/ir/utils.py +++ b/edb/ir/utils.py @@ -304,10 +304,11 @@ def __init__(self, *, skip_bindings: bool) -> None: super().__init__() self.skip_bindings = skip_bindings - def combine_field_results(self, xs: List[Optional[bool]]) -> bool: + def combine_field_results(self, xs: Iterable[Optional[bool]]) -> bool: return any( x is True or (isinstance(x, list) and self.combine_field_results(x)) + or (isinstance(x, dict) and self.combine_field_results(x.values())) for x in xs ) @@ -420,6 +421,8 @@ def combine_field_results(self, xs: Any) -> Set[irast.Set]: for x in xs: if isinstance(x, list): x = self.combine_field_results(x) + if isinstance(x, dict): + x = self.combine_field_results(x.values()) if x: if isinstance(x, set): out.update(x)