-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dialects: (arith) Add support for bitcast operation #3805
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
AddiOp, | ||
AddUIExtendedOp, | ||
AndIOp, | ||
BitcastOp, | ||
CeilDivSIOp, | ||
CeilDivUIOp, | ||
CmpfOp, | ||
|
@@ -55,6 +56,7 @@ | |
IndexType, | ||
IntegerAttr, | ||
IntegerType, | ||
Signedness, | ||
TensorType, | ||
VectorType, | ||
f32, | ||
|
@@ -249,6 +251,54 @@ def test_select_op(): | |
assert select_f_op.result.type == f.result.type | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"in_type, out_type, should_verify", | ||
[ | ||
(i1, IntegerType(1, signedness=Signedness.UNSIGNED), True), | ||
(i32, f32, True), | ||
(i64, f64, True), | ||
(i32, i32, True), | ||
(IndexType(), i1, True), | ||
(i1, IndexType(), True), | ||
(f32, IndexType(), True), | ||
(IndexType(), f64, True), | ||
(VectorType(i64, [3]), VectorType(f64, [3]), True), | ||
(VectorType(f32, [3]), VectorType(i32, [3]), True), | ||
# false cases | ||
(i1, i32, False), | ||
(i32, i64, False), | ||
(i64, i32, False), | ||
(f32, i64, False), | ||
(f32, f64, False), | ||
(VectorType(i32, [5]), i32, False), | ||
(i64, VectorType(i64, [5]), False), | ||
(VectorType(i32, [5]), VectorType(f32, [6]), False), | ||
(VectorType(i32, [5]), VectorType(f64, [5]), False), | ||
], | ||
) | ||
def test_bitcast_op(in_type: Attribute, out_type: Attribute, should_verify: bool): | ||
in_arg = TestSSAValue(in_type) | ||
cast = BitcastOp(in_arg, out_type) | ||
|
||
if should_verify: | ||
cast.verify_() | ||
assert cast.result.type == out_type | ||
return | ||
|
||
# expecting test to fail | ||
with pytest.raises(VerifyException) as e: | ||
cast.verify_() | ||
|
||
err_msg1 = "Expected operand and result types to be signless-integer-or-float-like" | ||
err_msg2 = "'arith.bitcast' operand and result types must have equal bitwidths" | ||
err_msg3 = "'arith.bitcast' operand and result must both be containers or scalars" | ||
err_msg4 = ( | ||
"'arith.bitcast' operand and result type elements must have equal bitwidths" | ||
) | ||
err_msg5 = "'arith.bitcast' operand and result types must have the same shape" | ||
Comment on lines
+292
to
+298
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please make these messages part of the |
||
assert e.value.args[0] in [err_msg1, err_msg2, err_msg3, err_msg4, err_msg5] | ||
|
||
|
||
def test_index_cast_op(): | ||
a = ConstantOp.from_int_and_width(0, 32) | ||
cast = IndexCastOp(a, IndexType()) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,9 @@ | |
AnyIntegerAttr, | ||
AnyIntegerAttrConstr, | ||
ContainerOf, | ||
ContainerType, | ||
DenseIntOrFPElementsAttr, | ||
FixedBitwidthType, | ||
Float16Type, | ||
Float32Type, | ||
Float64Type, | ||
|
@@ -19,6 +21,7 @@ | |
IntAttr, | ||
IntegerAttr, | ||
IntegerType, | ||
ShapedType, | ||
TensorType, | ||
UnrankedTensorType, | ||
VectorType, | ||
|
@@ -54,6 +57,7 @@ | |
) | ||
from xdsl.utils.exceptions import VerifyException | ||
from xdsl.utils.str_enum import StrEnum | ||
from xdsl.utils.type import get_element_type_or_self | ||
|
||
boolLike = ContainerOf(IntegerType(1)) | ||
signlessIntegerLike = ContainerOf(AnyOf([IntegerType, IndexType])) | ||
|
@@ -1223,6 +1227,64 @@ class MinnumfOp(FloatingPointLikeBinaryOperation): | |
traits = traits_def(Pure()) | ||
|
||
|
||
@irdl_op_definition | ||
class BitcastOp(IRDLOperation): | ||
name = "arith.bitcast" | ||
|
||
input = operand_def(signlessIntegerLike | floatingPointLike) | ||
result = result_def(signlessIntegerLike | floatingPointLike) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mlir operation also accepts memrefs |
||
|
||
assembly_format = "$input attr-dict `:` type($input) `to` type($result)" | ||
|
||
def __init__(self, in_arg: SSAValue | Operation, target_type: Attribute): | ||
super().__init__(operands=[in_arg], result_types=[target_type]) | ||
|
||
def verify_(self) -> None: | ||
# check if we have a ContainerType | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this comment doesn't add much and it's a bit confusing since we first check for shaped types |
||
in_type = self.input.type | ||
res_type = self.result.type | ||
|
||
if isinstance(in_type, ShapedType) and isinstance(res_type, ShapedType): | ||
if in_type.get_shape() != res_type.get_shape(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it not enough for these to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, the logic in upstream MLIR is enforced by the |
||
raise VerifyException( | ||
"'arith.bitcast' operand and result types must have the same shape" | ||
) | ||
|
||
t1 = get_element_type_or_self(in_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this check not go right at the start of the function? This would save you checking again in the scalar case at the bottom |
||
t2 = get_element_type_or_self(res_type) | ||
if BitcastOp._equal_bitwidths(t1, t2): | ||
jakedves marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return | ||
|
||
raise VerifyException( | ||
"'arith.bitcast' operand and result type elements must have equal bitwidths" | ||
) | ||
|
||
if isinstance(in_type, ContainerType) or isinstance(res_type, ContainerType): | ||
raise VerifyException( | ||
"'arith.bitcast' operand and result must both be containers or scalars" | ||
) | ||
|
||
# at this point we know we have two scalar types | ||
if not BitcastOp._equal_bitwidths(in_type, res_type): | ||
raise VerifyException( | ||
"'arith.bitcast' operand and result types must have equal bitwidths" | ||
) | ||
|
||
@staticmethod | ||
def _equal_bitwidths(type_a: Attribute, type_b: Attribute) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something that could be reused? If so, maybe pull out as a utility method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, it feels like this could be in |
||
if isinstance(type_a, IndexType) or isinstance(type_b, IndexType): | ||
return True | ||
|
||
if isinstance(type_a, FixedBitwidthType) and isinstance( | ||
type_b, FixedBitwidthType | ||
): | ||
return type_a.bitwidth == type_b.bitwidth | ||
|
||
raise VerifyException( | ||
"Expected operand and result types to be signless-integer-or-float-like" | ||
) | ||
Comment on lines
+1283
to
+1285
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it a bit confusing that this returns bool and can also throw an exception. |
||
|
||
|
||
@irdl_op_definition | ||
class IndexCastOp(IRDLOperation): | ||
name = "arith.index_cast" | ||
|
@@ -1420,6 +1482,7 @@ def verify_(self) -> None: | |
MaximumfOp, | ||
MaxnumfOp, | ||
# Casts | ||
BitcastOp, | ||
IndexCastOp, | ||
FPToSIOp, | ||
SIToFPOp, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would personally split out the tests that error from the others.