Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
core: Deprecate methods that should be using BlockInsertPoint
Browse files Browse the repository at this point in the history
stack-info: PR: #3705, branch: math-fehr/stack/9
math-fehr committed Jan 20, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 97e1eb6 commit a886efb
Showing 8 changed files with 74 additions and 32 deletions.
14 changes: 9 additions & 5 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
attr_type_rewrite_pattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.rewriter import BlockInsertPoint, InsertPoint


def rewrite_and_compare(
@@ -1225,7 +1225,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if op.parent is None:
return

rewriter.inline_region_before(op.regions[0], op.parent)
rewriter.inline_region(op.regions[0], BlockInsertPoint.before(op.parent))
rewriter.erase_matched_op()

rewrite_and_compare(
@@ -1272,7 +1272,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if op.parent is None:
return

rewriter.inline_region_after(op.regions[0], op.parent)
rewriter.inline_region(op.regions[0], BlockInsertPoint.after(op.parent))
rewriter.erase_matched_op()

rewrite_and_compare(
@@ -1320,7 +1320,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if parent_region is None:
return

rewriter.inline_region_at_start(op.regions[0], parent_region)
rewriter.inline_region(
op.regions[0], BlockInsertPoint.at_start(parent_region)
)
rewriter.erase_matched_op()

rewrite_and_compare(
@@ -1368,7 +1370,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
if parent_region is None:
return

rewriter.inline_region_at_end(op.regions[0], parent_region)
rewriter.inline_region(
op.regions[0], BlockInsertPoint.at_end(parent_region)
)
rewriter.erase_matched_op()

rewrite_and_compare(
16 changes: 8 additions & 8 deletions tests/test_op_builder.py
Original file line number Diff line number Diff line change
@@ -107,29 +107,29 @@ def test_builder_create_block():
target = Region([block1, block2])
builder = Builder(InsertPoint.at_start(block1))

new_block1 = builder.create_block_at_start(target, (i32,))
new_block1 = builder.create_block(BlockInsertPoint.at_start(target), (i32,))
assert len(new_block1.args) == 1
assert new_block1.args[0].type == i32
assert len(target.blocks) == 3
assert target.blocks[0] == new_block1
assert builder.insertion_point == InsertPoint.at_start(new_block1)

new_block2 = builder.create_block_at_end(target, (i64,))
new_block2 = builder.create_block(BlockInsertPoint.at_end(target), (i64,))
assert len(new_block2.args) == 1
assert new_block2.args[0].type == i64
assert len(target.blocks) == 4
assert target.blocks[3] == new_block2
assert builder.insertion_point == InsertPoint.at_start(new_block2)

new_block3 = builder.create_block_before(block2, (i32, i64))
new_block3 = builder.create_block(BlockInsertPoint.before(block2), (i32, i64))
assert len(new_block3.args) == 2
assert new_block3.args[0].type == i32
assert new_block3.args[1].type == i64
assert len(target.blocks) == 5
assert target.blocks[2] == new_block3
assert builder.insertion_point == InsertPoint.at_start(new_block3)

new_block4 = builder.create_block_after(block2, (i64, i32))
new_block4 = builder.create_block(BlockInsertPoint.after(block2), (i64, i32))
assert len(new_block4.args) == 2
assert new_block4.args[0].type == i64
assert new_block4.args[1].type == i32
@@ -173,10 +173,10 @@ def add_block_on_create(b: Block):

b.block_creation_handler = [add_block_on_create]

b1 = b.create_block_at_start(region)
b2 = b.create_block_at_end(region)
b3 = b.create_block_before(block)
b4 = b.create_block_after(block)
b1 = b.create_block(BlockInsertPoint.at_start(region))
b2 = b.create_block(BlockInsertPoint.at_end(region))
b3 = b.create_block(BlockInsertPoint.before(block))
b4 = b.create_block(BlockInsertPoint.after(block))

assert created_blocks == [b1, b2, b3, b4]

26 changes: 15 additions & 11 deletions tests/test_rewriter.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from xdsl.dialects.builtin import Builtin, Float32Type, Float64Type, ModuleOp, i32, i64
from xdsl.ir import Block, Region
from xdsl.parser import Parser
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter


def rewrite_and_compare(
@@ -289,7 +289,9 @@ def test_insert_block_before():
"""

def insert_empty_block_before(module: ModuleOp, rewriter: Rewriter) -> None:
rewriter.insert_block_before(Block(), module.regions[0].blocks[0])
rewriter.insert_block(
Block(), BlockInsertPoint.before(module.regions[0].blocks[0])
)

rewrite_and_compare(prog, expected, insert_empty_block_before)

@@ -312,7 +314,9 @@ def test_insert_block_after():
"""

def insert_empty_block_after(module: ModuleOp, rewriter: Rewriter) -> None:
rewriter.insert_block_after(Block(), module.regions[0].blocks[0])
rewriter.insert_block(
Block(), BlockInsertPoint.after(module.regions[0].blocks[0])
)

rewrite_and_compare(prog, expected, insert_empty_block_after)

@@ -510,7 +514,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_before(region, module.body.blocks[1])
rewriter.inline_region(region, BlockInsertPoint.before(module.body.blocks[1]))

rewrite_and_compare(prog, expected, transformation)

@@ -544,7 +548,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_after(region, module.body.blocks[0])
rewriter.inline_region(region, BlockInsertPoint.after(module.body.blocks[0]))

rewrite_and_compare(prog, expected, transformation)

@@ -578,7 +582,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_at_start(region, module.body)
rewriter.inline_region(region, BlockInsertPoint.at_start(module.body))

rewrite_and_compare(prog, expected, transformation)

@@ -612,7 +616,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
Block((test.TestOp(result_types=(Float64Type(),)),)),
)
)
rewriter.inline_region_at_end(region, module.body)
rewriter.inline_region(region, BlockInsertPoint.at_end(module.body))

rewrite_and_compare(prog, expected, transformation)

@@ -621,13 +625,13 @@ def test_verify_inline_region():
region = Region(Block())

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_before(region, region.block)
Rewriter.inline_region(region, BlockInsertPoint.before(region.block))

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_after(region, region.block)
Rewriter.inline_region(region, BlockInsertPoint.after(region.block))

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_at_start(region, region)
Rewriter.inline_region(region, BlockInsertPoint.at_start(region))

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_at_end(region, region)
Rewriter.inline_region(region, BlockInsertPoint.at_end(region))
4 changes: 2 additions & 2 deletions xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.rewriter import BlockInsertPoint, InsertPoint


class LowerRiscvScfForPattern(RewritePattern):
@@ -119,7 +119,7 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
),
)

rewriter.inline_region_before(op.body, end_block)
rewriter.inline_region(op.body, BlockInsertPoint.before(end_block))

# Move lb to new register to initialize the iv.
# Skip for loop if condition is not satisfied at start.
14 changes: 13 additions & 1 deletion xdsl/builder.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,8 @@
from types import TracebackType
from typing import ClassVar, TypeAlias, overload

from typing_extensions import deprecated

from xdsl.dialects.builtin import ArrayAttr
from xdsl.ir import Attribute, Block, BlockArgument, Operation, OperationInvT, Region
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter
@@ -75,7 +77,7 @@ def insert(self, op: OperationInvT) -> OperationInvT:
return op

def create_block(
self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute]
self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute] = ()
) -> Block:
"""
Create a block at the given location, and set the operation insertion point
@@ -89,6 +91,9 @@ def create_block(
self.handle_block_creation(block)
return block

@deprecated(
"Use create_block(BlockInsertPoint.before(insert_before), arg_types) instead"
)
def create_block_before(
self, insert_before: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
@@ -98,6 +103,9 @@ def create_block_before(
"""
return self.create_block(BlockInsertPoint.before(insert_before), arg_types)

