Skip to content

Commit

Permalink
fold stencil.cast op with equal input-output types
Browse files Browse the repository at this point in the history
  • Loading branch information
n-io committed Nov 13, 2024
1 parent ff5fefc commit b47bf3d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
12 changes: 11 additions & 1 deletion xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,16 @@ class AllocOp(IRDLOperation):
traits = traits_def(AllocOpEffect())


class CastOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.stencil import (
RemoveCastWithNoEffect,
)

return (RemoveCastWithNoEffect(),)


@irdl_op_definition
class CastOp(IRDLOperation):
"""
Expand Down Expand Up @@ -722,7 +732,7 @@ class CastOp(IRDLOperation):
"$field attr-dict-with-keyword `:` type($field) `->` type($result)"
)

traits = traits_def(NoMemoryEffect())
traits = traits_def(NoMemoryEffect(), CastOpHasCanonicalizationPatternsTrait())

@staticmethod
def get(
Expand Down
11 changes: 11 additions & 0 deletions xdsl/transforms/canonicalization_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,14 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N

rewriter.replace_op(old_return, stencil.ReturnOp.get(return_args))
rewriter.replace_matched_op(new, replace_results)


class RemoveCastWithNoEffect(RewritePattern):
"""
Remove `stencil.cast` where input and output types are equal.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.CastOp, rewriter: PatternRewriter) -> None:
if op.result.type == op.field.type:
rewriter.replace_matched_op([], new_results=[op.field])

0 comments on commit b47bf3d

Please sign in to comment.