-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dialects: (arith) Add support for bitcast operation #3805
base: main
Are you sure you want to change the base?
Conversation
catch up to main
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3805 +/- ##
=======================================
Coverage 91.23% 91.23%
=======================================
Files 461 461
Lines 57487 57534 +47
Branches 5548 5556 +8
=======================================
+ Hits 52449 52494 +45
- Misses 3615 3616 +1
- Partials 1423 1424 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good, thanks for this. I've left a few comments.
) | ||
|
||
@staticmethod | ||
def _equal_bitwidths(type_a: Attribute, type_b: Attribute) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something that could be reused? If so, maybe pull out as a utility method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it feels like this could be in xdsl.utils.type
raise VerifyException( | ||
"Expected operand and result types to be signless-integer-or-float-like" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it a bit confusing that this returns bool and can also throw an exception.
Maybe leave the exception raising to clients and just plainly return true/false here.
err_msg1 = "Expected operand and result types to be signless-integer-or-float-like" | ||
err_msg2 = "'arith.bitcast' operand and result types must have equal bitwidths" | ||
err_msg3 = "'arith.bitcast' operand and result must both be containers or scalars" | ||
err_msg4 = ( | ||
"'arith.bitcast' operand and result type elements must have equal bitwidths" | ||
) | ||
err_msg5 = "'arith.bitcast' operand and result types must have the same shape" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make these messages part of the parametrize
decorator to match each individual case
super().__init__(operands=[in_arg], result_types=[target_type]) | ||
|
||
def verify_(self) -> None: | ||
# check if we have a ContainerType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment doesn't add much and it's a bit confusing since we first check for shaped types
(VectorType(i64, [3]), VectorType(f64, [3]), True), | ||
(VectorType(f32, [3]), VectorType(i32, [3]), True), | ||
# false cases | ||
(i1, i32, False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would personally split out the tests that error from the others.
name = "arith.bitcast" | ||
|
||
input = operand_def(signlessIntegerLike | floatingPointLike) | ||
result = result_def(signlessIntegerLike | floatingPointLike) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mlir operation also accepts memrefs
res_type = self.result.type | ||
|
||
if isinstance(in_type, ShapedType) and isinstance(res_type, ShapedType): | ||
if in_type.get_shape() != res_type.get_shape(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not enough for these to have_compatible_shape
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, the logic in upstream MLIR is enforced by the SameOperandsAndResultShape
which reuses heavily parts of the SameOperandsAndResultType
trait that we got in xDSL recently.
Actually, this operation observes the SameOperandsAndResultShape
, but since we don't have that yet, maybe reuse part of that infrastructure in the custom verifier here for now as @alexarice suggests/questions.
) | ||
|
||
@staticmethod | ||
def _equal_bitwidths(type_a: Attribute, type_b: Attribute) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it feels like this could be in xdsl.utils.type
"'arith.bitcast' operand and result types must have the same shape" | ||
) | ||
|
||
t1 = get_element_type_or_self(in_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this check not go right at the start of the function? This would save you checking again in the scalar case at the bottom
Implementation of MLIR's
arith.bitcast
op. Test cases are passing but type checking ont1
andt2
variables isn't