Skip to content

Commit

Permalink
dialects: (vector) Add vector.insertelement and vector.extractelement (
Browse files Browse the repository at this point in the history
…#3649)

Added "vector.insertelement" and "vector.extractelement" ops
  • Loading branch information
watermelonwolverine authored Dec 19, 2024
1 parent d5dd188 commit 0e1cec0
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 1 deletion.
136 changes: 136 additions & 0 deletions tests/dialects/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from xdsl.dialects.vector import (
BroadcastOp,
CreatemaskOp,
ExtractElementOp,
FMAOp,
InsertElementOp,
LoadOp,
MaskedloadOp,
MaskedstoreOp,
Expand Down Expand Up @@ -517,3 +519,137 @@ def test_vector_create_mask_verify_indexing_exception():
match="Expected an operand value for each dimension of resultant mask.",
):
create_mask.verify()


def test_vector_extract_element_verify_vector_rank_0_or_1():
vector_type = VectorType(IndexType(), [3, 3])

vector = TestSSAValue(vector_type)
position = TestSSAValue(IndexType())
extract_element = ExtractElementOp(vector, position)

with pytest.raises(Exception, match="Unexpected >1 vector rank."):
extract_element.verify()


def test_vector_extract_element_construction_1d():
vector_type = VectorType(IndexType(), [3])

vector = TestSSAValue(vector_type)
position = TestSSAValue(IndexType())

extract_element = ExtractElementOp(vector, position)

assert extract_element.vector is vector
assert extract_element.position is position
assert extract_element.result.type == vector_type.element_type


def test_vector_extract_element_1d_verify_non_empty_position():
vector_type = VectorType(IndexType(), [3])

vector = TestSSAValue(vector_type)

extract_element = ExtractElementOp(vector)

with pytest.raises(Exception, match="Expected position for 1-D vector."):
extract_element.verify()


def test_vector_extract_element_construction_0d():
vector_type = VectorType(IndexType(), [])

vector = TestSSAValue(vector_type)

extract_element = ExtractElementOp(vector)

assert extract_element.vector is vector
assert extract_element.position is None
assert extract_element.result.type == vector_type.element_type


def test_vector_extract_element_0d_verify_empty_position():
vector_type = VectorType(IndexType(), [])

vector = TestSSAValue(vector_type)
position = TestSSAValue(IndexType())

extract_element = ExtractElementOp(vector, position)

with pytest.raises(
Exception, match="Expected position to be empty with 0-D vector."
):
extract_element.verify()


def test_vector_insert_element_verify_vector_rank_0_or_1():
vector_type = VectorType(IndexType(), [3, 3])

source = TestSSAValue(IndexType())
dest = TestSSAValue(vector_type)
position = TestSSAValue(IndexType())

insert_element = InsertElementOp(source, dest, position)

with pytest.raises(Exception, match="Unexpected >1 vector rank."):
insert_element.verify()


def test_vector_insert_element_construction_1d():
vector_type = VectorType(IndexType(), [3])

source = TestSSAValue(IndexType())
dest = TestSSAValue(vector_type)
position = TestSSAValue(IndexType())

insert_element = InsertElementOp(source, dest, position)

assert insert_element.source is source
assert insert_element.dest is dest
assert insert_element.position is position
assert insert_element.result.type == vector_type


def test_vector_insert_element_1d_verify_non_empty_position():
vector_type = VectorType(IndexType(), [3])

source = TestSSAValue(IndexType())
dest = TestSSAValue(vector_type)

insert_element = InsertElementOp(source, dest)

with pytest.raises(
Exception,
match="Expected position for 1-D vector.",
):
insert_element.verify()


def test_vector_insert_element_construction_0d():
vector_type = VectorType(IndexType(), [])

source = TestSSAValue(IndexType())
dest = TestSSAValue(vector_type)

insert_element = InsertElementOp(source, dest)

assert insert_element.source is source
assert insert_element.dest is dest
assert insert_element.position is None
assert insert_element.result.type == vector_type


def test_vector_insert_element_0d_verify_empty_position():
vector_type = VectorType(IndexType(), [])

source = TestSSAValue(IndexType())
dest = TestSSAValue(vector_type)
position = TestSSAValue(IndexType())

insert_element = InsertElementOp(source, dest, position)

with pytest.raises(
Exception,
match="Expected position to be empty with 0-D vector.",
):
insert_element.verify()
27 changes: 27 additions & 0 deletions tests/filecheck/dialects/vector/vector_extractelement_verify.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: xdsl-opt --split-input-file --verify-diagnostics %s | filecheck %s

%vector, %i0 = "test.op"() : () -> (vector<index>, index)

%0 = "vector.extractelement"(%vector, %i0) : (vector<index>, index) -> index
// CHECK: Expected position to be empty with 0-D vector.

// -----

%vector, %i0 = "test.op"() : () -> (vector<4x4xindex>, index)

%0 = "vector.extractelement"(%vector, %i0) : (vector<4x4xindex>, index) -> index
// CHECK: Operation does not verify: Unexpected >1 vector rank.

// -----

%vector, %i0= "test.op"() : () -> (vector<4xindex>, index)

%0 = "vector.extractelement"(%vector, %i0) : (vector<4xindex>, index) -> f64
// CHECK: Expected result type to match element type of vector operand.

// -----

%vector, %i0 = "test.op"() : () -> (vector<1xindex>, index)

%1 = "vector.extractelement"(%vector) : (vector<1xindex>) -> index
// CHECK: Expected position for 1-D vector.
34 changes: 34 additions & 0 deletions tests/filecheck/dialects/vector/vector_insertelement_verify.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: xdsl-opt --split-input-file --verify-diagnostics %s | filecheck %s

%vector, %i0 = "test.op"() : () -> (vector<index>, index)

%0 = "vector.insertelement"(%i0, %vector, %i0) : (index, vector<index>, index) -> vector<index>
// CHECK: Expected position to be empty with 0-D vector.

// -----

%vector, %i0 = "test.op"() : () -> (vector<1xindex>, index)

%1 = "vector.insertelement"(%i0, %vector) : (index, vector<1xindex>) -> vector<1xindex>
// CHECK: Expected position for 1-D vector.

// -----

%vector, %i0, %f0 = "test.op"() : () -> (vector<4xindex>, index, f64)

%0 = "vector.insertelement"(%f0, %vector, %i0) : (f64, vector<4xindex>, index) -> vector<4xindex>
// CHECK: Expected source operand type to match element type of dest operand.

// -----

%vector, %i0 = "test.op"() : () -> (vector<4xindex>, index)

%0 = "vector.insertelement"(%i0, %vector, %i0) : (index, vector<4xindex>, index) -> vector<3xindex>
// CHECK: Expected dest operand and result to have matching types.

// -----

%vector, %i0 = "test.op"() : () -> (vector<4x4xindex>, index)

%0 = "vector.insertelement"(%i0, %vector, %i0) : (index, vector<4x4xindex>, index) -> vector<4x4xindex>
// CHECK: Operation does not verify: Unexpected >1 vector rank.
3 changes: 2 additions & 1 deletion tests/filecheck/dialects/vector/vector_pure_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
"vector.store"(%load, %m0, %i0, %i0) : (vector<2xindex>, memref<4x4xindex>, index, index) -> ()
%broadcast = "vector.broadcast"(%i0) : (index) -> vector<1xindex>
%fma = "vector.fma"(%load, %load, %load) : (vector<2xindex>, vector<2xindex>, vector<2xindex>) -> vector<2xindex>

%extract_op = "vector.extractelement"(%broadcast, %i0) : (vector<1xindex>, index) -> index
"vector.insertelement"(%extract_op, %broadcast, %i0) : (index, vector<1xindex>, index) -> vector<1xindex>
/// Check that unused results from vector.broadcast and vector.fma are eliminated
// CHECK: %m0, %i0 = "test.op"() : () -> (memref<4x4xindex>, index)
// CHECK-NEXT: %load = "vector.load"(%m0, %i0, %i0) : (memref<4x4xindex>, index, index) -> vector<2xindex>
Expand Down
19 changes: 19 additions & 0 deletions tests/filecheck/mlir-conversion/with-mlir/dialects/vector/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: xdsl-opt --print-op-generic %s | mlir-opt --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s

builtin.module{

%vector0, %vector1, %i0 = "test.op"() : () -> (vector<index>, vector<3xindex>, index)
// CHECK: %0, %1, %2 = "test.op"() : () -> (vector<index>, vector<3xindex>, index)

%0 = "vector.insertelement"(%i0, %vector0) : (index, vector<index>) -> vector<index>
// CHECK-NEXT: %3 = "vector.insertelement"(%2, %0) : (index, vector<index>) -> vector<index>

%1 = "vector.insertelement"(%i0, %vector1, %i0) : (index, vector<3xindex>, index) -> vector<3xindex>
// CHECK-NEXT: %4 = "vector.insertelement"(%2, %1, %2) : (index, vector<3xindex>, index) -> vector<3xindex>

%2 = "vector.extractelement"(%vector1, %i0) : (vector<3xindex>, index) -> index
// CHECK-NEXT: %5 = "vector.extractelement"(%1, %2) : (vector<3xindex>, index) -> index

%3 = "vector.extractelement"(%vector0) : (vector<index>) -> index
// CHECK-NEXT: %6 = "vector.extractelement"(%0) : (vector<index>) -> index
}
93 changes: 93 additions & 0 deletions xdsl/dialects/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from xdsl.dialects.builtin import (
IndexType,
IndexTypeConstr,
MemRefType,
SignlessIntegerConstraint,
VectorBaseTypeAndRankConstraint,
VectorBaseTypeConstraint,
VectorRankConstraint,
Expand All @@ -16,6 +18,7 @@
IRDLOperation,
irdl_op_definition,
operand_def,
opt_operand_def,
result_def,
traits_def,
var_operand_def,
Expand Down Expand Up @@ -292,6 +295,94 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp:
)


@irdl_op_definition
class ExtractElementOp(IRDLOperation):
name = "vector.extractelement"
vector = operand_def(VectorType)
position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint)
result = result_def(Attribute)
traits = traits_def(Pure())

