Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: (convert-memref-to-ptr) add lower-func flag #3820

Merged
merged 12 commits into from
Feb 2, 2025
Prev Previous commit
Next Next commit
feat: unrealized ptr cast reconciliation
kaylendog committed Feb 1, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit e40ba0010849a4d9e7db44e7b3cd584892fa4ce0
116 changes: 88 additions & 28 deletions xdsl/transforms/convert_memref_to_ptr.py
Original file line number Diff line number Diff line change
@@ -157,7 +157,7 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
@dataclass
class LowerMemrefFuncOpPattern(RewritePattern):
"""
Rewrites function arguments of MemRefType to PtrType - leaves IR in invalid state(?)
Rewrites function arguments of MemRefType to PtrType.
Args:
RewritePattern (_type_): _description_
@@ -202,24 +202,31 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):

@dataclass
class LowerMemrefFuncReturnPattern(RewritePattern):
"""
Rewrites all `memref` arguments to `func.return` into `ptr.PtrType`
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments):
return

new_arguments = [
(arg.owner.inputs[0], arg.owner)
if isinstance(arg.owner, builtin.UnrealizedConversionCastOp)
and isinstance(arg.owner.inputs[0].type, ptr.PtrType)
else (arg, None)
for arg in op.arguments
]
insert_point = InsertPoint.before(op)
new_arguments: list[SSAValue] = []

rewriter.replace_matched_op(func.ReturnOp(*(arg for (arg, _) in new_arguments)))
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[argument], [ptr.PtrType()]
),
insert_point,
)
new_arguments.append(cast_op.results[0])
else:
new_arguments.append(argument)

for _, cast_op in new_arguments:
if cast_op is not None and not cast_op.results[0].uses:
rewriter.erase_op(cast_op)
rewriter.replace_matched_op(func.ReturnOp(*new_arguments))


@dataclass
@@ -231,36 +238,89 @@ def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
) and not any(isinstance(type, memref.MemRefType) for type in op.result_types):
return

new_arguments = [
(arg.owner.inputs[0], arg.owner)
if isinstance(arg.owner, builtin.UnrealizedConversionCastOp)
and isinstance(arg.owner.inputs[0].type, ptr.PtrType)
else (arg, None)
for arg in op.arguments
]
# rewrite arguments
insert_point = InsertPoint.before(op)
new_arguments: list[SSAValue] = []

for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[argument], [ptr.PtrType()]
),
insert_point,
)
new_arguments.append(cast_op.results[0])
else:
new_arguments.append(argument)

insert_point = InsertPoint.after(op)
new_results: list[SSAValue] = []

# rewrite results
for result in op.results:
if isinstance(result.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[result], [ptr.PtrType()]
),
insert_point,
)
new_results.append(cast_op.results[0])
else:
new_results.append(result)

new_return_types = [
ptr.PtrType() if isinstance(type, memref.MemRefType) else type
for type in op.result_types
]

rewriter.replace_matched_op(
func.CallOp(
op.callee, [arg for (arg, _) in new_arguments], new_return_types
)
func.CallOp(op.callee, new_arguments, new_return_types)
)

for _, cast_op in new_arguments:
if cast_op is not None and not cast_op.results[0].uses:
rewriter.erase_op(cast_op)

class ReconcileUnrealizedPtrCasts(RewritePattern):
"""
Eliminates three variants of unrealized ptr casts:
- `ptr_xdsl.ptr -> ptr_xdsl.ptr`;
- `ptr_xdsl.ptr -> memref.MemRef -> ptr_xdsl.ptr`;
- Casts from `ptr_xdsl.ptr` where all uses are `ToPtrOp` operations.
"""

class ReconcilePtrCasts(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: builtin.UnrealizedConversionCastOp, rewriter: PatternRewriter, /
):
if not isinstance(op.inputs[0].type, ptr.PtrType):
# preconditions
if (
len(op.inputs) != 1
or len(op.outputs) != 1
or not isinstance(op.inputs[0].type, ptr.PtrType)
):
return

# erase ptr -> ptr casts
if isinstance(op.outputs[0].type, ptr.PtrType):
op.outputs[0].replace_by(op.inputs[0])
rewriter.erase_matched_op()
return

if not isinstance(op.outputs[0].type, memref.MemRefType):
return

# erase ptr -> memref -> ptr cast pairs
uses = [use for use in op.outputs[0].uses]
for use in uses:
if (
isinstance(use.operation, builtin.UnrealizedConversionCastOp)
and isinstance(use.operation.inputs[0].type, memref.MemRefType)
and isinstance(use.operation.outputs[0].type, ptr.PtrType)
):
use.operation.outputs[0].replace_by(op.inputs[0])
rewriter.erase_op(use.operation)

# erase this cast entirely if all uses are ToPtrOp
cast_ops = [use.operation for use in op.outputs[0].uses]
if not all(isinstance(op, ptr.ToPtrOp) for op in cast_ops):
return
@@ -290,7 +350,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
LowerMemrefFuncOpPattern(),
LowerMemrefFuncCallPattern(),
LowerMemrefFuncReturnPattern(),
ReconcilePtrCasts(),
ReconcileUnrealizedPtrCasts(),
]
)
).rewrite_module(op)