From 0c1e580ec691b3cf8b1b21aaca7703606c675819 Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr@gmail.com>
Date: Mon, 6 Jan 2025 15:24:18 +0000
Subject: [PATCH] core: Deprecate methods that should be using BlockInsertPoint

stack-info: PR: https://github.com/xdslproject/xdsl/pull/3705, branch: math-fehr/stack/9
---
 .../pattern_rewriter/test_pattern_rewriter.py | 14 ++++++----
 tests/test_op_builder.py                      | 16 ++++++------
 tests/test_rewriter.py                        | 26 +++++++++++--------
 .../lowering/convert_riscv_scf_to_riscv_cf.py |  4 +--
 xdsl/builder.py                               | 14 +++++++++-
 xdsl/pattern_rewriter.py                      | 12 +++++++++
 xdsl/rewriter.py                              | 10 +++++++
 xdsl/transforms/convert_scf_to_cf.py          | 10 +++----
 8 files changed, 74 insertions(+), 32 deletions(-)

diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py
index e0ee262daa..bba8c20f96 100644
--- a/tests/pattern_rewriter/test_pattern_rewriter.py
+++ b/tests/pattern_rewriter/test_pattern_rewriter.py
@@ -33,7 +33,7 @@
     attr_type_rewrite_pattern,
     op_type_rewrite_pattern,
 )
-from xdsl.rewriter import InsertPoint
+from xdsl.rewriter import BlockInsertPoint, InsertPoint
 
 
 def rewrite_and_compare(
@@ -1225,7 +1225,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
             if op.parent is None:
                 return
 
-            rewriter.inline_region_before(op.regions[0], op.parent)
+            rewriter.inline_region(op.regions[0], BlockInsertPoint.before(op.parent))
             rewriter.erase_matched_op()
 
     rewrite_and_compare(
@@ -1272,7 +1272,7 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
             if op.parent is None:
                 return
 
-            rewriter.inline_region_after(op.regions[0], op.parent)
+            rewriter.inline_region(op.regions[0], BlockInsertPoint.after(op.parent))
             rewriter.erase_matched_op()
 
     rewrite_and_compare(
@@ -1320,7 +1320,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
             if parent_region is None:
                 return
 
-            rewriter.inline_region_at_start(op.regions[0], parent_region)
+            rewriter.inline_region(
+                op.regions[0], BlockInsertPoint.at_start(parent_region)
+            )
             rewriter.erase_matched_op()
 
     rewrite_and_compare(
@@ -1368,7 +1370,9 @@ def match_and_rewrite(self, op: test.TestOp, rewriter: PatternRewriter):
             if parent_region is None:
                 return
 
-            rewriter.inline_region_at_end(op.regions[0], parent_region)
+            rewriter.inline_region(
+                op.regions[0], BlockInsertPoint.at_end(parent_region)
+            )
             rewriter.erase_matched_op()
 
     rewrite_and_compare(
diff --git a/tests/test_op_builder.py b/tests/test_op_builder.py
index 3c1663b8bd..a407dc339e 100644
--- a/tests/test_op_builder.py
+++ b/tests/test_op_builder.py
@@ -107,21 +107,21 @@ def test_builder_create_block():
     target = Region([block1, block2])
     builder = Builder(InsertPoint.at_start(block1))
 
-    new_block1 = builder.create_block_at_start(target, (i32,))
+    new_block1 = builder.create_block(BlockInsertPoint.at_start(target), (i32,))
     assert len(new_block1.args) == 1
     assert new_block1.args[0].type == i32
     assert len(target.blocks) == 3
     assert target.blocks[0] == new_block1
     assert builder.insertion_point == InsertPoint.at_start(new_block1)
 
-    new_block2 = builder.create_block_at_end(target, (i64,))
+    new_block2 = builder.create_block(BlockInsertPoint.at_end(target), (i64,))
     assert len(new_block2.args) == 1
     assert new_block2.args[0].type == i64
     assert len(target.blocks) == 4
     assert target.blocks[3] == new_block2
     assert builder.insertion_point == InsertPoint.at_start(new_block2)
 
-    new_block3 = builder.create_block_before(block2, (i32, i64))
+    new_block3 = builder.create_block(BlockInsertPoint.before(block2), (i32, i64))
     assert len(new_block3.args) == 2
     assert new_block3.args[0].type == i32
     assert new_block3.args[1].type == i64
@@ -129,7 +129,7 @@ def test_builder_create_block():
     assert target.blocks[2] == new_block3
     assert builder.insertion_point == InsertPoint.at_start(new_block3)
 
-    new_block4 = builder.create_block_after(block2, (i64, i32))
+    new_block4 = builder.create_block(BlockInsertPoint.after(block2), (i64, i32))
     assert len(new_block4.args) == 2
     assert new_block4.args[0].type == i64
     assert new_block4.args[1].type == i32
@@ -173,10 +173,10 @@ def add_block_on_create(b: Block):
 
     b.block_creation_handler = [add_block_on_create]
 
-    b1 = b.create_block_at_start(region)
-    b2 = b.create_block_at_end(region)
-    b3 = b.create_block_before(block)
-    b4 = b.create_block_after(block)
+    b1 = b.create_block(BlockInsertPoint.at_start(region))
+    b2 = b.create_block(BlockInsertPoint.at_end(region))
+    b3 = b.create_block(BlockInsertPoint.before(block))
+    b4 = b.create_block(BlockInsertPoint.after(block))
 
     assert created_blocks == [b1, b2, b3, b4]
 
diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py
index 7553537a9b..fc6052f584 100644
--- a/tests/test_rewriter.py
+++ b/tests/test_rewriter.py
@@ -9,7 +9,7 @@
 from xdsl.dialects.builtin import Builtin, Float32Type, Float64Type, ModuleOp, i32, i64
 from xdsl.ir import Block, Region
 from xdsl.parser import Parser
-from xdsl.rewriter import InsertPoint, Rewriter
+from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter
 
 
 def rewrite_and_compare(
@@ -289,7 +289,9 @@ def test_insert_block_before():
 """
 
     def insert_empty_block_before(module: ModuleOp, rewriter: Rewriter) -> None:
-        rewriter.insert_block_before(Block(), module.regions[0].blocks[0])
+        rewriter.insert_block(
+            Block(), BlockInsertPoint.before(module.regions[0].blocks[0])
+        )
 
     rewrite_and_compare(prog, expected, insert_empty_block_before)
 
@@ -312,7 +314,9 @@ def test_insert_block_after():
 """
 
     def insert_empty_block_after(module: ModuleOp, rewriter: Rewriter) -> None:
-        rewriter.insert_block_after(Block(), module.regions[0].blocks[0])
+        rewriter.insert_block(
+            Block(), BlockInsertPoint.after(module.regions[0].blocks[0])
+        )
 
     rewrite_and_compare(prog, expected, insert_empty_block_after)
 
@@ -510,7 +514,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
                 Block((test.TestOp(result_types=(Float64Type(),)),)),
             )
         )
-        rewriter.inline_region_before(region, module.body.blocks[1])
+        rewriter.inline_region(region, BlockInsertPoint.before(module.body.blocks[1]))
 
     rewrite_and_compare(prog, expected, transformation)
 
@@ -544,7 +548,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
                 Block((test.TestOp(result_types=(Float64Type(),)),)),
             )
         )
-        rewriter.inline_region_after(region, module.body.blocks[0])
+        rewriter.inline_region(region, BlockInsertPoint.after(module.body.blocks[0]))
 
     rewrite_and_compare(prog, expected, transformation)
 
@@ -578,7 +582,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
                 Block((test.TestOp(result_types=(Float64Type(),)),)),
             )
         )
