Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: (convert-memref-to-ptr) add lower-func flag #3820

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
42 changes: 42 additions & 0 deletions tests/filecheck/transforms/convert_memref_args_to_ptr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: xdsl-opt -p convert-memref-to-ptr{lower_func=true} --split-input-file --verify-diagnostics %s | filecheck %s

// CHECK: builtin.module {

// CHECK-NEXT: func.func @declaration(!ptr_xdsl.ptr) -> ()
func.func @declaration(%arg : memref<2x2xf32>)


// CHECK-NEXT: func.func @simple(%arg : !ptr_xdsl.ptr) {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @simple(%arg : memref<2x2xf32>) {
func.return
}

// CHECK-NEXT: func.func @id(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr {
// CHECK-NEXT: func.return %arg : !ptr_xdsl.ptr
// CHECK-NEXT: }
func.func @id(%arg : memref<2x2xf32>) -> memref<2x2xf32> {
func.return %arg : memref<2x2xf32>
}

// CHECK-NEXT: func.func @id2(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr {
// CHECK-NEXT: %res = func.call @id(%arg) : (!ptr_xdsl.ptr) -> !ptr_xdsl.ptr
// CHECK-NEXT: func.return %res : !ptr_xdsl.ptr
// CHECK-NEXT: }
func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> {
%res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32>
func.return %res : memref<2x2xf32>
}

// CHECK-NEXT: func.func @first(%arg : !ptr_xdsl.ptr) -> f32 {
// CHECK-NEXT: %res = ptr_xdsl.load %arg : !ptr_xdsl.ptr -> f32
// CHECK-NEXT: func.return %res : f32
// CHECK-NEXT: }
func.func @first(%arg : memref<2x2xf32>) -> f32 {
%pointer = ptr_xdsl.to_ptr %arg : memref<2x2xf32> -> !ptr_xdsl.ptr
%res = ptr_xdsl.load %pointer : !ptr_xdsl.ptr -> f32
func.return %res : f32
}

// CHECK-NEXT: }
196 changes: 191 additions & 5 deletions xdsl/transforms/convert_memref_to_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, memref, ptr
from xdsl.ir import Operation, SSAValue
from xdsl.dialects import arith, builtin, func, memref, ptr
from xdsl.ir import Attribute, Operation, SSAValue
from xdsl.irdl import Any
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
Expand All @@ -14,6 +14,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.exceptions import DiagnosticException


Expand Down Expand Up @@ -153,12 +154,197 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(ops, new_results=[load_result.res])


@dataclass
class LowerMemrefFuncOpPattern(RewritePattern):
"""
Rewrites function arguments of MemRefType to PtrType.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# rewrite function declaration
new_input_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.inputs
]
new_output_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.outputs
]
op.function_type = func.FunctionType.from_lists(
new_input_types,
new_output_types,
)

if op.is_declaration:
return

insert_point = InsertPoint.at_start(op.body.blocks[0])

# rewrite arguments
for arg in op.args:
if not isinstance(arg_type := arg.type, memref.MemRefType):
continue

old_type = cast(memref.MemRefType[Attribute], arg_type)
arg.type = ptr.PtrType()

if not arg.uses:
continue

rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get([arg], [old_type]),
insert_point,
)
arg.replace_by_if(cast_op.results[0], lambda x: x.operation is not cast_op)


@dataclass
class LowerMemrefFuncReturnPattern(RewritePattern):
"""
Rewrites all `memref` arguments to `func.return` into `ptr.PtrType`
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments):
return

insert_point = InsertPoint.before(op)
new_arguments: list[SSAValue] = []

# insert `memref -> ptr` casts for memref return values
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[argument], [ptr.PtrType()]
),
insert_point,
)
new_arguments.append(cast_op.results[0])
else:
new_arguments.append(argument)

rewriter.replace_matched_op(func.ReturnOp(*new_arguments))


@dataclass
class LowerMemrefFuncCallPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
if not any(
isinstance(arg.type, memref.MemRefType) for arg in op.arguments
) and not any(isinstance(type, memref.MemRefType) for type in op.result_types):
return

# rewrite arguments
insert_point = InsertPoint.before(op)
new_arguments: list[SSAValue] = []

# insert `memref -> ptr` casts for memref arguments values
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[argument], [ptr.PtrType()]
),
insert_point,
)
new_arguments.append(cast_op.results[0])
else:
new_arguments.append(argument)

insert_point = InsertPoint.after(op)
new_results: list[SSAValue] = []

# insert `ptr -> memref` casts for return values
for result in op.results:
if isinstance(result.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[result],
# TODO: annoying pyright warnings - Sasha, pls help
[result.type], # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
),
insert_point,
)
new_results.append(cast_op.results[0])
else:
new_results.append(result)

new_return_types = [
ptr.PtrType() if isinstance(type, memref.MemRefType) else type
for type in op.result_types
]

rewriter.replace_matched_op(
func.CallOp(op.callee, new_arguments, new_return_types)
)


class ReconcileUnrealizedPtrCasts(RewritePattern):
"""
Eliminates two variants of unrealized ptr casts:
- `ptr_xdsl.ptr -> memref.MemRef -> ptr_xdsl.ptr`;
- `ptr_xdsl.ptr -> memref.memref` where all uses are `ToPtrOp` operations.
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: builtin.UnrealizedConversionCastOp, rewriter: PatternRewriter, /
):
# preconditions
if (
len(op.inputs) != 1
or len(op.outputs) != 1
or not isinstance(op.inputs[0].type, ptr.PtrType)
or not not isinstance(op.outputs[0].type, memref.MemRefType)
):
return

# erase ptr -> memref -> ptr cast pairs
uses = (use for use in op.outputs[0].uses)
for use in uses:
if (
isinstance(use.operation, builtin.UnrealizedConversionCastOp)
and isinstance(use.operation.inputs[0].type, memref.MemRefType)
and isinstance(use.operation.outputs[0].type, ptr.PtrType)
):
use.operation.outputs[0].replace_by(op.inputs[0])
rewriter.erase_op(use.operation)

# erase this cast entirely if all remaining uses are by ToPtr operations
cast_ops = [use.operation for use in op.outputs[0].uses]
if not all(isinstance(op, ptr.ToPtrOp) for op in cast_ops):
return

for cast_op in cast_ops:
cast_op.results[0].replace_by(op.inputs[0])
rewriter.erase_op(cast_op)

rewriter.erase_op(op)


@dataclass(frozen=True)
class ConvertMemrefToPtr(ModulePass):
name = "convert-memref-to-ptr"

lower_func: bool = False

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
the_one_pass = PatternRewriteWalker(
PatternRewriteWalker(
GreedyRewritePatternApplier([ConvertStoreOp(), ConvertLoadOp()])
)
the_one_pass.rewrite_module(op)
).rewrite_module(op)

if self.lower_func:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
LowerMemrefFuncOpPattern(),
LowerMemrefFuncCallPattern(),
LowerMemrefFuncReturnPattern(),
ReconcileUnrealizedPtrCasts(),
]
)
).rewrite_module(op)
Loading