From d43ac2fcfb301dd132bc267507eebb6a73cdee0a Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 16 Dec 2024 10:30:31 +0000 Subject: [PATCH 01/10] interactive: use individual rewrite pass directly instead of helpers (#3642) It doesn't seem that this abstraction currently adds anything, both the available pass and pass itself are already immutable and can be reasoned about directly. --- tests/interactive/test_app.py | 9 +- tests/interactive/test_rewrites.py | 99 +++++--------------- xdsl/interactive/get_all_available_passes.py | 16 +--- xdsl/interactive/rewrites.py | 58 ++---------- 4 files changed, 40 insertions(+), 142 deletions(-) diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index 0abf9bf7ef..ac129f1c74 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -20,10 +20,7 @@ from xdsl.interactive.add_arguments_screen import AddArguments from xdsl.interactive.app import InputApp from xdsl.interactive.passes import AvailablePass, get_condensed_pass_list -from xdsl.interactive.rewrites import ( - convert_indexed_individual_rewrites_to_available_pass, - get_all_possible_rewrites, -) +from xdsl.interactive.rewrites import get_all_possible_rewrites from xdsl.ir import Block, Region from xdsl.transforms import ( get_all_passes, @@ -295,9 +292,7 @@ def callback(x: str): ) assert app.available_pass_list == get_condensed_pass_list( expected_module, app.all_passes - ) + convert_indexed_individual_rewrites_to_available_pass( - rewrites, expected_module - ) + ) + tuple(rewrites) # press "Uncondense" button await pilot.click("#uncondense_button") diff --git a/tests/interactive/test_rewrites.py b/tests/interactive/test_rewrites.py index 0d9a42bb86..25313758a8 100644 --- a/tests/interactive/test_rewrites.py +++ b/tests/interactive/test_rewrites.py @@ -5,20 +5,15 @@ ) from xdsl.dialects.test import Test, TestOp from xdsl.interactive.passes import AvailablePass -from xdsl.interactive.rewrites import ( - IndexedIndividualRewrite, - IndividualRewrite, - convert_indexed_individual_rewrites_to_available_pass, - get_all_possible_rewrites, -) +from xdsl.interactive.rewrites import get_all_possible_rewrites from xdsl.parser import Parser from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, op_type_rewrite_pattern, ) -from xdsl.transforms import individual_rewrite -from xdsl.utils.parse_pipeline import parse_pipeline +from xdsl.transforms.individual_rewrite import ApplyIndividualRewritePass +from xdsl.utils.parse_pipeline import PipelinePassSpec class Rewrite(RewritePattern): @@ -47,75 +42,31 @@ def test_get_all_possible_rewrite(): module = parser.parse_module() expected_res = [ - ( - IndexedIndividualRewrite( - 1, IndividualRewrite(operation="test.op", pattern="TestRewrite") - ) + AvailablePass( + display_name='TestOp("test.op"() {"label" = "a"} : () -> ()):test.op:TestRewrite', + module_pass=ApplyIndividualRewritePass, + pass_spec=PipelinePassSpec( + "apply-individual-rewrite", + { + "matched_operation_index": (1,), + "operation_name": ("test.op",), + "pattern_name": ("TestRewrite",), + }, + ), ), - ( - IndexedIndividualRewrite( - operation_index=2, - rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"), - ) + AvailablePass( + display_name='TestOp("test.op"() {"label" = "a"} : () -> ()):test.op:TestRewrite', + module_pass=ApplyIndividualRewritePass, + pass_spec=PipelinePassSpec( + "apply-individual-rewrite", + { + "matched_operation_index": (2,), + "operation_name": ("test.op",), + "pattern_name": ("TestRewrite",), + }, + ), ), ] res = get_all_possible_rewrites(module, {"test.op": {"TestRewrite": Rewrite()}}) assert res == expected_res - - -def test_convert_indexed_individual_rewrites_to_available_pass(): - # build module - prog = """ - builtin.module { - "test.op"() {"label" = "a"} : () -> () - "test.op"() {"label" = "a"} : () -> () - "test.op"() {"label" = "b"} : () -> () - } - """ - - ctx = MLContext() - ctx.load_dialect(Builtin) - ctx.load_dialect(Test) - parser = Parser(ctx, prog) - module = parser.parse_module() - - rewrites = ( - ( - IndexedIndividualRewrite( - 1, IndividualRewrite(operation="test.op", pattern="TestRewrite") - ) - ), - ( - IndexedIndividualRewrite( - operation_index=2, - rewrite=IndividualRewrite(operation="test.op", pattern="TestRewrite"), - ) - ), - ) - - expected_res = tuple( - ( - AvailablePass( - display_name='TestOp("test.op"() {"label" = "a"} : () -> ()):test.op:TestRewrite', - module_pass=individual_rewrite.ApplyIndividualRewritePass, - pass_spec=list( - parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=1 operation_name="test.op" pattern_name="TestRewrite"}' - ) - )[0], - ), - AvailablePass( - display_name='TestOp("test.op"() {"label" = "a"} : () -> ()):test.op:TestRewrite', - module_pass=individual_rewrite.ApplyIndividualRewritePass, - pass_spec=list( - parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=2 operation_name="test.op" pattern_name="TestRewrite"}' - ) - )[0], - ), - ) - ) - - res = convert_indexed_individual_rewrites_to_available_pass(rewrites, module) - assert res == expected_res diff --git a/xdsl/interactive/get_all_available_passes.py b/xdsl/interactive/get_all_available_passes.py index 6283905fbc..3547df2a7d 100644 --- a/xdsl/interactive/get_all_available_passes.py +++ b/xdsl/interactive/get_all_available_passes.py @@ -6,10 +6,7 @@ get_condensed_pass_list, get_new_registered_context, ) -from xdsl.interactive.rewrites import ( - convert_indexed_individual_rewrites_to_available_pass, - get_all_possible_rewrites, -) +from xdsl.interactive.rewrites import get_all_possible_rewrites from xdsl.ir import Dialect from xdsl.parser import Parser from xdsl.passes import ModulePass @@ -34,19 +31,14 @@ def get_available_pass_list( current_module = apply_passes_to_module(current_module, ctx, pass_pipeline) - # get all rewrites - rewrites = get_all_possible_rewrites( + # get all individual rewrites + individual_rewrites = get_all_possible_rewrites( current_module, rewrite_by_names_dict, ) - # transform rewrites into passes - rewrites_as_pass_list = convert_indexed_individual_rewrites_to_available_pass( - rewrites, current_module - ) # merge rewrite passes with "other" pass list if condense_mode: pass_list = get_condensed_pass_list(current_module, all_passes) - return pass_list + rewrites_as_pass_list else: pass_list = tuple(AvailablePass(p.name, p, None) for _, p in all_passes) - return pass_list + rewrites_as_pass_list + return pass_list + tuple(individual_rewrites) diff --git a/xdsl/interactive/rewrites.py b/xdsl/interactive/rewrites.py index b151e910ee..50ca4ed908 100644 --- a/xdsl/interactive/rewrites.py +++ b/xdsl/interactive/rewrites.py @@ -1,68 +1,23 @@ from collections.abc import Sequence -from typing import NamedTuple from xdsl.dialects.builtin import ModuleOp from xdsl.interactive.passes import AvailablePass from xdsl.pattern_rewriter import PatternRewriter, RewritePattern from xdsl.traits import HasCanonicalizationPatternsTrait from xdsl.transforms import individual_rewrite -from xdsl.utils.parse_pipeline import PipelinePassSpec - - -class IndividualRewrite(NamedTuple): - """ - Type alias for a possible rewrite, described by an operation and pattern name. - """ - - operation: str - pattern: str - - -class IndexedIndividualRewrite(NamedTuple): - """ - Type alias for a specific rewrite pattern, additionally consisting of its operation index. - """ - - operation_index: int - rewrite: IndividualRewrite - - -def convert_indexed_individual_rewrites_to_available_pass( - rewrites: Sequence[IndexedIndividualRewrite], current_module: ModuleOp -) -> tuple[AvailablePass, ...]: - """ - Function that takes a tuple of rewrites, converts each rewrite into an IndividualRewrite pass and returns the tuple of AvailablePass. - """ - rewrites_as_pass_list: tuple[AvailablePass, ...] = () - for op_idx, (op_name, pat_name) in rewrites: - rewrite_pass = individual_rewrite.ApplyIndividualRewritePass - rewrite_spec = PipelinePassSpec( - name=rewrite_pass.name, - args={ - "matched_operation_index": (op_idx,), - "operation_name": (op_name,), - "pattern_name": (pat_name,), - }, - ) - op = list(current_module.walk())[op_idx] - rewrites_as_pass_list = ( - *rewrites_as_pass_list, - (AvailablePass(f"{op}:{op_name}:{pat_name}", rewrite_pass, rewrite_spec)), - ) - return rewrites_as_pass_list def get_all_possible_rewrites( module: ModuleOp, rewrite_by_name: dict[str, dict[str, RewritePattern]], -) -> Sequence[IndexedIndividualRewrite]: +) -> Sequence[AvailablePass]: """ Function that takes a sequence of IndividualRewrite Patterns and a ModuleOp, and returns the possible rewrites. Issue filed: https://github.com/xdslproject/xdsl/issues/2162 """ - res: list[IndexedIndividualRewrite] = [] + res: list[AvailablePass] = [] for op_idx, matched_op in enumerate(module.walk()): pattern_by_name = rewrite_by_name.get(matched_op.name, {}).copy() @@ -78,9 +33,14 @@ def get_all_possible_rewrites( rewriter = PatternRewriter(cloned_op) pattern.match_and_rewrite(cloned_op, rewriter) if rewriter.has_done_action: + p = individual_rewrite.ApplyIndividualRewritePass( + op_idx, cloned_op.name, pattern_name + ) res.append( - IndexedIndividualRewrite( - op_idx, IndividualRewrite(cloned_op.name, pattern_name) + AvailablePass( + f"{cloned_op}:{cloned_op.name}:{pattern_name}", + individual_rewrite.ApplyIndividualRewritePass, + p.pipeline_pass_spec(), ) ) From fd6296dd69de1debdbd7c9f5501e050ef8fc8f4b Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 16 Dec 2024 14:28:08 +0000 Subject: [PATCH 02/10] misc: add a DisjointSet data structure (#3621) This is useful in a few places, notably in bufferization. --- tests/utils/test_disjoint_set.py | 133 +++++++++++++++++++++++ xdsl/utils/disjoint_set.py | 175 +++++++++++++++++++++++++++++++ 2 files changed, 308 insertions(+) create mode 100644 tests/utils/test_disjoint_set.py create mode 100644 xdsl/utils/disjoint_set.py diff --git a/tests/utils/test_disjoint_set.py b/tests/utils/test_disjoint_set.py new file mode 100644 index 0000000000..3422e1acb4 --- /dev/null +++ b/tests/utils/test_disjoint_set.py @@ -0,0 +1,133 @@ +import pytest + +from xdsl.utils.disjoint_set import DisjointSet, IntDisjointSet + + +def test_disjoint_set_init(): + ds = IntDisjointSet(size=5) + assert ds.value_count() == 5 + # Each element should start in its own set + for i in range(5): + assert ds[i] == i + + +def test_disjoint_set_add(): + ds = IntDisjointSet(size=2) + assert ds.value_count() == 2 + + new_val = ds.add() + assert new_val == 2 + assert ds.value_count() == 3 + assert ds[new_val] == new_val + + +def test_disjoint_set_find_invalid(): + ds = IntDisjointSet(size=3) + with pytest.raises(KeyError): + ds[3] + with pytest.raises(KeyError): + ds[-1] + + +def test_disjoint_set_union(): + ds = IntDisjointSet(size=4) + + # Union 0 and 1 + assert ds.union(0, 1) + root = ds[0] + assert ds[1] == root + assert ds.connected(0, 1) + assert not ds.connected(0, 2) + + # Union 2 and 3 + assert ds.union(2, 3) + root2 = ds[2] + assert ds[3] == root2 + assert ds.connected(2, 3) + assert not ds.connected(1, 2) + + # Union already connected elements + assert not ds.union(0, 1) + assert ds.connected(0, 1) + + # Union two sets + assert ds.union(1, 2) + final_root = ds[0] + assert ds[1] == final_root + assert ds[2] == final_root + assert ds[3] == final_root + # After unioning all elements, they should all be connected + assert ds.connected(0, 1) + assert ds.connected(1, 2) + assert ds.connected(2, 3) + assert ds.connected(0, 3) + + +def test_disjoint_set_path_compression(): + ds = IntDisjointSet(size=4) + + # Create a chain: 3->2->1->0 + ds._parent = [0, 0, 1, 2] # pyright: ignore[reportPrivateUsage] + ds._count = [4, 3, 2, 1] # pyright: ignore[reportPrivateUsage] + + # Find should compress the path + root = ds[3] + # After compression, all nodes should point directly to root + assert ds._parent[3] == root # pyright: ignore[reportPrivateUsage] + assert ds._parent[2] == root # pyright: ignore[reportPrivateUsage] + assert ds._parent[1] == root # pyright: ignore[reportPrivateUsage] + assert ds._parent[0] == root # pyright: ignore[reportPrivateUsage] + + +def test_generic_disjoint_set(): + ds = DisjointSet(["a", "b", "c", "d"]) + + # Union a and b + assert ds.union("a", "b") + root = ds.find("a") + assert ds.find("b") == root + assert ds.connected("a", "b") + assert not ds.connected("a", "c") + + # Union c and d + assert ds.union("c", "d") + root2 = ds.find("c") + assert ds.find("d") == root2 + assert ds.connected("c", "d") + assert not ds.connected("b", "c") + + # Union already connected elements + assert not ds.union("a", "b") + assert ds.connected("a", "b") + + # Union two sets + assert ds.union("b", "c") + final_root = ds.find("a") + assert ds.find("b") == final_root + assert ds.find("c") == final_root + assert ds.find("d") == final_root + # After unioning all elements, they should all be connected + assert ds.connected("a", "b") + assert ds.connected("b", "c") + assert ds.connected("c", "d") + assert ds.connected("a", "d") + + +def test_generic_disjoint_set_add(): + ds = DisjointSet(["a", "b"]) + ds.add("c") + ds.add("d") + + assert ds.union("a", "c") + root = ds.find("a") + assert ds.find("c") == root + + assert ds.union("b", "d") + root2 = ds.find("b") + assert ds.find("d") == root2 + + +def test_generic_disjoint_set_find_invalid(): + ds = DisjointSet(["a", "b", "c"]) + with pytest.raises(KeyError): + ds.find("d") diff --git a/xdsl/utils/disjoint_set.py b/xdsl/utils/disjoint_set.py new file mode 100644 index 0000000000..1cf7e0857e --- /dev/null +++ b/xdsl/utils/disjoint_set.py @@ -0,0 +1,175 @@ +""" +Generic implementation of a disjoint set data structure. + +https://en.wikipedia.org/wiki/Disjoint-set_data_structure +""" + +from collections.abc import Hashable, Sequence +from typing import Generic, TypeVar + + +class IntDisjointSet: + """ + Represents a collection of disjoint sets of integers. + The integers stored are always in the range [0,n), where n is the number of elements + in this structure. + + This implementation uses path compression and union by size for efficiency. + The amortized time complexity for operations is nearly constant. + """ + + _parent: list[int] + """ + Index of the parent node. If the node is its own parent then it is a root node. + """ + _count: list[int] + """ + If the node is a root node, the corresponding value is the count of elements in the + set. For non-root nodes, these counts may be stale and should not be used. + """ + + def __init__(self, *, size: int) -> None: + """ + Initialize disjoint sets with elements [0,size). + Each element starts in its own singleton set. + """ + self._parent = list(range(size)) + self._count = [1] * size + + def value_count(self) -> int: + """Number of nodes in this structure.""" + return len(self._parent) + + def add(self) -> int: + """ + Add a new element to this set as a singleton. + Returns the added value, which will be equal to the previous size. + """ + res = len(self._parent) + self._parent.append(res) + self._count.append(1) + return res + + def __getitem__(self, value: int) -> int: + """ + Returns the root/representative value of this set. + Uses path compression - updates parent pointers to point directly to the root + as we traverse up the tree, improving amortized performance. + """ + if value < 0 or len(self._parent) <= value: + raise KeyError(f"Index {value} not found") + + # Find the root + root = value + while self._parent[root] != root: + root = self._parent[root] + + # Path compression - point all nodes on path to root + current = value + while current != root: + next_parent = self._parent[current] + self._parent[current] = root + current = next_parent + + return root + + def union(self, lhs: int, rhs: int) -> bool: + """ + Merges the sets containing lhs and rhs if they are different. + Returns True if the sets were merged, False if they were already the same set. + + Uses union by size - the smaller tree is attached to the larger tree's root + to maintain balance. This ensures the maximum tree height is O(log n). + """ + lhs_root = self[lhs] + rhs_root = self[rhs] + if lhs_root == rhs_root: + return False + + lhs_count = self._count[lhs_root] + rhs_count = self._count[rhs_root] + # Choose the root of the larger tree as the new parent + new_parent, new_child = ( + (lhs_root, rhs_root) if lhs_count <= rhs_count else (rhs_root, lhs_root) + ) + self._parent[new_child] = new_parent + self._count[new_parent] = lhs_count + rhs_count + # Note: We don't need to update _count[new_child] since it's no longer a root + return True + + def connected(self, lhs: int, rhs: int) -> bool: + return self[lhs] == self[rhs] + + +_T = TypeVar("_T", bound=Hashable) + + +class DisjointSet(Generic[_T]): + """ + A disjoint-set data structure that works with arbitrary hashable values. + Internally uses IntDisjointSet by mapping values to integer indices. + """ + + _base: IntDisjointSet + _values: list[_T] + _index_by_value: dict[_T, int] + + def __init__(self, values: Sequence[_T] = ()): + """ + Initialize a DisjointSet with the given sequence of values. + Each value starts in its own singleton set. + + Args: + values: Initial sequence of values to add to the disjoint set + """ + self._values = list(values) + self._index_by_value = {v: i for i, v in enumerate(self._values)} + self._base = IntDisjointSet(size=len(self._values)) + + def __len__(self): + return len(self._values) + + def add(self, value: _T): + """ + Add a new value to the disjoint set in its own singleton set. + + Args: + value: The value to add + """ + index = self._base.add() + self._values.append(value) + self._index_by_value[value] = index + + def find(self, value: _T) -> _T: + """ + Find the representative value for the set containing the given value. + + Returns the representative value for the set. + + Raises: + KeyError: If the value is not in the disjoint set + """ + index = self._base[self._index_by_value[value]] + return self._values[index] + + def union(self, lhs: _T, rhs: _T) -> bool: + """ + Merge the sets containing the two given values if they are different. + + Returns `True` if the sets were merged, `False` if they were already the same set. + + Raises: + KeyError: If either value is not in the disjoint set + """ + return self._base.union(self._index_by_value[lhs], self._index_by_value[rhs]) + + def connected(self, lhs: _T, rhs: _T) -> bool: + """ + Returns `True` if the values are in the same set. + + Raises: + KeyError: If either value is not in the disjoint set + """ + return self._base.connected( + self._index_by_value[lhs], self._index_by_value[rhs] + ) From 91c2790e754b307ecfe9bd84db107da2bc85ffb9 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 16 Dec 2024 15:04:36 +0000 Subject: [PATCH 03/10] misc: use activated venv dir if any, and add back project venv export (#3630) This change lets us use the activated venv, if any, for all our make commands. In order to make it work, I had to add back the export, because otherwise the inner uv commands would complain about the fact that the current venv is not the same as the project venv, which is .venv if not overridden. --- Makefile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index f422f3eeb5..5be8b96cb1 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,9 @@ COVERAGE_FILE ?= .coverage # allow overriding the name of the venv directory VENV_DIR ?= .venv -UV_PROJECT_ENVIRONMENT=${VENV_DIR} + +# use activated venv if any +export UV_PROJECT_ENVIRONMENT=$(if $(VIRTUAL_ENV),$(VIRTUAL_ENV),$(VENV_DIR)) # allow overriding which extras are installed VENV_EXTRAS ?= --extra gui --extra dev --extra jax --extra riscv From a2505b3f7e70ace69ee2cc40b8fb11bfd5a42c1e Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 16 Dec 2024 16:59:17 +0000 Subject: [PATCH 04/10] CI: add lockfile update action to manually trigger for dependabot PRs (#3647) I tested it [here](https://github.com/xdslproject/xdsl/pull/3646) and it works if you do the dumb thing and trigger on push, but it seems to not run actions on unsigned commits by default. I'm hoping that if we trigger the action manually it'll inherit the clicker's permissions and re-run the action. In either case, I think whatever we adopt should supercede the cron job one as we want to never be out of sync between the pyproject and lockfile. --- .github/workflows/update-bot.yml | 38 ----------------------- .github/workflows/update-lockfile-bot.yml | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+), 38 deletions(-) delete mode 100644 .github/workflows/update-bot.yml create mode 100644 .github/workflows/update-lockfile-bot.yml diff --git a/.github/workflows/update-bot.yml b/.github/workflows/update-bot.yml deleted file mode 100644 index f05c4f0d5c..0000000000 --- a/.github/workflows/update-bot.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Update Bot - -on: - workflow_dispatch: - # Set the schedule, every week at 8:00am on Monday - schedule: - - cron: 0 8 * * 1 - -permissions: - contents: write - pull-requests: write - -jobs: - lock: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: astral-sh/setup-uv@v3 - - - run: | - echo "\`\`\`" > uv_output.md - uv lock &>> uv_output.md - echo "\`\`\`" >> uv_output.md - - - name: Create pull request - uses: peter-evans/create-pull-request@v7 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: Update uv lockfile - title: Update uv lockfile - body-path: uv_output.md - branch: update-uv - base: main - labels: install - delete-branch: true - add-paths: uv.lock - assignees: math-fehr, georgebisbas, superlopuh diff --git a/.github/workflows/update-lockfile-bot.yml b/.github/workflows/update-lockfile-bot.yml new file mode 100644 index 0000000000..93c5bf1b4e --- /dev/null +++ b/.github/workflows/update-lockfile-bot.yml @@ -0,0 +1,32 @@ +name: Update Lockfile Bot + +on: + workflow_dispatch: + +permissions: + contents: write + pull-requests: write + +jobs: + lock: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python + run: uv python install 3.12 + + - name: Install the package locally and update lockfile + run: | + # Install all default extras. + XDSL_VERSION_OVERRIDE="0+dynamic" make venv + + - uses: EndBug/add-and-commit@v9 + with: + add: uv.lock From 3725f40fdc21c91d069f1ac523e7fce25efcdbd3 Mon Sep 17 00:00:00 2001 From: Chris Vasiladiotis Date: Mon, 16 Dec 2024 17:01:34 +0000 Subject: [PATCH 05/10] dialects (llvm): Add dense array constraint for the `position` attribute of `llvm.extractvalue` and `llvm.insertvalue` operations (#3643) This PR: - Adds a dense array constraint for the `position` attribute of `llvm.extractvalue` and `llvm.insertvalue` operations, restricting to `i64` as in MLIR. - Filecheck tests of the above Resolves #3155 --- tests/filecheck/dialects/llvm/invalid.mlir | 21 ++++++++++++++++++++- xdsl/dialects/builtin.py | 4 ++++ xdsl/dialects/llvm.py | 5 +++-- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/filecheck/dialects/llvm/invalid.mlir b/tests/filecheck/dialects/llvm/invalid.mlir index 09627efd7a..0f5d854a6f 100644 --- a/tests/filecheck/dialects/llvm/invalid.mlir +++ b/tests/filecheck/dialects/llvm/invalid.mlir @@ -7,10 +7,29 @@ builtin.module { // CHECK: Varargs specifier `...` must be at the end of the argument definition // ----- -// CHECK: ----- builtin.module { %cc = "test.op"() {"cconv" = #llvm.cconv} : () -> () } // CHECK: Unknown calling convention + +// ----- + +func.func public @main() { + %0 = "test.op"() : () -> (!llvm.struct<(i32)>) + %1 = "llvm.extractvalue"(%0) {"position" = array} : (!llvm.struct<(i32)>) -> i32 + func.return +} + +// CHECK: Expected attribute i64 but got i32 + +// ----- + +func.func public @main() { + %0, %1 = "test.op"() : () -> (!llvm.struct<(i32)>, i32) + %2 = "llvm.insertvalue"(%0, %1) {"position" = array} : (!llvm.struct<(i32)>, i32) -> !llvm.struct<(i32)> + func.return +} + +// CHECK: Expected attribute i64 but got i32 diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index f24f2c83fa..e4bf295f56 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1248,6 +1248,10 @@ def __len__(self) -> int: return len(self.data.data) // self.elt_type.size +DenseI64ArrayConstr = ParamAttrConstraint(DenseArrayBase, [i64, BytesAttr]) +"""Type constraint for DenseArrays containing integers of i64 integers.""" + + @irdl_attr_definition class FunctionType(ParametrizedAttribute, TypeAttribute): name = "fun" diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index dd3752ff43..3d767ac40e 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -12,6 +12,7 @@ ArrayAttr, ContainerType, DenseArrayBase, + DenseI64ArrayConstr, IndexType, IntAttr, IntegerAttr, @@ -1262,7 +1263,7 @@ class ExtractValueOp(IRDLOperation): name = "llvm.extractvalue" - position = prop_def(DenseArrayBase) + position = prop_def(DenseI64ArrayConstr) container = operand_def(Attribute) res = result_def(Attribute) @@ -1292,7 +1293,7 @@ class InsertValueOp(IRDLOperation): name = "llvm.insertvalue" - position = prop_def(DenseArrayBase) + position = prop_def(DenseI64ArrayConstr) container = operand_def(Attribute) value = operand_def(Attribute) From b8611e04fefa80ac2238f20311d28ab9b75aa4d6 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 17 Dec 2024 13:21:38 +0000 Subject: [PATCH 06/10] dialects: (arith) add SignlessIntegerBinaryOperation canonicalization (#3583) Generically implements various arith canonicalizations on SignlessIntegerBinaryOperation --- .../dialects/arith/canonicalize.mlir | 10 + tests/interactive/test_app.py | 10 +- xdsl/dialects/arith.py | 220 +++++++++++++++--- xdsl/traits.py | 6 + .../canonicalization_patterns/arith.py | 62 ++--- .../canonicalization_patterns/cf.py | 4 +- .../canonicalization_patterns/scf.py | 4 +- .../canonicalization_patterns/utils.py | 14 +- 8 files changed, 249 insertions(+), 81 deletions(-) diff --git a/tests/filecheck/dialects/arith/canonicalize.mlir b/tests/filecheck/dialects/arith/canonicalize.mlir index 2b49a6a847..8cc3984468 100644 --- a/tests/filecheck/dialects/arith/canonicalize.mlir +++ b/tests/filecheck/dialects/arith/canonicalize.mlir @@ -149,3 +149,13 @@ func.func @test_const_var_const() { %9 = arith.cmpi uge, %int, %int : i32 "test.op"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %int) : (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i32) -> () + +// Subtraction is not commutative so should not have the constant swapped to the right +// CHECK: arith.subi %c2, %a : i32 +%10 = arith.subi %c2, %a : i32 +"test.op"(%10) : (i32) -> () + +// CHECK: %{{.*}} = arith.constant false +%11 = arith.constant true +%12 = arith.addi %11, %11 : i1 +"test.op"(%12) : (i1) -> () diff --git a/tests/interactive/test_app.py b/tests/interactive/test_app.py index ac129f1c74..3e116aed3a 100644 --- a/tests/interactive/test_app.py +++ b/tests/interactive/test_app.py @@ -329,11 +329,11 @@ async def test_rewrites(): await pilot.click("#condense_button") addi_pass = AvailablePass( - display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:AddiIdentityRight", + display_name="AddiOp(%res = arith.addi %n, %c0 : i32):arith.addi:SignlessIntegerBinaryOperationZeroOrUnitRight", module_pass=individual_rewrite.ApplyIndividualRewritePass, pass_spec=list( parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' + 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' ) )[0], ) @@ -354,7 +354,7 @@ async def test_rewrites(): individual_rewrite.ApplyIndividualRewritePass, list( parse_pipeline( - 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' + 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' ) )[0], ), @@ -563,7 +563,7 @@ async def test_apply_individual_rewrite(): n.data is not None and n.data[1] is not None and str(n.data[1]) - == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiConstantProp"}' + == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationConstantProp"}' ): node = n @@ -593,7 +593,7 @@ async def test_apply_individual_rewrite(): n.data is not None and n.data[1] is not None and str(n.data[1]) - == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="AddiIdentityRight"}' + == 'apply-individual-rewrite{matched_operation_index=3 operation_name="arith.addi" pattern_name="SignlessIntegerBinaryOperationZeroOrUnitRight"}' ): node = n diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 768c051863..6850f33ffc 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -45,6 +45,7 @@ from xdsl.pattern_rewriter import RewritePattern from xdsl.printer import Printer from xdsl.traits import ( + Commutative, ConditionallySpeculatable, ConstantLike, HasCanonicalizationPatternsTrait, @@ -195,6 +196,36 @@ class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC): assembly_format = "$lhs `,` $rhs attr-dict `:` type($result)" + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + """ + Performs a python function corresponding to this operation. + + If `i := py_operation(lhs, rhs)` is an int, then this operation can be + canonicalized to a constant with value `i` when the inputs are constants + with values `lhs` and `rhs`. + """ + return None + + @staticmethod + def is_right_zero(attr: AnyIntegerAttr) -> bool: + """ + Returns True only when 'attr' is a right zero for the operation + https://en.wikipedia.org/wiki/Absorbing_element + + Note that this depends on the operation and does *not* imply that + attr.value.data == 0 + """ + return False + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + """ + Return True only when 'attr' is a right unit/identity for the operation + https://en.wikipedia.org/wiki/Identity_element + """ + return False + def __init__( self, operand1: Operation | SSAValue, @@ -209,6 +240,22 @@ def __hash__(self) -> int: return id(self) +class SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait( + HasCanonicalizationPatternsTrait +): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.arith import ( + SignlessIntegerBinaryOperationConstantProp, + SignlessIntegerBinaryOperationZeroOrUnitRight, + ) + + return ( + SignlessIntegerBinaryOperationConstantProp(), + SignlessIntegerBinaryOperationZeroOrUnitRight(), + ) + + class SignlessIntegerBinaryOperationWithOverflow( SignlessIntegerBinaryOperation, abc.ABC ): @@ -318,22 +365,23 @@ def print(self, printer: Printer): printer.print_attribute(self.result.type) -class AddiOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): - @classmethod - def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: - from xdsl.transforms.canonicalization_patterns.arith import ( - AddiConstantProp, - AddiIdentityRight, - ) - - return (AddiIdentityRight(), AddiConstantProp()) - - @irdl_op_definition class AddiOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.addi" - traits = traits_def(Pure(), AddiOpHasCanonicalizationPatternsTrait()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs + rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -400,19 +448,27 @@ def infer_overflow_type(input_type: Attribute) -> Attribute: ) -class MuliHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait): - @classmethod - def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: - from xdsl.transforms.canonicalization_patterns import arith - - return (arith.MuliIdentityRight(), arith.MuliConstantProp()) - - @irdl_op_definition class MuliOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.muli" - traits = traits_def(Pure(), MuliHasCanonicalizationPatterns()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs * rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) + + @staticmethod + def is_right_zero(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 class MulExtendedBase(IRDLOperation): @@ -460,7 +516,17 @@ class MulSIExtendedOp(MulExtendedBase): class SubiOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.subi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs - rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 class DivUISpeculatable(ConditionallySpeculatable): @@ -483,7 +549,15 @@ class DivUIOp(SignlessIntegerBinaryOperation): name = "arith.divui" - traits = traits_def(NoMemoryEffect(), DivUISpeculatable()) + traits = traits_def( + NoMemoryEffect(), + DivUISpeculatable(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition @@ -495,7 +569,14 @@ class DivSIOp(SignlessIntegerBinaryOperation): name = "arith.divsi" - traits = traits_def(NoMemoryEffect()) + traits = traits_def( + NoMemoryEffect(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition @@ -506,21 +587,40 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation): name = "arith.floordivsi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition class CeilDivSIOp(SignlessIntegerBinaryOperation): name = "arith.ceildivsi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition class CeilDivUIOp(SignlessIntegerBinaryOperation): name = "arith.ceildivui" - traits = traits_def(NoMemoryEffect()) + traits = traits_def( + NoMemoryEffect(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr == IntegerAttr(1, attr.type) @irdl_op_definition @@ -567,21 +667,57 @@ class MaxSIOp(SignlessIntegerBinaryOperation): class AndIOp(SignlessIntegerBinaryOperation): name = "arith.andi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs & rhs + + @staticmethod + def is_right_zero(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition class OrIOp(SignlessIntegerBinaryOperation): name = "arith.ori" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs | rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition class XOrIOp(SignlessIntegerBinaryOperation): name = "arith.xori" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), + Commutative(), + SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait(), + ) + + @staticmethod + def py_operation(lhs: int, rhs: int) -> int | None: + return lhs ^ rhs + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -593,7 +729,13 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow): name = "arith.shli" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -606,7 +748,13 @@ class ShRUIOp(SignlessIntegerBinaryOperation): name = "arith.shrui" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 @irdl_op_definition @@ -620,7 +768,13 @@ class ShRSIOp(SignlessIntegerBinaryOperation): name = "arith.shrsi" - traits = traits_def(Pure()) + traits = traits_def( + Pure(), SignlessIntegerBinaryOperationHasCanonicalizationPatternsTrait() + ) + + @staticmethod + def is_right_unit(attr: AnyIntegerAttr) -> bool: + return attr.value.data == 0 class ComparisonOperation(IRDLOperation): diff --git a/xdsl/traits.py b/xdsl/traits.py index 0b0e5383dd..9fd65519c9 100644 --- a/xdsl/traits.py +++ b/xdsl/traits.py @@ -687,6 +687,12 @@ class Pure(NoMemoryEffect, AlwaysSpeculatable): """ +class Commutative(OpTrait): + """ + A trait that signals that an operation is commutative. + """ + + class HasInsnRepresentation(OpTrait, abc.ABC): """ A trait providing information on how to encode an operation using a .insn assember directive. diff --git a/xdsl/transforms/canonicalization_patterns/arith.py b/xdsl/transforms/canonicalization_patterns/arith.py index d8005b3bf5..caa569c529 100644 --- a/xdsl/transforms/canonicalization_patterns/arith.py +++ b/xdsl/transforms/canonicalization_patterns/arith.py @@ -5,33 +5,46 @@ RewritePattern, op_type_rewrite_pattern, ) -from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand +from xdsl.traits import Commutative +from xdsl.transforms.canonicalization_patterns.utils import ( + const_evaluate_operand, + const_evaluate_operand_attribute, +) from xdsl.utils.hints import isa -class AddiIdentityRight(RewritePattern): +class SignlessIntegerBinaryOperationZeroOrUnitRight(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter) -> None: - if (rhs := const_evaluate_operand(op.rhs)) is None: - return - if rhs != 0: + def match_and_rewrite( + self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, / + ): + if (rhs := const_evaluate_operand_attribute(op.rhs)) is None: return - rewriter.replace_matched_op((), (op.lhs,)) + if op.is_right_zero(rhs): + rewriter.replace_matched_op((), (op.rhs,)) + elif op.is_right_unit(rhs): + rewriter.replace_matched_op((), (op.lhs,)) -class AddiConstantProp(RewritePattern): +class SignlessIntegerBinaryOperationConstantProp(RewritePattern): @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.AddiOp, rewriter: PatternRewriter): + def match_and_rewrite( + self, op: arith.SignlessIntegerBinaryOperation, rewriter: PatternRewriter, / + ): if (lhs := const_evaluate_operand(op.lhs)) is None: return if (rhs := const_evaluate_operand(op.rhs)) is None: # Swap inputs if lhs is constant and rhs is not - rewriter.replace_matched_op(arith.AddiOp(op.rhs, op.lhs)) + if op.has_trait(Commutative): + rewriter.replace_matched_op(op.__class__(op.rhs, op.lhs)) return + if (res := op.py_operation(lhs, rhs)) is None: + return assert isinstance(op.result.type, IntegerType | IndexType) + rewriter.replace_matched_op( - arith.ConstantOp.from_int_and_width(lhs + rhs, op.result.type) + arith.ConstantOp.from_int_and_width(res, op.result.type, truncate_bits=True) ) @@ -176,33 +189,6 @@ def match_and_rewrite(self, op: arith.SelectOp, rewriter: PatternRewriter): rewriter.replace_matched_op((), (op.lhs,)) -class MuliIdentityRight(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.MuliOp, rewriter: PatternRewriter): - if (rhs := const_evaluate_operand(op.rhs)) is None: - return - if rhs != 1: - return - - rewriter.replace_matched_op((), (op.lhs,)) - - -class MuliConstantProp(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: arith.MuliOp, rewriter: PatternRewriter): - if (lhs := const_evaluate_operand(op.lhs)) is None: - return - if (rhs := const_evaluate_operand(op.rhs)) is None: - # Swap inputs if rhs is constant and lhs is not - rewriter.replace_matched_op(arith.MuliOp(op.rhs, op.lhs)) - return - - assert isinstance(op.result.type, IntegerType | IndexType) - rewriter.replace_matched_op( - arith.ConstantOp.from_int_and_width(lhs * rhs, op.result.type) - ) - - class ApplyCmpiPredicateToEqualOperands(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: arith.CmpiOp, rewriter: PatternRewriter): diff --git a/xdsl/transforms/canonicalization_patterns/cf.py b/xdsl/transforms/canonicalization_patterns/cf.py index eea4fb276e..0af86049d0 100644 --- a/xdsl/transforms/canonicalization_patterns/cf.py +++ b/xdsl/transforms/canonicalization_patterns/cf.py @@ -15,7 +15,9 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint -from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand +from xdsl.transforms.canonicalization_patterns.utils import ( + const_evaluate_operand, +) class AssertTrue(RewritePattern): diff --git a/xdsl/transforms/canonicalization_patterns/scf.py b/xdsl/transforms/canonicalization_patterns/scf.py index 2285070cc0..2e38bf15f4 100644 --- a/xdsl/transforms/canonicalization_patterns/scf.py +++ b/xdsl/transforms/canonicalization_patterns/scf.py @@ -9,7 +9,9 @@ ) from xdsl.rewriter import InsertPoint from xdsl.traits import ConstantLike -from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand +from xdsl.transforms.canonicalization_patterns.utils import ( + const_evaluate_operand, +) class RehoistConstInLoops(RewritePattern): diff --git a/xdsl/transforms/canonicalization_patterns/utils.py b/xdsl/transforms/canonicalization_patterns/utils.py index 273de0ec26..bab4fdb59a 100644 --- a/xdsl/transforms/canonicalization_patterns/utils.py +++ b/xdsl/transforms/canonicalization_patterns/utils.py @@ -1,13 +1,21 @@ from xdsl.dialects import arith -from xdsl.dialects.builtin import IntegerAttr +from xdsl.dialects.builtin import AnyIntegerAttr, IntegerAttr from xdsl.ir import SSAValue -def const_evaluate_operand(operand: SSAValue) -> int | None: +def const_evaluate_operand_attribute(operand: SSAValue) -> AnyIntegerAttr | None: """ Try to constant evaluate an SSA value, returning None on failure. """ if isinstance(op := operand.owner, arith.ConstantOp) and isinstance( val := op.value, IntegerAttr ): - return val.value.data + return val + + +def const_evaluate_operand(operand: SSAValue) -> int | None: + """ + Try to constant evaluate an SSA value, returning None on failure. + """ + if (attr := const_evaluate_operand_attribute(operand)) is not None: + return attr.value.data From b394263200b3a7bba207037ba7d3a9bebfd09bea Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Wed, 18 Dec 2024 14:44:15 +0000 Subject: [PATCH 07/10] transforms: Allow to pass a pattern rewriter in CSE (#3539) Stacked PRs: * #3540 * __->__#3539 * #3538 * #3537 --- --- --- ### transforms: Allow to pass a pattern rewriter in CSE Without passing the pattern rewriter, CSE couldn't be called inside a pattern rewriter walker, as it would not notify the operations that were deleted or replaced. --- .../canonicalization_patterns/stencil.py | 2 +- .../common_subexpression_elimination.py | 21 +++++++++++-------- xdsl/transforms/control_flow_hoist.py | 4 ++-- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/xdsl/transforms/canonicalization_patterns/stencil.py b/xdsl/transforms/canonicalization_patterns/stencil.py index 6a019275a1..58fe67f0a5 100644 --- a/xdsl/transforms/canonicalization_patterns/stencil.py +++ b/xdsl/transforms/canonicalization_patterns/stencil.py @@ -41,7 +41,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N continue a.replace_by(bbargs[rbargs[i]]) - cse(op.region.block) + cse(op.region.block, rewriter) class ApplyUnusedOperands(RewritePattern): diff --git a/xdsl/transforms/common_subexpression_elimination.py b/xdsl/transforms/common_subexpression_elimination.py index e14441d4a3..9c301fd6a3 100644 --- a/xdsl/transforms/common_subexpression_elimination.py +++ b/xdsl/transforms/common_subexpression_elimination.py @@ -5,6 +5,7 @@ from xdsl.dialects.builtin import ModuleOp, UnregisteredOp from xdsl.ir import Block, Operation, Region, Use from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import PatternRewriter from xdsl.rewriter import Rewriter from xdsl.traits import ( IsolatedFromAbove, @@ -115,19 +116,15 @@ def has_other_side_effecting_op_in_between( return False +@dataclass class CSEDriver: """ Boilerplate class to handle and carry the state for CSE. """ - _rewriter: Rewriter + _rewriter: Rewriter | PatternRewriter = field(default_factory=Rewriter) _to_erase: set[Operation] = field(default_factory=set) - _known_ops: KnownOps = KnownOps() - - def __init__(self): - self._rewriter = Rewriter() - self._to_erase = set() - self._known_ops = KnownOps() + _known_ops: KnownOps = field(default_factory=KnownOps) def _mark_erasure(self, op: Operation): self._to_erase.add(op) @@ -250,8 +247,14 @@ def simplify(self, thing: Operation | Block | Region): self._commit_erasures() -def cse(thing: Operation | Block | Region): - CSEDriver().simplify(thing) +def cse( + thing: Operation | Block | Region, + rewriter: Rewriter | PatternRewriter | None = None, +): + if rewriter is not None: + CSEDriver(_rewriter=rewriter).simplify(thing) + else: + CSEDriver().simplify(thing) class CommonSubexpressionElimination(ModulePass): diff --git a/xdsl/transforms/control_flow_hoist.py b/xdsl/transforms/control_flow_hoist.py index 37b55f9fcb..e30b2f6b98 100644 --- a/xdsl/transforms/control_flow_hoist.py +++ b/xdsl/transforms/control_flow_hoist.py @@ -57,7 +57,7 @@ def match_and_rewrite(self, op: affine.IfOp, rewriter: PatternRewriter): return block = op.parent if block: - cse(block) + cse(block, rewriter) class SCFIfHoistPattern(RewritePattern): @@ -84,7 +84,7 @@ def match_and_rewrite(self, op: scf.IfOp, rewriter: PatternRewriter): block = op.parent if block: # If we hoisted some ops, run CSE on that block to not keep pushing duplicates upward. - cse(block) + cse(block, rewriter) class ControlFlowHoistPass(ModulePass): From 4b15917e639a5ec3c1512735208d4c4d740c2967 Mon Sep 17 00:00:00 2001 From: Chris Vasiladiotis Date: Thu, 19 Dec 2024 08:47:02 +0000 Subject: [PATCH 08/10] dialects (func): Add SymbolUserOpInterface implementation for `func.call` operation (#3652) This PR: - Adds support for the `SymbolUserOpInterface` interface and implements it for `func.call` - Adds tests (pytest and filecheck) of the above Resolves #3497 --- tests/dialects/test_func.py | 41 ++++++++++++- .../filecheck/dialects/func/func_invalid.mlir | 57 +++++++++++++++++++ xdsl/dialects/func.py | 44 +++++++++++++- xdsl/traits.py | 20 +++++++ 4 files changed, 158 insertions(+), 4 deletions(-) diff --git a/tests/dialects/test_func.py b/tests/dialects/test_func.py index 3679e2ee80..755e41b7c3 100644 --- a/tests/dialects/test_func.py +++ b/tests/dialects/test_func.py @@ -3,10 +3,23 @@ from xdsl.builder import Builder, ImplicitBuilder from xdsl.dialects.arith import AddiOp, ConstantOp -from xdsl.dialects.builtin import IntegerAttr, IntegerType, ModuleOp, i32, i64 +from xdsl.dialects.builtin import ( + IntegerAttr, + IntegerType, + ModuleOp, + StringAttr, + i32, + i64, +) from xdsl.dialects.func import CallOp, FuncOp, ReturnOp from xdsl.ir import Block, Region -from xdsl.traits import CallableOpInterface +from xdsl.irdl import ( + IRDLOperation, + attr_def, + irdl_op_definition, + traits_def, +) +from xdsl.traits import CallableOpInterface, SymbolOpInterface from xdsl.utils.exceptions import VerifyException @@ -261,6 +274,30 @@ def test_call_II(): assert_print_op(mod, expected, None) +def test_call_III(): + """Call a symbol that is not func.func""" + + @irdl_op_definition + class SymbolOp(IRDLOperation): + name = "test.symbol" + + sym_name = attr_def(StringAttr) + + traits = traits_def(SymbolOpInterface()) + + def __init__(self, name: str): + return super().__init__(attributes={"sym_name": StringAttr(name)}) + + symop = SymbolOp("foo") + call0 = CallOp("foo", [], []) + mod = ModuleOp([symop, call0]) + + with pytest.raises( + VerifyException, match="'@foo' does not reference a valid function" + ): + mod.verify() + + def test_return(): # Create two constants and add them, then return a = ConstantOp.from_int_and_width(1, i32) diff --git a/tests/filecheck/dialects/func/func_invalid.mlir b/tests/filecheck/dialects/func/func_invalid.mlir index 9fd7953410..afcad9667d 100644 --- a/tests/filecheck/dialects/func/func_invalid.mlir +++ b/tests/filecheck/dialects/func/func_invalid.mlir @@ -38,3 +38,60 @@ builtin.module { // CHECK: Operation does not verify: Unexpected nested symbols in FlatSymbolRefAttr // CHECK-NEXT: Underlying verification failure: expected empty array, but got ["invalid"] + +// ----- + +func.func @bar() { + %1 = "test.op"() : () -> !test.type<"int"> + %2 = func.call @foo(%1) : (!test.type<"int">) -> !test.type<"int"> + func.return +} + +// CHECK: '@foo' could not be found in symbol table + +// ----- + +func.func @foo(%0 : !test.type<"int">) -> !test.type<"int"> + +func.func @bar() { + %1 = func.call @foo() : () -> !test.type<"int"> + func.return +} + +// CHECK: incorrect number of operands for callee + +// ----- + +func.func @foo(%0 : !test.type<"int">) + +func.func @bar() { + %1 = "test.op"() : () -> !test.type<"int"> + %2 = func.call @foo(%1) : (!test.type<"int">) -> !test.type<"int"> + func.return +} + +// CHECK: incorrect number of results for callee + +// ----- + +func.func @foo(%0 : !test.type<"int">) -> !test.type<"int"> + +func.func @bar() { + %1 = "test.op"() : () -> !test.type<"foo"> + %2 = func.call @foo(%1) : (!test.type<"foo">) -> !test.type<"int"> + func.return +} + +// CHECK: operand type mismatch: expected operand type !test.type<"int">, but provided !test.type<"foo"> for operand number 0 + +// ----- + +func.func @foo(%0 : !test.type<"int">) -> !test.type<"int"> + +func.func @bar() { + %1 = "test.op"() : () -> !test.type<"int"> + %2 = func.call @foo(%1) : (!test.type<"int">) -> !test.type<"foo"> + func.return +} + +// CHECK: result type mismatch: expected result type !test.type<"int">, but provided !test.type<"foo"> for result number 0 diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index c5a2e28c58..279f22d3a1 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -41,6 +41,8 @@ IsolatedFromAbove, IsTerminator, SymbolOpInterface, + SymbolTable, + SymbolUserOpInterface, ) from xdsl.utils.exceptions import VerifyException @@ -62,6 +64,42 @@ def get_result_types(cls, op: Operation) -> tuple[Attribute, ...]: return op.function_type.outputs.data +class CallOpSymbolUserOpInterface(SymbolUserOpInterface): + def verify(self, op: Operation) -> None: + assert isinstance(op, CallOp) + + found_callee = SymbolTable.lookup_symbol(op, op.callee) + if not found_callee: + raise VerifyException(f"'{op.callee}' could not be found in symbol table") + + if not isinstance(found_callee, FuncOp): + raise VerifyException(f"'{op.callee}' does not reference a valid function") + + if len(found_callee.function_type.inputs) != len(op.arguments): + raise VerifyException("incorrect number of operands for callee") + + if len(found_callee.function_type.outputs) != len(op.result_types): + raise VerifyException("incorrect number of results for callee") + + for idx, (found_operand, operand) in enumerate( + zip(found_callee.function_type.inputs, (arg.type for arg in op.arguments)) + ): + if found_operand != operand: + raise VerifyException( + f"operand type mismatch: expected operand type {found_operand}, but provided {operand} for operand number {idx}" + ) + + for idx, (found_res, res) in enumerate( + zip(found_callee.function_type.outputs, op.result_types) + ): + if found_res != res: + raise VerifyException( + f"result type mismatch: expected result type {found_res}, but provided {res} for result number {idx}" + ) + + return + + @irdl_op_definition class FuncOp(IRDLOperation): name = "func.func" @@ -108,7 +146,6 @@ def verify_(self) -> None: if len(self.body.blocks) == 0: return - # TODO: how to verify that there is a terminator? entry_block = self.body.blocks.first assert entry_block is not None block_arg_types = entry_block.arg_types @@ -272,11 +309,14 @@ class CallOp(IRDLOperation): callee = prop_def(FlatSymbolRefAttrConstr) res = var_result_def() + traits = traits_def( + CallOpSymbolUserOpInterface(), + ) + assembly_format = ( "$callee `(` $arguments `)` attr-dict `:` functional-type($arguments, $res)" ) - # TODO how do we verify that the types are correct? def __init__( self, callee: str | SymbolRefAttr, diff --git a/xdsl/traits.py b/xdsl/traits.py index 9fd65519c9..23eca97fcd 100644 --- a/xdsl/traits.py +++ b/xdsl/traits.py @@ -239,6 +239,26 @@ def verify(self, op: Operation) -> None: regions += child_op.regions +class SymbolUserOpInterface(OpTrait, abc.ABC): + """ + Used to represent operations that reference Symbol operations. This provides the + ability to perform safe and efficient verification of symbol uses, as well as + additional functionality. + + https://mlir.llvm.org/docs/Interfaces/#symbolinterfaces + """ + + @abc.abstractmethod + def verify(self, op: Operation) -> None: + """ + This method should be adapted to the requirements of specific symbol users per + operation. + + It corresponds to the verifySymbolUses in upstream MLIR. + """ + raise NotImplementedError() + + class SymbolTable(OpTrait): """ SymbolTable operations are containers for Symbol operations. They offer lookup From d5dd188e9017dbd953be3c022ee515938b5bd819 Mon Sep 17 00:00:00 2001 From: Nicolai Stawinoga <36768051+n-io@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:03:07 +0100 Subject: [PATCH 09/10] bug: (csl-lowering) Make multi-apply lowering work (#3614) This PR includes a few small fixes, described below. --------- Co-authored-by: n-io --- .../transforms/convert-stencil-to-csl-stencil.mlir | 6 +++--- xdsl/transforms/convert_stencil_to_csl_stencil.py | 10 ++++++---- xdsl/transforms/csl_stencil_bufferize.py | 9 +++++++++ .../experimental/stencil_tensorize_z_dimension.py | 6 +++++- xdsl/transforms/memref_to_dsd.py | 7 ++++++- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index 8031f867e9..e550a04561 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -238,12 +238,12 @@ builtin.module { // CHECK-NEXT: %0 = tensor.empty() : tensor<1x64xf32> // CHECK-NEXT: csl_stencil.apply(%arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) -> () <{"swaps" = [#csl_stencil.exchange], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%1 : tensor<1x32xf32>, %2 : index, %3 : tensor<1x64xf32>): -// CHECK-NEXT: %4 = csl_stencil.access %3[-1, 0] : tensor<1x64xf32> -// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32> +// CHECK-NEXT: %4 = csl_stencil.access %1[-1, 0] : tensor<1x32xf32> +// CHECK-NEXT: %5 = "tensor.insert_slice"(%4, %3, %2) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<32xf32>, tensor<1x64xf32>, index) -> tensor<1x64xf32> // CHECK-NEXT: csl_stencil.yield %5 : tensor<1x64xf32> // CHECK-NEXT: }, { // CHECK-NEXT: ^1(%6 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %7 : tensor<1x64xf32>): -// CHECK-NEXT: csl_stencil.yield %7 : tensor<1x64xf32> +// CHECK-NEXT: csl_stencil.yield // CHECK-NEXT: }) // CHECK-NEXT: %1 = tensor.empty() : tensor<64xf32> // CHECK-NEXT: csl_stencil.apply(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %1 : tensor<64xf32>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) outs (%arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) <{"swaps" = [#csl_stencil.exchange], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 76981530d4..8551846aee 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -606,19 +606,21 @@ def match_and_rewrite( block = Block(arg_types=[chunk_buf_t, builtin.IndexType(), op.result.type]) block2 = Block(arg_types=[op.input_stencil.type, op.result.type]) - block2.add_op(csl_stencil.YieldOp(block2.args[1])) + block2.add_op(csl_stencil.YieldOp()) - with ImplicitBuilder(block) as (_, offset, acc): + with ImplicitBuilder(block) as (buf, offset, acc): dest = acc for i, acc_offset in enumerate(offsets): ac_op = csl_stencil.AccessOp( - dest, stencil.IndexAttr.get(*acc_offset), chunk_t + buf, stencil.IndexAttr.get(*acc_offset), chunk_t ) assert isa(ac_op.result.type, AnyTensorType) + # inserts 1 (see static_sizes) 1d slice into a 2d tensor at offset (i, `offset`) (see static_offsets) + # where the latter offset is provided dynamically (see offsets) dest = tensor.InsertSliceOp.get( source=ac_op.result, dest=dest, - static_sizes=ac_op.result.type.get_shape(), + static_sizes=[1, *ac_op.result.type.get_shape()], static_offsets=[i, memref.SubviewOp.DYNAMIC_INDEX], offsets=[offset], ).result diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py index 92d2ea9708..bd4f46d9fe 100644 --- a/xdsl/transforms/csl_stencil_bufferize.py +++ b/xdsl/transforms/csl_stencil_bufferize.py @@ -94,6 +94,10 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # convert args buf_args: list[SSAValue] = [] to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)] + # in case of subsequent apply ops accessing this accumulator, replace uses with `bufferization.to_memref` + op.accumulator.replace_by_if( + buf_iter_arg.memref, lambda use: use.operation != buf_iter_arg + ) for arg in [*op.args_rchunk, *op.args_dexchng]: if isa(arg.type, TensorType[Attribute]): to_memrefs.append(new_arg := to_memref_op(arg)) @@ -385,6 +389,11 @@ def match_and_rewrite(self, op: arith.ConstantOp, rewriter: PatternRewriter, /): class InjectApplyOutsIntoLinalgOuts(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): + # require bufferized apply (with op.dest specified) + # zero-output apply ops may be used for communicate-only, to which this pattern does not apply + if not op.dest: + return + yld = op.done_exchange.block.last_op assert isinstance(yld, csl_stencil.YieldOp) new_dest: list[SSAValue] = [] diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 473d6b3628..ddccef2eec 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -84,9 +84,13 @@ def get_required_result_type(op: Operation) -> TensorType[Attribute] | None: tuple[int, ...], ) ): + assert is_tensor(use.operation.source.type) + # inserting an (n-1)d tensor into an (n)d tensor should not require the input tensor to also be (n)d + # instead, drop the first `dimdiff` dimensions + dimdiff = len(static_sizes) - len(use.operation.source.type.shape) return TensorType( use.operation.result.type.get_element_type(), - static_sizes, + static_sizes[dimdiff:], ) for ret in use.operation.results: if isa(r_type := ret.type, TensorType[Attribute]): diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index 98f4a59a7d..a14b06e2b3 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -133,7 +133,12 @@ def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /): last_op = stride_ops[-1] if len(stride_ops) > 0 else last_op offset_ops = self._update_offsets(op, last_op) - rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops]) + new_ops = [*size_ops, *stride_ops, *offset_ops] + if new_ops: + rewriter.replace_matched_op([*size_ops, *stride_ops, *offset_ops]) + else: + # subview has no effect (todo: this could be canonicalized away) + rewriter.replace_matched_op([], new_results=[op.source]) @staticmethod def _update_sizes( From 0e1cec0f1b4329ea771d6ff3eedcd8faee2183cd Mon Sep 17 00:00:00 2001 From: Watermelon Wolverine <29666253+watermelonwolverine@users.noreply.github.com> Date: Thu, 19 Dec 2024 12:01:40 +0100 Subject: [PATCH 10/10] dialects: (vector) Add vector.insertelement and vector.extractelement (#3649) Added "vector.insertelement" and "vector.extractelement" ops --- tests/dialects/test_vector.py | 136 ++++++++++++++++++ .../vector/vector_extractelement_verify.mlir | 27 ++++ .../vector/vector_insertelement_verify.mlir | 34 +++++ .../dialects/vector/vector_pure_ops.mlir | 3 +- .../with-mlir/dialects/vector/ops.mlir | 19 +++ xdsl/dialects/vector.py | 93 ++++++++++++ 6 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 tests/filecheck/dialects/vector/vector_extractelement_verify.mlir create mode 100644 tests/filecheck/dialects/vector/vector_insertelement_verify.mlir create mode 100644 tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir diff --git a/tests/dialects/test_vector.py b/tests/dialects/test_vector.py index 229b619f34..faea6e9bd6 100644 --- a/tests/dialects/test_vector.py +++ b/tests/dialects/test_vector.py @@ -12,7 +12,9 @@ from xdsl.dialects.vector import ( BroadcastOp, CreatemaskOp, + ExtractElementOp, FMAOp, + InsertElementOp, LoadOp, MaskedloadOp, MaskedstoreOp, @@ -517,3 +519,137 @@ def test_vector_create_mask_verify_indexing_exception(): match="Expected an operand value for each dimension of resultant mask.", ): create_mask.verify() + + +def test_vector_extract_element_verify_vector_rank_0_or_1(): + vector_type = VectorType(IndexType(), [3, 3]) + + vector = TestSSAValue(vector_type) + position = TestSSAValue(IndexType()) + extract_element = ExtractElementOp(vector, position) + + with pytest.raises(Exception, match="Unexpected >1 vector rank."): + extract_element.verify() + + +def test_vector_extract_element_construction_1d(): + vector_type = VectorType(IndexType(), [3]) + + vector = TestSSAValue(vector_type) + position = TestSSAValue(IndexType()) + + extract_element = ExtractElementOp(vector, position) + + assert extract_element.vector is vector + assert extract_element.position is position + assert extract_element.result.type == vector_type.element_type + + +def test_vector_extract_element_1d_verify_non_empty_position(): + vector_type = VectorType(IndexType(), [3]) + + vector = TestSSAValue(vector_type) + + extract_element = ExtractElementOp(vector) + + with pytest.raises(Exception, match="Expected position for 1-D vector."): + extract_element.verify() + + +def test_vector_extract_element_construction_0d(): + vector_type = VectorType(IndexType(), []) + + vector = TestSSAValue(vector_type) + + extract_element = ExtractElementOp(vector) + + assert extract_element.vector is vector + assert extract_element.position is None + assert extract_element.result.type == vector_type.element_type + + +def test_vector_extract_element_0d_verify_empty_position(): + vector_type = VectorType(IndexType(), []) + + vector = TestSSAValue(vector_type) + position = TestSSAValue(IndexType()) + + extract_element = ExtractElementOp(vector, position) + + with pytest.raises( + Exception, match="Expected position to be empty with 0-D vector." + ): + extract_element.verify() + + +def test_vector_insert_element_verify_vector_rank_0_or_1(): + vector_type = VectorType(IndexType(), [3, 3]) + + source = TestSSAValue(IndexType()) + dest = TestSSAValue(vector_type) + position = TestSSAValue(IndexType()) + + insert_element = InsertElementOp(source, dest, position) + + with pytest.raises(Exception, match="Unexpected >1 vector rank."): + insert_element.verify() + + +def test_vector_insert_element_construction_1d(): + vector_type = VectorType(IndexType(), [3]) + + source = TestSSAValue(IndexType()) + dest = TestSSAValue(vector_type) + position = TestSSAValue(IndexType()) + + insert_element = InsertElementOp(source, dest, position) + + assert insert_element.source is source + assert insert_element.dest is dest + assert insert_element.position is position + assert insert_element.result.type == vector_type + + +def test_vector_insert_element_1d_verify_non_empty_position(): + vector_type = VectorType(IndexType(), [3]) + + source = TestSSAValue(IndexType()) + dest = TestSSAValue(vector_type) + + insert_element = InsertElementOp(source, dest) + + with pytest.raises( + Exception, + match="Expected position for 1-D vector.", + ): + insert_element.verify() + + +def test_vector_insert_element_construction_0d(): + vector_type = VectorType(IndexType(), []) + + source = TestSSAValue(IndexType()) + dest = TestSSAValue(vector_type) + + insert_element = InsertElementOp(source, dest) + + assert insert_element.source is source + assert insert_element.dest is dest + assert insert_element.position is None + assert insert_element.result.type == vector_type + + +def test_vector_insert_element_0d_verify_empty_position(): + vector_type = VectorType(IndexType(), []) + + source = TestSSAValue(IndexType()) + dest = TestSSAValue(vector_type) + position = TestSSAValue(IndexType()) + + insert_element = InsertElementOp(source, dest, position) + + with pytest.raises( + Exception, + match="Expected position to be empty with 0-D vector.", + ): + insert_element.verify() diff --git a/tests/filecheck/dialects/vector/vector_extractelement_verify.mlir b/tests/filecheck/dialects/vector/vector_extractelement_verify.mlir new file mode 100644 index 0000000000..780d11ceab --- /dev/null +++ b/tests/filecheck/dialects/vector/vector_extractelement_verify.mlir @@ -0,0 +1,27 @@ +// RUN: xdsl-opt --split-input-file --verify-diagnostics %s | filecheck %s + +%vector, %i0 = "test.op"() : () -> (vector, index) + +%0 = "vector.extractelement"(%vector, %i0) : (vector, index) -> index +// CHECK: Expected position to be empty with 0-D vector. + +// ----- + +%vector, %i0 = "test.op"() : () -> (vector<4x4xindex>, index) + +%0 = "vector.extractelement"(%vector, %i0) : (vector<4x4xindex>, index) -> index +// CHECK: Operation does not verify: Unexpected >1 vector rank. + +// ----- + +%vector, %i0= "test.op"() : () -> (vector<4xindex>, index) + +%0 = "vector.extractelement"(%vector, %i0) : (vector<4xindex>, index) -> f64 +// CHECK: Expected result type to match element type of vector operand. + +// ----- + +%vector, %i0 = "test.op"() : () -> (vector<1xindex>, index) + +%1 = "vector.extractelement"(%vector) : (vector<1xindex>) -> index +// CHECK: Expected position for 1-D vector. diff --git a/tests/filecheck/dialects/vector/vector_insertelement_verify.mlir b/tests/filecheck/dialects/vector/vector_insertelement_verify.mlir new file mode 100644 index 0000000000..0f5619fb15 --- /dev/null +++ b/tests/filecheck/dialects/vector/vector_insertelement_verify.mlir @@ -0,0 +1,34 @@ +// RUN: xdsl-opt --split-input-file --verify-diagnostics %s | filecheck %s + +%vector, %i0 = "test.op"() : () -> (vector, index) + +%0 = "vector.insertelement"(%i0, %vector, %i0) : (index, vector, index) -> vector +// CHECK: Expected position to be empty with 0-D vector. + +// ----- + +%vector, %i0 = "test.op"() : () -> (vector<1xindex>, index) + +%1 = "vector.insertelement"(%i0, %vector) : (index, vector<1xindex>) -> vector<1xindex> +// CHECK: Expected position for 1-D vector. + +// ----- + +%vector, %i0, %f0 = "test.op"() : () -> (vector<4xindex>, index, f64) + +%0 = "vector.insertelement"(%f0, %vector, %i0) : (f64, vector<4xindex>, index) -> vector<4xindex> +// CHECK: Expected source operand type to match element type of dest operand. + +// ----- + +%vector, %i0 = "test.op"() : () -> (vector<4xindex>, index) + +%0 = "vector.insertelement"(%i0, %vector, %i0) : (index, vector<4xindex>, index) -> vector<3xindex> +// CHECK: Expected dest operand and result to have matching types. + +// ----- + +%vector, %i0 = "test.op"() : () -> (vector<4x4xindex>, index) + +%0 = "vector.insertelement"(%i0, %vector, %i0) : (index, vector<4x4xindex>, index) -> vector<4x4xindex> +// CHECK: Operation does not verify: Unexpected >1 vector rank. diff --git a/tests/filecheck/dialects/vector/vector_pure_ops.mlir b/tests/filecheck/dialects/vector/vector_pure_ops.mlir index eda9f79ccc..a3c304976b 100644 --- a/tests/filecheck/dialects/vector/vector_pure_ops.mlir +++ b/tests/filecheck/dialects/vector/vector_pure_ops.mlir @@ -5,7 +5,8 @@ "vector.store"(%load, %m0, %i0, %i0) : (vector<2xindex>, memref<4x4xindex>, index, index) -> () %broadcast = "vector.broadcast"(%i0) : (index) -> vector<1xindex> %fma = "vector.fma"(%load, %load, %load) : (vector<2xindex>, vector<2xindex>, vector<2xindex>) -> vector<2xindex> - +%extract_op = "vector.extractelement"(%broadcast, %i0) : (vector<1xindex>, index) -> index +"vector.insertelement"(%extract_op, %broadcast, %i0) : (index, vector<1xindex>, index) -> vector<1xindex> /// Check that unused results from vector.broadcast and vector.fma are eliminated // CHECK: %m0, %i0 = "test.op"() : () -> (memref<4x4xindex>, index) // CHECK-NEXT: %load = "vector.load"(%m0, %i0, %i0) : (memref<4x4xindex>, index, index) -> vector<2xindex> diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir new file mode 100644 index 0000000000..e4c2d5649c --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir @@ -0,0 +1,19 @@ +// RUN: xdsl-opt --print-op-generic %s | mlir-opt --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s + +builtin.module{ + +%vector0, %vector1, %i0 = "test.op"() : () -> (vector, vector<3xindex>, index) +// CHECK: %0, %1, %2 = "test.op"() : () -> (vector, vector<3xindex>, index) + +%0 = "vector.insertelement"(%i0, %vector0) : (index, vector) -> vector +// CHECK-NEXT: %3 = "vector.insertelement"(%2, %0) : (index, vector) -> vector + +%1 = "vector.insertelement"(%i0, %vector1, %i0) : (index, vector<3xindex>, index) -> vector<3xindex> +// CHECK-NEXT: %4 = "vector.insertelement"(%2, %1, %2) : (index, vector<3xindex>, index) -> vector<3xindex> + +%2 = "vector.extractelement"(%vector1, %i0) : (vector<3xindex>, index) -> index +// CHECK-NEXT: %5 = "vector.extractelement"(%1, %2) : (vector<3xindex>, index) -> index + +%3 = "vector.extractelement"(%vector0) : (vector) -> index +// CHECK-NEXT: %6 = "vector.extractelement"(%0) : (vector) -> index +} diff --git a/xdsl/dialects/vector.py b/xdsl/dialects/vector.py index 7432ba2442..fd99afeb70 100644 --- a/xdsl/dialects/vector.py +++ b/xdsl/dialects/vector.py @@ -4,7 +4,9 @@ from xdsl.dialects.builtin import ( IndexType, + IndexTypeConstr, MemRefType, + SignlessIntegerConstraint, VectorBaseTypeAndRankConstraint, VectorBaseTypeConstraint, VectorRankConstraint, @@ -16,6 +18,7 @@ IRDLOperation, irdl_op_definition, operand_def, + opt_operand_def, result_def, traits_def, var_operand_def, @@ -292,6 +295,94 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp: ) +@irdl_op_definition +class ExtractElementOp(IRDLOperation): + name = "vector.extractelement" + vector = operand_def(VectorType) + position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint) + result = result_def(Attribute) + traits = traits_def(Pure()) + + def verify_(self): + assert isa(self.vector.type, VectorType[Attribute]) + + if self.result.type != self.vector.type.element_type: + raise VerifyException( + "Expected result type to match element type of vector operand." + ) + + if self.vector.type.get_num_dims() == 0: + if self.position is not None: + raise VerifyException("Expected position to be empty with 0-D vector.") + return + if self.vector.type.get_num_dims() != 1: + raise VerifyException("Unexpected >1 vector rank.") + if self.position is None: + raise VerifyException("Expected position for 1-D vector.") + + def __init__( + self, + vector: SSAValue | Operation, + position: SSAValue | Operation | None = None, + ): + vector = SSAValue.get(vector) + assert isa(vector.type, VectorType[Attribute]) + + result_type = vector.type.element_type + + super().__init__( + operands=[vector, position], + result_types=[result_type], + ) + + +@irdl_op_definition +class InsertElementOp(IRDLOperation): + name = "vector.insertelement" + source = operand_def(Attribute) + dest = operand_def(VectorType) + position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint) + result = result_def(VectorType) + traits = traits_def(Pure()) + + def verify_(self): + assert isa(self.dest.type, VectorType[Attribute]) + + if self.result.type != self.dest.type: + raise VerifyException( + "Expected dest operand and result to have matching types." + ) + if self.source.type != self.dest.type.element_type: + raise VerifyException( + "Expected source operand type to match element type of dest operand." + ) + + if self.dest.type.get_num_dims() == 0: + if self.position is not None: + raise VerifyException("Expected position to be empty with 0-D vector.") + return + if self.dest.type.get_num_dims() != 1: + raise VerifyException("Unexpected >1 vector rank.") + if self.position is None: + raise VerifyException("Expected position for 1-D vector.") + + def __init__( + self, + source: SSAValue | Operation, + dest: SSAValue | Operation, + position: SSAValue | Operation | None = None, + ): + dest = SSAValue.get(dest) + assert isa(dest.type, VectorType[Attribute]) + + result_type = SSAValue.get(dest).type + + super().__init__( + operands=[source, dest, position], + result_types=[result_type], + ) + + Vector = Dialect( "vector", [ @@ -303,6 +394,8 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp: MaskedstoreOp, PrintOp, CreatemaskOp, + ExtractElementOp, + InsertElementOp, ], [], )