From 2c35423e620f130f52daa3d6e9c3f5bb078e182d Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 13:05:50 +0000 Subject: [PATCH 1/9] transforms: (convert_memref_to_ptr) add func arg rewrite option --- .../convert_memref_args_to_ptr.mlir | 22 ++++ xdsl/transforms/convert_memref_to_ptr.py | 105 +++++++++++++++++- 2 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 tests/filecheck/transforms/convert_memref_args_to_ptr.mlir diff --git a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir new file mode 100644 index 0000000000..39f5eae76b --- /dev/null +++ b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir @@ -0,0 +1,22 @@ +// RUN: xdsl-opt -p convert-memref-to-ptr{convert_func_args=true} --split-input-file --verify-diagnostics %s | filecheck %s + +func.func @declaration(%arg : memref<2x2xf32>) + +func.func @simple(%arg : memref<2x2xf32>) { + func.return +} + +func.func @id(%arg : memref<2x2xf32>) -> memref<2x2xf32> { + func.return %arg : memref<2x2xf32> +} + +func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> { + %res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32> + func.return %res : memref<2x2xf32> +} + +func.func @first(%arg : memref<2x2xf32>) -> f32 { + %pointer = ptr_xdsl.to_ptr %arg : memref<2x2xf32> -> !ptr_xdsl.ptr + %res = ptr_xdsl.load %pointer : !ptr_xdsl.ptr -> f32 + func.return %res : f32 +} diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py index cb1f37a76c..174d8e8827 100644 --- a/xdsl/transforms/convert_memref_to_ptr.py +++ b/xdsl/transforms/convert_memref_to_ptr.py @@ -3,8 +3,9 @@ from typing import cast from xdsl.context import MLContext -from xdsl.dialects import arith, builtin, memref, ptr +from xdsl.dialects import arith, builtin, memref, ptr, func from xdsl.ir import Operation, SSAValue +from xdsl.ir.core import Attribute from xdsl.irdl import Any from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -14,6 +15,7 @@ RewritePattern, op_type_rewrite_pattern, ) +from xdsl.rewriter import InsertPoint from xdsl.utils.exceptions import DiagnosticException @@ -153,12 +155,107 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op(ops, new_results=[load_result.res]) +@dataclass +class LowerMemrefFuncArgsPattern(RewritePattern): + """ + Rewrites function arguments of MemRefType to PtrType - leaves IR in invalid state(?) + + Args: + RewritePattern (_type_): _description_ + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): + new_input_types = [ + ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg + for arg in op.function_type.inputs + ] + new_output_types = [ + ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg + for arg in op.function_type.outputs + ] + op.function_type = func.FunctionType.from_lists( + new_input_types, + new_output_types, + ) + + if op.is_declaration: + return + + insert_point = InsertPoint.at_start(op.body.blocks[0]) + + for arg in op.args: + if isinstance(arg_type := arg.type, memref.MemRefType): + old_type = cast(memref.MemRefType[Attribute], arg_type) + arg.type = ptr.PtrType() + + if not arg.uses: + continue + + rewriter.insert_op( + cast_op := builtin.UnrealizedConversionCastOp.get( + [arg], [old_type] + ), + insert_point, + ) + arg.replace_by_if(cast_op.results[0], lambda x: x.operation != cast_op) + + +@dataclass +class LowerMemrefReturnPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /): + if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments): + return + + new_arguments = [ + (arg.owner.inputs[0], arg.owner) + if isinstance(arg.owner, builtin.UnrealizedConversionCastOp) + and isinstance(arg.owner.inputs[0].type, ptr.PtrType) + else (arg, None) + for arg in op.arguments + ] + + rewriter.replace_matched_op(func.ReturnOp(*(arg for (arg, _) in new_arguments))) + + for _, cast_op in new_arguments: + if cast_op is not None and not cast_op.results[0].uses: + rewriter.erase_op(cast_op) + + +@dataclass +class LowerMemrefCallArgsPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /): + pass + + +@dataclass +class LowerMemrefToPtrPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: builtin.ModuleOp, rewriter: PatternRewriter, /): + pass + + @dataclass(frozen=True) class ConvertMemrefToPtr(ModulePass): name = "convert-memref-to-ptr" + convert_func_args: bool = False + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - the_one_pass = PatternRewriteWalker( + PatternRewriteWalker( GreedyRewritePatternApplier([ConvertStoreOp(), ConvertLoadOp()]) - ) - the_one_pass.rewrite_module(op) + ).rewrite_module(op) + + if self.convert_func_args: + PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + LowerMemrefFuncArgsPattern(), + LowerMemrefCallArgsPattern(), + LowerMemrefToPtrPattern(), + LowerMemrefReturnPattern(), + ] + ) + ).rewrite_module(op) From 82b192fb57ae86ab4bffc4ec7844195c2a939343 Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 13:09:01 +0000 Subject: [PATCH 2/9] transforms: (convert_memref_to_ptr) add filecheck tests --- .../convert_memref_args_to_ptr.mlir | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir index 39f5eae76b..c29e2abb83 100644 --- a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir +++ b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir @@ -1,22 +1,39 @@ // RUN: xdsl-opt -p convert-memref-to-ptr{convert_func_args=true} --split-input-file --verify-diagnostics %s | filecheck %s +// CHECK: builtin.module { + +// CHECK-NEXT: func.func @declaration(!ptr_xdsl.ptr) -> () func.func @declaration(%arg : memref<2x2xf32>) + +// CHECK-NEXT: func.func @simple(%arg : !ptr_xdsl.ptr) { +// CHECK-NEXT: func.return +// CHECK-NEXT: } func.func @simple(%arg : memref<2x2xf32>) { func.return } +// CHECK-NEXT: func.func @id(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr { +// CHECK-NEXT: func.return %arg : !ptr_xdsl.ptr +// CHECK-NEXT: } func.func @id(%arg : memref<2x2xf32>) -> memref<2x2xf32> { func.return %arg : memref<2x2xf32> } -func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> { - %res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32> - func.return %res : memref<2x2xf32> -} +// func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> { +// %res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32> +// func.return %res : memref<2x2xf32> +// } + +// CHECK-NEXT: func.func @first(%arg : !ptr_xdsl.ptr) -> f32 { +// CHECK-NEXT: %res = ptr_xdsl.load %arg : !ptr_xdsl.ptr -> f32 +// CHECK-NEXT: func.return %res : f32 +// CHECK-NEXT: } func.func @first(%arg : memref<2x2xf32>) -> f32 { %pointer = ptr_xdsl.to_ptr %arg : memref<2x2xf32> -> !ptr_xdsl.ptr %res = ptr_xdsl.load %pointer : !ptr_xdsl.ptr -> f32 func.return %res : f32 } + +// CHECK-NEXT: } From 5bd96859337dcb3f752a7394dd6d1835965280f4 Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 15:15:46 +0000 Subject: [PATCH 3/9] feat: add reconcile ptr casts --- .../convert_memref_args_to_ptr.mlir | 8 +- xdsl/transforms/convert_memref_to_ptr.py | 93 +++++++++++++------ 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir index c29e2abb83..62a21564fc 100644 --- a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir +++ b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir @@ -20,10 +20,10 @@ func.func @id(%arg : memref<2x2xf32>) -> memref<2x2xf32> { func.return %arg : memref<2x2xf32> } -// func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> { -// %res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32> -// func.return %res : memref<2x2xf32> -// } +func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> { + %res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32> + func.return %res : memref<2x2xf32> +} // CHECK-NEXT: func.func @first(%arg : !ptr_xdsl.ptr) -> f32 { diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py index 174d8e8827..43220a3d45 100644 --- a/xdsl/transforms/convert_memref_to_ptr.py +++ b/xdsl/transforms/convert_memref_to_ptr.py @@ -3,9 +3,8 @@ from typing import cast from xdsl.context import MLContext -from xdsl.dialects import arith, builtin, memref, ptr, func -from xdsl.ir import Operation, SSAValue -from xdsl.ir.core import Attribute +from xdsl.dialects import arith, builtin, func, memref, ptr +from xdsl.ir import Attribute, Operation, SSAValue from xdsl.irdl import Any from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -156,7 +155,7 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /): @dataclass -class LowerMemrefFuncArgsPattern(RewritePattern): +class LowerMemrefFuncOpPattern(RewritePattern): """ Rewrites function arguments of MemRefType to PtrType - leaves IR in invalid state(?) @@ -185,24 +184,24 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): insert_point = InsertPoint.at_start(op.body.blocks[0]) for arg in op.args: - if isinstance(arg_type := arg.type, memref.MemRefType): - old_type = cast(memref.MemRefType[Attribute], arg_type) - arg.type = ptr.PtrType() - - if not arg.uses: - continue - - rewriter.insert_op( - cast_op := builtin.UnrealizedConversionCastOp.get( - [arg], [old_type] - ), - insert_point, - ) - arg.replace_by_if(cast_op.results[0], lambda x: x.operation != cast_op) + if not isinstance(arg_type := arg.type, memref.MemRefType): + continue + + old_type = cast(memref.MemRefType[Attribute], arg_type) + arg.type = ptr.PtrType() + + if not arg.uses: + continue + + rewriter.insert_op( + cast_op := builtin.UnrealizedConversionCastOp.get([arg], [old_type]), + insert_point, + ) + arg.replace_by_if(cast_op.results[0], lambda x: x.operation != cast_op) @dataclass -class LowerMemrefReturnPattern(RewritePattern): +class LowerMemrefFuncReturnPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /): if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments): @@ -224,17 +223,53 @@ def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /): @dataclass -class LowerMemrefCallArgsPattern(RewritePattern): +class LowerMemrefFuncCallPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /): - pass + if not any( + isinstance(arg.type, memref.MemRefType) for arg in op.arguments + ) and not any(isinstance(type, memref.MemRefType) for type in op.result_types): + return + new_arguments = [ + (arg.owner.inputs[0], arg.owner) + if isinstance(arg.owner, builtin.UnrealizedConversionCastOp) + and isinstance(arg.owner.inputs[0].type, ptr.PtrType) + else (arg, None) + for arg in op.arguments + ] + new_return_types = [ + ptr.PtrType() if isinstance(type, memref.MemRefType) else type + for type in op.result_types + ] + rewriter.replace_matched_op( + func.CallOp( + op.callee, [arg for (arg, _) in new_arguments], new_return_types + ) + ) -@dataclass -class LowerMemrefToPtrPattern(RewritePattern): + for _, cast_op in new_arguments: + if cast_op is not None and not cast_op.results[0].uses: + rewriter.erase_op(cast_op) + + +class ReconcilePtrCasts(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: builtin.ModuleOp, rewriter: PatternRewriter, /): - pass + def match_and_rewrite( + self, op: builtin.UnrealizedConversionCastOp, rewriter: PatternRewriter, / + ): + if not isinstance(op.inputs[0].type, ptr.PtrType): + return + + cast_ops = [use.operation for use in op.outputs[0].uses] + if not all(isinstance(op, ptr.ToPtrOp) for op in cast_ops): + return + + for cast_op in cast_ops: + cast_op.results[0].replace_by(op.inputs[0]) + rewriter.erase_op(cast_op) + + rewriter.erase_op(op) @dataclass(frozen=True) @@ -252,10 +287,10 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: PatternRewriteWalker( GreedyRewritePatternApplier( [ - LowerMemrefFuncArgsPattern(), - LowerMemrefCallArgsPattern(), - LowerMemrefToPtrPattern(), - LowerMemrefReturnPattern(), + LowerMemrefFuncOpPattern(), + LowerMemrefFuncCallPattern(), + LowerMemrefFuncReturnPattern(), + ReconcilePtrCasts(), ] ) ).rewrite_module(op) From e40ba0010849a4d9e7db44e7b3cd584892fa4ce0 Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 15:54:26 +0000 Subject: [PATCH 4/9] feat: unrealized ptr cast reconciliation --- xdsl/transforms/convert_memref_to_ptr.py | 116 +++++++++++++++++------ 1 file changed, 88 insertions(+), 28 deletions(-) diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py index 43220a3d45..d46582c921 100644 --- a/xdsl/transforms/convert_memref_to_ptr.py +++ b/xdsl/transforms/convert_memref_to_ptr.py @@ -157,7 +157,7 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /): @dataclass class LowerMemrefFuncOpPattern(RewritePattern): """ - Rewrites function arguments of MemRefType to PtrType - leaves IR in invalid state(?) + Rewrites function arguments of MemRefType to PtrType. Args: RewritePattern (_type_): _description_ @@ -202,24 +202,31 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): @dataclass class LowerMemrefFuncReturnPattern(RewritePattern): + """ + Rewrites all `memref` arguments to `func.return` into `ptr.PtrType` + """ + @op_type_rewrite_pattern def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /): if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments): return - new_arguments = [ - (arg.owner.inputs[0], arg.owner) - if isinstance(arg.owner, builtin.UnrealizedConversionCastOp) - and isinstance(arg.owner.inputs[0].type, ptr.PtrType) - else (arg, None) - for arg in op.arguments - ] + insert_point = InsertPoint.before(op) + new_arguments: list[SSAValue] = [] - rewriter.replace_matched_op(func.ReturnOp(*(arg for (arg, _) in new_arguments))) + for argument in op.arguments: + if isinstance(argument.type, memref.MemRefType): + rewriter.insert_op( + cast_op := builtin.UnrealizedConversionCastOp.get( + [argument], [ptr.PtrType()] + ), + insert_point, + ) + new_arguments.append(cast_op.results[0]) + else: + new_arguments.append(argument) - for _, cast_op in new_arguments: - if cast_op is not None and not cast_op.results[0].uses: - rewriter.erase_op(cast_op) + rewriter.replace_matched_op(func.ReturnOp(*new_arguments)) @dataclass @@ -231,36 +238,89 @@ def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /): ) and not any(isinstance(type, memref.MemRefType) for type in op.result_types): return - new_arguments = [ - (arg.owner.inputs[0], arg.owner) - if isinstance(arg.owner, builtin.UnrealizedConversionCastOp) - and isinstance(arg.owner.inputs[0].type, ptr.PtrType) - else (arg, None) - for arg in op.arguments - ] + # rewrite arguments + insert_point = InsertPoint.before(op) + new_arguments: list[SSAValue] = [] + + for argument in op.arguments: + if isinstance(argument.type, memref.MemRefType): + rewriter.insert_op( + cast_op := builtin.UnrealizedConversionCastOp.get( + [argument], [ptr.PtrType()] + ), + insert_point, + ) + new_arguments.append(cast_op.results[0]) + else: + new_arguments.append(argument) + + insert_point = InsertPoint.after(op) + new_results: list[SSAValue] = [] + + # rewrite results + for result in op.results: + if isinstance(result.type, memref.MemRefType): + rewriter.insert_op( + cast_op := builtin.UnrealizedConversionCastOp.get( + [result], [ptr.PtrType()] + ), + insert_point, + ) + new_results.append(cast_op.results[0]) + else: + new_results.append(result) + new_return_types = [ ptr.PtrType() if isinstance(type, memref.MemRefType) else type for type in op.result_types ] + rewriter.replace_matched_op( - func.CallOp( - op.callee, [arg for (arg, _) in new_arguments], new_return_types - ) + func.CallOp(op.callee, new_arguments, new_return_types) ) - for _, cast_op in new_arguments: - if cast_op is not None and not cast_op.results[0].uses: - rewriter.erase_op(cast_op) +class ReconcileUnrealizedPtrCasts(RewritePattern): + """ + Eliminates three variants of unrealized ptr casts: + - `ptr_xdsl.ptr -> ptr_xdsl.ptr`; + - `ptr_xdsl.ptr -> memref.MemRef -> ptr_xdsl.ptr`; + - Casts from `ptr_xdsl.ptr` where all uses are `ToPtrOp` operations. + """ -class ReconcilePtrCasts(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite( self, op: builtin.UnrealizedConversionCastOp, rewriter: PatternRewriter, / ): - if not isinstance(op.inputs[0].type, ptr.PtrType): + # preconditions + if ( + len(op.inputs) != 1 + or len(op.outputs) != 1 + or not isinstance(op.inputs[0].type, ptr.PtrType) + ): + return + + # erase ptr -> ptr casts + if isinstance(op.outputs[0].type, ptr.PtrType): + op.outputs[0].replace_by(op.inputs[0]) + rewriter.erase_matched_op() + return + + if not isinstance(op.outputs[0].type, memref.MemRefType): return + # erase ptr -> memref -> ptr cast pairs + uses = [use for use in op.outputs[0].uses] + for use in uses: + if ( + isinstance(use.operation, builtin.UnrealizedConversionCastOp) + and isinstance(use.operation.inputs[0].type, memref.MemRefType) + and isinstance(use.operation.outputs[0].type, ptr.PtrType) + ): + use.operation.outputs[0].replace_by(op.inputs[0]) + rewriter.erase_op(use.operation) + + # erase this cast entirely if all uses are ToPtrOp cast_ops = [use.operation for use in op.outputs[0].uses] if not all(isinstance(op, ptr.ToPtrOp) for op in cast_ops): return @@ -290,7 +350,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: LowerMemrefFuncOpPattern(), LowerMemrefFuncCallPattern(), LowerMemrefFuncReturnPattern(), - ReconcilePtrCasts(), + ReconcileUnrealizedPtrCasts(), ] ) ).rewrite_module(op) From aa649c0df6580abb48105505801851f9067f10be Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 16:04:46 +0000 Subject: [PATCH 5/9] fix: fixup return receiver cast --- xdsl/transforms/convert_memref_to_ptr.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py index d46582c921..7ff9156d3c 100644 --- a/xdsl/transforms/convert_memref_to_ptr.py +++ b/xdsl/transforms/convert_memref_to_ptr.py @@ -165,6 +165,7 @@ class LowerMemrefFuncOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): + # rewrite function declaration new_input_types = [ ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg for arg in op.function_type.inputs @@ -183,6 +184,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): insert_point = InsertPoint.at_start(op.body.blocks[0]) + # rewrite arguments for arg in op.args: if not isinstance(arg_type := arg.type, memref.MemRefType): continue @@ -214,6 +216,7 @@ def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /): insert_point = InsertPoint.before(op) new_arguments: list[SSAValue] = [] + # insert `memref -> ptr` casts for memref return values for argument in op.arguments: if isinstance(argument.type, memref.MemRefType): rewriter.insert_op( @@ -242,6 +245,7 @@ def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /): insert_point = InsertPoint.before(op) new_arguments: list[SSAValue] = [] + # insert `memref -> ptr` casts for memref arguments values for argument in op.arguments: if isinstance(argument.type, memref.MemRefType): rewriter.insert_op( @@ -257,12 +261,14 @@ def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /): insert_point = InsertPoint.after(op) new_results: list[SSAValue] = [] - # rewrite results + # insert `ptr -> memref` casts for return values for result in op.results: if isinstance(result.type, memref.MemRefType): rewriter.insert_op( cast_op := builtin.UnrealizedConversionCastOp.get( - [result], [ptr.PtrType()] + [result], + # TODO: annoying pyright warnings - Sasha, pls help + [result.type], # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType] ), insert_point, ) @@ -282,10 +288,9 @@ def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /): class ReconcileUnrealizedPtrCasts(RewritePattern): """ - Eliminates three variants of unrealized ptr casts: - - `ptr_xdsl.ptr -> ptr_xdsl.ptr`; + Eliminates two variants of unrealized ptr casts: - `ptr_xdsl.ptr -> memref.MemRef -> ptr_xdsl.ptr`; - - Casts from `ptr_xdsl.ptr` where all uses are `ToPtrOp` operations. + - `ptr_xdsl.ptr -> memref.memref` where all uses are `ToPtrOp` operations. """ @op_type_rewrite_pattern @@ -300,12 +305,6 @@ def match_and_rewrite( ): return - # erase ptr -> ptr casts - if isinstance(op.outputs[0].type, ptr.PtrType): - op.outputs[0].replace_by(op.inputs[0]) - rewriter.erase_matched_op() - return - if not isinstance(op.outputs[0].type, memref.MemRefType): return @@ -320,7 +319,7 @@ def match_and_rewrite( use.operation.outputs[0].replace_by(op.inputs[0]) rewriter.erase_op(use.operation) - # erase this cast entirely if all uses are ToPtrOp + # erase this cast entirely if all remaining uses are by ToPtr operations cast_ops = [use.operation for use in op.outputs[0].uses] if not all(isinstance(op, ptr.ToPtrOp) for op in cast_ops): return From f74d8475a13ac29bf415f250ac5704d441fa8ee9 Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 16:10:07 +0000 Subject: [PATCH 6/9] tests: update filecheck to include id2 --- tests/filecheck/transforms/convert_memref_args_to_ptr.mlir | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir index 62a21564fc..04a0dd2034 100644 --- a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir +++ b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir @@ -20,12 +20,15 @@ func.func @id(%arg : memref<2x2xf32>) -> memref<2x2xf32> { func.return %arg : memref<2x2xf32> } +// CHECK-NEXT: func.func @id2(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr { +// CHECK-NEXT: %res = func.call @id(%arg) : (!ptr_xdsl.ptr) -> !ptr_xdsl.ptr +// CHECK-NEXT: func.return %res : !ptr_xdsl.ptr +// CHECK-NEXT: } func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> { %res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32> func.return %res : memref<2x2xf32> } - // CHECK-NEXT: func.func @first(%arg : !ptr_xdsl.ptr) -> f32 { // CHECK-NEXT: %res = ptr_xdsl.load %arg : !ptr_xdsl.ptr -> f32 // CHECK-NEXT: func.return %res : f32 From 2f4d5b7363e9c03c34bc1dd76217105d0a55da71 Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 16:21:52 +0000 Subject: [PATCH 7/9] chore: rename flag --- xdsl/transforms/convert_memref_to_ptr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py index 7ff9156d3c..a6775c14d7 100644 --- a/xdsl/transforms/convert_memref_to_ptr.py +++ b/xdsl/transforms/convert_memref_to_ptr.py @@ -335,14 +335,14 @@ def match_and_rewrite( class ConvertMemrefToPtr(ModulePass): name = "convert-memref-to-ptr" - convert_func_args: bool = False + lower_func: bool = False def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: PatternRewriteWalker( GreedyRewritePatternApplier([ConvertStoreOp(), ConvertLoadOp()]) ).rewrite_module(op) - if self.convert_func_args: + if self.lower_func: PatternRewriteWalker( GreedyRewritePatternApplier( [ From 450073d4d3cc39fb2482683c3b32ad57483fc879 Mon Sep 17 00:00:00 2001 From: kaylendog Date: Sat, 1 Feb 2025 16:42:51 +0000 Subject: [PATCH 8/9] chore: update test --- tests/filecheck/transforms/convert_memref_args_to_ptr.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir index 04a0dd2034..d107e7e89a 100644 --- a/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir +++ b/tests/filecheck/transforms/convert_memref_args_to_ptr.mlir @@ -1,4 +1,4 @@ -// RUN: xdsl-opt -p convert-memref-to-ptr{convert_func_args=true} --split-input-file --verify-diagnostics %s | filecheck %s +// RUN: xdsl-opt -p convert-memref-to-ptr{lower_func=true} --split-input-file --verify-diagnostics %s | filecheck %s // CHECK: builtin.module { From d79167ad9bf722cc05b6ec6e3586f6695ca0e5f4 Mon Sep 17 00:00:00 2001 From: Skye Date: Sat, 1 Feb 2025 19:14:59 +0000 Subject: [PATCH 9/9] fix: address PR comments from sasha --- xdsl/transforms/convert_memref_to_ptr.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py index a6775c14d7..cb13929e27 100644 --- a/xdsl/transforms/convert_memref_to_ptr.py +++ b/xdsl/transforms/convert_memref_to_ptr.py @@ -158,9 +158,6 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /): class LowerMemrefFuncOpPattern(RewritePattern): """ Rewrites function arguments of MemRefType to PtrType. - - Args: - RewritePattern (_type_): _description_ """ @op_type_rewrite_pattern @@ -199,7 +196,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): cast_op := builtin.UnrealizedConversionCastOp.get([arg], [old_type]), insert_point, ) - arg.replace_by_if(cast_op.results[0], lambda x: x.operation != cast_op) + arg.replace_by_if(cast_op.results[0], lambda x: x.operation is not cast_op) @dataclass @@ -302,14 +299,12 @@ def match_and_rewrite( len(op.inputs) != 1 or len(op.outputs) != 1 or not isinstance(op.inputs[0].type, ptr.PtrType) + or not not isinstance(op.outputs[0].type, memref.MemRefType) ): return - if not isinstance(op.outputs[0].type, memref.MemRefType): - return - # erase ptr -> memref -> ptr cast pairs - uses = [use for use in op.outputs[0].uses] + uses = (use for use in op.outputs[0].uses) for use in uses: if ( isinstance(use.operation, builtin.UnrealizedConversionCastOp)