Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Mar 2, 2025
1 parent 5136adc commit 9653694
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/gt4py/eve/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,16 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
def _preserve_annex(
node: concepts.Node, new_node: concepts.Node, preserved_annex_attrs: tuple[str, ...]
) -> None:
if preserved_annex_attrs and (old_annex := getattr(node, "__node_annex__", None)):
# access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter
new_annex_dict = new_node.annex.__dict__
for key in preserved_annex_attrs:
if (value := getattr(old_annex, key, NOTHING)) is not NOTHING:
# Note: The annex value of the new node might not be equal
# (in the sense that an equality comparison returns false),
# but in the context of the pass, they are equivalent.
# Therefore, we don't assert equality here.
new_annex_dict[key] = value
# access to `new_node.annex` implicitly creates the `__node_annex__` attribute in the property getter
old_annex = node.annex
new_annex_dict = new_node.annex.__dict__
for key in preserved_annex_attrs:
if (value := getattr(old_annex, key, NOTHING)) is not NOTHING:
# Note: The annex value of the new node might not be equal
# (in the sense that an equality comparison returns false),
# but in the context of the pass, they are equivalent.
# Therefore, we don't assert equality here.
new_annex_dict[key] = value


class NodeTranslator(NodeVisitor):
Expand Down Expand Up @@ -173,7 +173,8 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
if (new_child := self.visit(child, **kwargs)) is not NOTHING
}
)
_preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS)
if self.PRESERVED_ANNEX_ATTRS and (old_annex := getattr(node, "__node_annex__", None)):
_preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS)

return new_node

Expand Down Expand Up @@ -202,7 +203,12 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
new_node = super().visit(node, **kwargs)

if isinstance(node, concepts.Node) and isinstance(new_node, concepts.Node):
if (
isinstance(node, concepts.Node)
and isinstance(new_node, concepts.Node)
and self.PRESERVED_ANNEX_ATTRS
and (old_annex := getattr(node, "__node_annex__", None))
):
_preserve_annex(node, new_node, self.PRESERVED_ANNEX_ATTRS)

return new_node

0 comments on commit 9653694

Please sign in to comment.