def verify_(self):
assert isa(self.vector.type, VectorType[Attribute])

if self.result.type != self.vector.type.element_type:
raise VerifyException(
"Expected result type to match element type of vector operand."
)

if self.vector.type.get_num_dims() == 0:
if self.position is not None:
raise VerifyException("Expected position to be empty with 0-D vector.")
return
if self.vector.type.get_num_dims() != 1:
raise VerifyException("Unexpected >1 vector rank.")
if self.position is None:
raise VerifyException("Expected position for 1-D vector.")

def __init__(
self,
vector: SSAValue | Operation,
position: SSAValue | Operation | None = None,
):
vector = SSAValue.get(vector)
assert isa(vector.type, VectorType[Attribute])

result_type = vector.type.element_type

super().__init__(
operands=[vector, position],
result_types=[result_type],
)


@irdl_op_definition
class InsertElementOp(IRDLOperation):
name = "vector.insertelement"
source = operand_def(Attribute)
dest = operand_def(VectorType)
position = opt_operand_def(IndexTypeConstr | SignlessIntegerConstraint)
result = result_def(VectorType)
traits = traits_def(Pure())

def verify_(self):
assert isa(self.dest.type, VectorType[Attribute])

if self.result.type != self.dest.type:
raise VerifyException(
"Expected dest operand and result to have matching types."
)
if self.source.type != self.dest.type.element_type:
raise VerifyException(
"Expected source operand type to match element type of dest operand."
)

if self.dest.type.get_num_dims() == 0:
if self.position is not None:
raise VerifyException("Expected position to be empty with 0-D vector.")
return
if self.dest.type.get_num_dims() != 1:
raise VerifyException("Unexpected >1 vector rank.")
if self.position is None:
raise VerifyException("Expected position for 1-D vector.")

def __init__(
self,
source: SSAValue | Operation,
dest: SSAValue | Operation,
position: SSAValue | Operation | None = None,
):
dest = SSAValue.get(dest)
assert isa(dest.type, VectorType[Attribute])

result_type = SSAValue.get(dest).type

super().__init__(
operands=[source, dest, position],
result_types=[result_type],
)


Vector = Dialect(
"vector",
[
Expand All @@ -303,6 +394,8 @@ def get(mask_operands: list[Operation | SSAValue]) -> CreatemaskOp:
MaskedstoreOp,
PrintOp,
CreatemaskOp,
ExtractElementOp,
InsertElementOp,
],
[],
)

0 comments on commit 0e1cec0

Please sign in to comment.