Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (arith) Add support for bitcast operation #3805

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/dialects/test_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AddiOp,
AddUIExtendedOp,
AndIOp,
BitcastOp,
CeilDivSIOp,
CeilDivUIOp,
CmpfOp,
Expand Down Expand Up @@ -55,6 +56,7 @@
IndexType,
IntegerAttr,
IntegerType,
Signedness,
TensorType,
VectorType,
f32,
Expand Down Expand Up @@ -249,6 +251,54 @@ def test_select_op():
assert select_f_op.result.type == f.result.type


@pytest.mark.parametrize(
"in_type, out_type, should_verify",
[
(i1, IntegerType(1, signedness=Signedness.UNSIGNED), True),
(i32, f32, True),
(i64, f64, True),
(i32, i32, True),
(IndexType(), i1, True),
(i1, IndexType(), True),
(f32, IndexType(), True),
(IndexType(), f64, True),
(VectorType(i64, [3]), VectorType(f64, [3]), True),
(VectorType(f32, [3]), VectorType(i32, [3]), True),
# false cases
(i1, i32, False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally split out the tests that error from the others.

(i32, i64, False),
(i64, i32, False),
(f32, i64, False),
(f32, f64, False),
(VectorType(i32, [5]), i32, False),
(i64, VectorType(i64, [5]), False),
(VectorType(i32, [5]), VectorType(f32, [6]), False),
(VectorType(i32, [5]), VectorType(f64, [5]), False),
],
)
def test_bitcast_op(in_type: Attribute, out_type: Attribute, should_verify: bool):
in_arg = TestSSAValue(in_type)
cast = BitcastOp(in_arg, out_type)

if should_verify:
cast.verify_()
assert cast.result.type == out_type
return

# expecting test to fail
with pytest.raises(VerifyException) as e:
cast.verify_()

err_msg1 = "Expected operand and result types to be signless-integer-or-float-like"
err_msg2 = "'arith.bitcast' operand and result types must have equal bitwidths"
err_msg3 = "'arith.bitcast' operand and result must both be containers or scalars"
err_msg4 = (
"'arith.bitcast' operand and result type elements must have equal bitwidths"
)
err_msg5 = "'arith.bitcast' operand and result types must have the same shape"
Comment on lines +292 to +298
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make these messages part of the parametrize decorator to match each individual case

assert e.value.args[0] in [err_msg1, err_msg2, err_msg3, err_msg4, err_msg5]


def test_index_cast_op():
a = ConstantOp.from_int_and_width(0, 32)
cast = IndexCastOp(a, IndexType())
Expand Down
63 changes: 63 additions & 0 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
AnyIntegerAttr,
AnyIntegerAttrConstr,
ContainerOf,
ContainerType,
DenseIntOrFPElementsAttr,
FixedBitwidthType,
Float16Type,
Float32Type,
Float64Type,
Expand All @@ -19,6 +21,7 @@
IntAttr,
IntegerAttr,
IntegerType,
ShapedType,
TensorType,
UnrankedTensorType,
VectorType,
Expand Down Expand Up @@ -54,6 +57,7 @@
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.str_enum import StrEnum
from xdsl.utils.type import get_element_type_or_self

boolLike = ContainerOf(IntegerType(1))
signlessIntegerLike = ContainerOf(AnyOf([IntegerType, IndexType]))
Expand Down Expand Up @@ -1223,6 +1227,64 @@ class MinnumfOp(FloatingPointLikeBinaryOperation):
traits = traits_def(Pure())


@irdl_op_definition
class BitcastOp(IRDLOperation):
name = "arith.bitcast"

input = operand_def(signlessIntegerLike | floatingPointLike)
result = result_def(signlessIntegerLike | floatingPointLike)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mlir operation also accepts memrefs


assembly_format = "$input attr-dict `:` type($input) `to` type($result)"

def __init__(self, in_arg: SSAValue | Operation, target_type: Attribute):
super().__init__(operands=[in_arg], result_types=[target_type])

def verify_(self) -> None:
# check if we have a ContainerType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment doesn't add much and it's a bit confusing since we first check for shaped types

in_type = self.input.type
res_type = self.result.type

if isinstance(in_type, ShapedType) and isinstance(res_type, ShapedType):
if in_type.get_shape() != res_type.get_shape():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not enough for these to have_compatible_shape?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, the logic in upstream MLIR is enforced by the SameOperandsAndResultShape which reuses heavily parts of the SameOperandsAndResultType trait that we got in xDSL recently.
Actually, this operation observes the SameOperandsAndResultShape, but since we don't have that yet, maybe reuse part of that infrastructure in the custom verifier here for now as @alexarice suggests/questions.

raise VerifyException(
"'arith.bitcast' operand and result types must have the same shape"
)

t1 = get_element_type_or_self(in_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this check not go right at the start of the function? This would save you checking again in the scalar case at the bottom

t2 = get_element_type_or_self(res_type)
if BitcastOp._equal_bitwidths(t1, t2):
jakedves marked this conversation as resolved.
Show resolved Hide resolved
return

raise VerifyException(
"'arith.bitcast' operand and result type elements must have equal bitwidths"
)

if isinstance(in_type, ContainerType) or isinstance(res_type, ContainerType):
raise VerifyException(
"'arith.bitcast' operand and result must both be containers or scalars"
)

# at this point we know we have two scalar types
if not BitcastOp._equal_bitwidths(in_type, res_type):
raise VerifyException(
"'arith.bitcast' operand and result types must have equal bitwidths"
)

@staticmethod
def _equal_bitwidths(type_a: Attribute, type_b: Attribute) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that could be reused? If so, maybe pull out as a utility method?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it feels like this could be in xdsl.utils.type

if isinstance(type_a, IndexType) or isinstance(type_b, IndexType):
return True

if isinstance(type_a, FixedBitwidthType) and isinstance(
type_b, FixedBitwidthType
):
return type_a.bitwidth == type_b.bitwidth

raise VerifyException(
"Expected operand and result types to be signless-integer-or-float-like"
)
Comment on lines +1283 to +1285
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing that this returns bool and can also throw an exception.
Maybe leave the exception raising to clients and just plainly return true/false here.



@irdl_op_definition
class IndexCastOp(IRDLOperation):
name = "arith.index_cast"
Expand Down Expand Up @@ -1420,6 +1482,7 @@ def verify_(self) -> None:
MaximumfOp,
MaxnumfOp,
# Casts
BitcastOp,
IndexCastOp,
FPToSIOp,
SIToFPOp,
Expand Down