-        rewriter.inline_region_at_start(region, module.body)
+        rewriter.inline_region(region, BlockInsertPoint.at_start(module.body))
 
     rewrite_and_compare(prog, expected, transformation)
 
@@ -612,7 +616,7 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:
                 Block((test.TestOp(result_types=(Float64Type(),)),)),
             )
         )
-        rewriter.inline_region_at_end(region, module.body)
+        rewriter.inline_region(region, BlockInsertPoint.at_end(module.body))
 
     rewrite_and_compare(prog, expected, transformation)
 
@@ -621,13 +625,13 @@ def test_verify_inline_region():
     region = Region(Block())
 
     with pytest.raises(ValueError, match="Cannot move region into itself."):
-        Rewriter.inline_region_before(region, region.block)
+        Rewriter.inline_region(region, BlockInsertPoint.before(region.block))
 
     with pytest.raises(ValueError, match="Cannot move region into itself."):
-        Rewriter.inline_region_after(region, region.block)
+        Rewriter.inline_region(region, BlockInsertPoint.after(region.block))
 
     with pytest.raises(ValueError, match="Cannot move region into itself."):
-        Rewriter.inline_region_at_start(region, region)
+        Rewriter.inline_region(region, BlockInsertPoint.at_start(region))
 
     with pytest.raises(ValueError, match="Cannot move region into itself."):