@deprecated(
"Use create_block(BlockInsertPoint.after(insert_after), arg_types) instead"
)
def create_block_after(
self, insert_after: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
@@ -107,6 +115,9 @@ def create_block_after(
"""
return self.create_block(BlockInsertPoint.after(insert_after), arg_types)

@deprecated(
"Use create_block(BlockInsertPoint.at_start(region), arg_types) instead"
)
def create_block_at_start(
self, region: Region, arg_types: Iterable[Attribute] = ()
) -> Block:
@@ -116,6 +127,7 @@ def create_block_at_start(
"""
return self.create_block(BlockInsertPoint.at_start(region), arg_types)

@deprecated("Use create_block(BlockInsertPoint.at_end(region), arg_types) instead")
def create_block_at_end(
self, region: Region, arg_types: Iterable[Attribute] = ()
) -> Block:
12 changes: 12 additions & 0 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
@@ -356,18 +356,30 @@ def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> No
self.has_done_action = True
Rewriter.inline_region(region, insertion_point)

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.before(target))` instead"
)
def inline_region_before(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.before(target))

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.after(target))` instead"
)
def inline_region_after(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.after(target))

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.at_start(target))` instead"
)
def inline_region_at_start(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.at_start(target))

@deprecated(
"Please use `inline_region(region, BlockInsertPoint.at_end(target))` instead"
)
def inline_region_at_end(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.inline_region(region, BlockInsertPoint.at_end(target))
10 changes: 10 additions & 0 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field

from typing_extensions import deprecated

from xdsl.ir import Block, Operation, Region, SSAValue


@@ -240,6 +242,7 @@ def insert_block(block: Block | Iterable[Block], insert_point: BlockInsertPoint)
else:
region.add_block(block)

@deprecated("Use `insert_block(block, BlockInsertPoint.after(target))` instead")
@staticmethod
def insert_block_after(block: Block | list[Block], target: Block):
"""
@@ -249,6 +252,7 @@ def insert_block_after(block: Block | list[Block], target: Block):
"""
Rewriter.insert_block(block, BlockInsertPoint.after(target))

@deprecated("Use `insert_block(block, BlockInsertPoint.before(target))` instead")
@staticmethod
def insert_block_before(block: Block | list[Block], target: Block):
"""
@@ -284,21 +288,27 @@ def inline_region(region: Region, insertion_point: BlockInsertPoint) -> None:
else:
region.move_blocks(insertion_point.region)

@deprecated("Use `inline_region(region, BlockInsertPoint.before(target))` instead")
@staticmethod
def inline_region_before(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, before `target`."""
Rewriter.inline_region(region, BlockInsertPoint.before(target))

@deprecated("Use `inline_region(region, BlockInsertPoint.after(target))` instead")
@staticmethod
def inline_region_after(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, after `target`."""
Rewriter.inline_region(region, BlockInsertPoint.after(target))

@deprecated(
"Use `inline_region(region, BlockInsertPoint.at_start(target))` instead"
)
@staticmethod
def inline_region_at_start(region: Region, target: Region) -> None:
"""Move the region blocks to the start of an existing region."""
Rewriter.inline_region(region, BlockInsertPoint.at_start(target))

@deprecated("Use `inline_region(region, BlockInsertPoint.at_end(target))` instead")
@staticmethod
def inline_region_at_end(region: Region, target: Region) -> None:
"""Move the region blocks to the end of an existing region."""
10 changes: 5 additions & 5 deletions xdsl/transforms/convert_scf_to_cf.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.rewriter import BlockInsertPoint, InsertPoint
from xdsl.traits import IsTerminator


@@ -60,7 +60,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /):
)

rewriter.erase_op(then_terminator)
rewriter.inline_region_before(then_region, continue_block)
rewriter.inline_region(then_region, BlockInsertPoint.before(continue_block))

# Move blocks from the "else" region (if present) to the region containing
# 'scf.if', place it before the continuation block and branch to it. It
@@ -78,7 +78,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /):
)

rewriter.erase_op(else_terminator)
rewriter.inline_region_before(else_region, continue_block)
rewriter.inline_region(else_region, BlockInsertPoint.before(continue_block))
else:
else_block = continue_block

@@ -116,7 +116,7 @@ def match_and_rewrite(self, for_op: ForOp, rewriter: PatternRewriter):
first_body_block = condition_block.split_before(first_op)
last_body_block = for_op.body.last_block
assert last_body_block is not None
rewriter.inline_region_before(for_op.body, end_block)
rewriter.inline_region(for_op.body, BlockInsertPoint.before(end_block))
iv = condition_block.args[0]

# Append the induction variable stepping logic to the last body block and
@@ -169,7 +169,7 @@ def _convert_region(
rewriter.replace_op(yield_op, BranchOp(continue_block, *yield_op.operands))

# Inline the region
rewriter.inline_region_before(region, continue_block)
rewriter.inline_region(region, BlockInsertPoint.before(continue_block))
return block

@op_type_rewrite_pattern

0 comments on commit a886efb

Please sign in to comment.