-        Rewriter.inline_region_at_end(region, region)
+        Rewriter.inline_region(region, BlockInsertPoint.at_end(region))
diff --git a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py
index 24abc9f874..c288bffbeb 100644
--- a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py
+++ b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py
@@ -7,7 +7,7 @@
     RewritePattern,
     op_type_rewrite_pattern,
 )
-from xdsl.rewriter import InsertPoint
+from xdsl.rewriter import BlockInsertPoint, InsertPoint
 
 
 class LowerRiscvScfForPattern(RewritePattern):
@@ -119,7 +119,7 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
             ),
         )
 
-        rewriter.inline_region_before(op.body, end_block)
+        rewriter.inline_region(op.body, BlockInsertPoint.before(end_block))
 
         # Move lb to new register to initialize the iv.
         # Skip for loop if condition is not satisfied at start.
diff --git a/xdsl/builder.py b/xdsl/builder.py
index 5f54fe948a..a1e4d8890a 100644
--- a/xdsl/builder.py
+++ b/xdsl/builder.py
@@ -7,6 +7,8 @@
 from types import TracebackType
 from typing import ClassVar, TypeAlias, overload
 
+from typing_extensions import deprecated
+
 from xdsl.dialects.builtin import ArrayAttr
 from xdsl.ir import Attribute, Block, BlockArgument, Operation, OperationInvT, Region
 from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter
@@ -75,7 +77,7 @@ def insert(self, op: OperationInvT) -> OperationInvT:
         return op
 
     def create_block(
-        self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute]
+        self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute] = ()
     ) -> Block:
         """
         Create a block at the given location, and set the operation insertion point
@@ -89,6 +91,9 @@ def create_block(
         self.handle_block_creation(block)
         return block
 
+    @deprecated(
+        "Use create_block(BlockInsertPoint.before(insert_before), arg_types) instead"
+    )
     def create_block_before(
         self, insert_before: Block, arg_types: Iterable[Attribute] = ()
     ) -> Block:
@@ -98,6 +103,9 @@ def create_block_before(
         """
         return self.create_block(BlockInsertPoint.before(insert_before), arg_types)
 
+    @deprecated(
+        "Use create_block(BlockInsertPoint.after(insert_after), arg_types) instead"
+    )
     def create_block_after(
         self, insert_after: Block, arg_types: Iterable[Attribute] = ()
     ) -> Block:
@@ -107,6 +115,9 @@ def create_block_after(
         """
         return self.create_block(BlockInsertPoint.after(insert_after), arg_types)
 
+    @deprecated(
+        "Use create_block(BlockInsertPoint.at_start(region), arg_types) instead"
+    )
     def create_block_at_start(
         self, region: Region, arg_types: Iterable[Attribute] = ()
     ) -> Block:
@@ -116,6 +127,7 @@ def create_block_at_start(
         """
         return self.create_block(BlockInsertPoint.at_start(region), arg_types)
 
+    @deprecated("Use create_block(BlockInsertPoint.at_end(region), arg_types) instead")
     def create_block_at_end(
         self, region: Region, arg_types: Iterable[Attribute] = ()
     ) -> Block:
diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py
index 340ad51299..a0b14ba17b 100644
--- a/xdsl/pattern_rewriter.py
+++ b/xdsl/pattern_rewriter.py
@@ -356,18 +356,30 @@ def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> No
         self.has_done_action = True
         Rewriter.inline_region(region, insertion_point)
 
+    @deprecated(
+        "Please use `inline_region(region, BlockInsertPoint.before(target))` instead"
+    )
     def inline_region_before(self, region: Region, target: Block) -> None:
         """Move the region blocks to an existing region."""
         self.inline_region(region, BlockInsertPoint.before(target))
 
+    @deprecated(
+        "Please use `inline_region(region, BlockInsertPoint.after(target))` instead"
+    )
     def inline_region_after(self, region: Region, target: Block) -> None:
         """Move the region blocks to an existing region."""
         self.inline_region(region, BlockInsertPoint.after(target))
 
+    @deprecated(
+        "Please use `inline_region(region, BlockInsertPoint.at_start(target))` instead"
+    )
     def inline_region_at_start(self, region: Region, target: Region) -> None:
         """Move the region blocks to an existing region."""
         self.inline_region(region, BlockInsertPoint.at_start(target))
 
+    @deprecated(
+        "Please use `inline_region(region, BlockInsertPoint.at_end(target))` instead"
+    )
     def inline_region_at_end(self, region: Region, target: Region) -> None:
         """Move the region blocks to an existing region."""
         self.inline_region(region, BlockInsertPoint.at_end(target))
diff --git a/xdsl/rewriter.py b/xdsl/rewriter.py
index a45cc5161a..e280aef608 100644
--- a/xdsl/rewriter.py
+++ b/xdsl/rewriter.py
@@ -3,6 +3,8 @@
 from collections.abc import Iterable, Sequence
 from dataclasses import dataclass, field
 
+from typing_extensions import deprecated
+
 from xdsl.ir import Block, Operation, Region, SSAValue
 
 
@@ -240,6 +242,7 @@ def insert_block(block: Block | Iterable[Block], insert_point: BlockInsertPoint)
         else:
             region.add_block(block)
 
+    @deprecated("Use `insert_block(block, BlockInsertPoint.after(target))` instead")
     @staticmethod
     def insert_block_after(block: Block | list[Block], target: Block):
         """
@@ -249,6 +252,7 @@ def insert_block_after(block: Block | list[Block], target: Block):
         """
         Rewriter.insert_block(block, BlockInsertPoint.after(target))
 
+    @deprecated("Use `insert_block(block, BlockInsertPoint.before(target))` instead")
     @staticmethod
     def insert_block_before(block: Block | list[Block], target: Block):
         """
@@ -284,21 +288,27 @@ def inline_region(region: Region, insertion_point: BlockInsertPoint) -> None:
         else:
             region.move_blocks(insertion_point.region)
 
+    @deprecated("Use `inline_region(region, BlockInsertPoint.before(target))` instead")
     @staticmethod
     def inline_region_before(region: Region, target: Block) -> None:
         """Move the region blocks to an existing region, before `target`."""
         Rewriter.inline_region(region, BlockInsertPoint.before(target))
 
+    @deprecated("Use `inline_region(region, BlockInsertPoint.after(target))` instead")
     @staticmethod
     def inline_region_after(region: Region, target: Block) -> None:
         """Move the region blocks to an existing region, after `target`."""
         Rewriter.inline_region(region, BlockInsertPoint.after(target))
 
+    @deprecated(
+        "Use `inline_region(region, BlockInsertPoint.at_start(target))` instead"
+    )
     @staticmethod
     def inline_region_at_start(region: Region, target: Region) -> None:
         """Move the region blocks to the start of an existing region."""
         Rewriter.inline_region(region, BlockInsertPoint.at_start(target))
 
+    @deprecated("Use `inline_region(region, BlockInsertPoint.at_end(target))` instead")
     @staticmethod
     def inline_region_at_end(region: Region, target: Region) -> None:
         """Move the region blocks to the end of an existing region."""
diff --git a/xdsl/transforms/convert_scf_to_cf.py b/xdsl/transforms/convert_scf_to_cf.py
index d1c07e683b..977c756fd1 100644
--- a/xdsl/transforms/convert_scf_to_cf.py
+++ b/xdsl/transforms/convert_scf_to_cf.py
@@ -16,7 +16,7 @@
     RewritePattern,
     op_type_rewrite_pattern,
 )
-from xdsl.rewriter import InsertPoint
+from xdsl.rewriter import BlockInsertPoint, InsertPoint
 from xdsl.traits import IsTerminator
 
 
@@ -60,7 +60,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /):
         )
 
         rewriter.erase_op(then_terminator)
-        rewriter.inline_region_before(then_region, continue_block)
+        rewriter.inline_region(then_region, BlockInsertPoint.before(continue_block))
 
         # Move blocks from the "else" region (if present) to the region containing
         # 'scf.if', place it before the continuation block and branch to it.  It
@@ -78,7 +78,7 @@ def match_and_rewrite(self, if_op: IfOp, rewriter: PatternRewriter, /):
             )
 
             rewriter.erase_op(else_terminator)
-            rewriter.inline_region_before(else_region, continue_block)
+            rewriter.inline_region(else_region, BlockInsertPoint.before(continue_block))
         else:
             else_block = continue_block
 
@@ -116,7 +116,7 @@ def match_and_rewrite(self, for_op: ForOp, rewriter: PatternRewriter):
         first_body_block = condition_block.split_before(first_op)
         last_body_block = for_op.body.last_block
         assert last_body_block is not None
-        rewriter.inline_region_before(for_op.body, end_block)
+        rewriter.inline_region(for_op.body, BlockInsertPoint.before(end_block))
         iv = condition_block.args[0]
 
         # Append the induction variable stepping logic to the last body block and
@@ -169,7 +169,7 @@ def _convert_region(
         rewriter.replace_op(yield_op, BranchOp(continue_block, *yield_op.operands))
 
         # Inline the region
-        rewriter.inline_region_before(region, continue_block)
+        rewriter.inline_region(region, BlockInsertPoint.before(continue_block))
         return block
 
     @op_type_rewrite_pattern