-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use float32
in where
#23
Conversation
onnxruntime doesn't implement where for 16, 32 and 64 bit unsigned integers
Onnxruntime doesn't implement ReduceProd for float64. We will error for float64 and otherwise use float32.
ndonnx/_core/_impl.py
Outdated
@@ -1029,6 +1037,34 @@ def _via_i64_f64( | |||
return _variadic_op(arrays, fn, via_dtype, cast_return) | |||
|
|||
|
|||
def _via_i64_f32( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share some of the logic between this and _via_i64_f64
? Specifically, I think providing the types as an input parameter and making sure you don't downcast as you do for i64
here (you can check ndx.iinfo(dtype).bits
/ndx.finfo(dtype).bits
and signedness) is sufficiently general. It would then make it easier to do more upcasting permutations as and when we find the need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah makes sense, I'll do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, float64
is usually the default in our use cases. I thought we had upstreamed these implementations in the past. Maybe @neNasko1 can find some time to add float64
support for ReduceProd
upstream?
Edit: The related issue has a link to a discussion on the onnxruntime issue tracker. People seemed reluctant to implement ops. It should be an easy change though. I'd say we just make a PR upstream and see what happens.
Co-authored-by: Christian Bourjau <[email protected]>
Co-authored-by: Christian Bourjau <[email protected]>
Co-authored-by: Christian Bourjau <[email protected]>
arrays, | ||
int_dtype=dtypes.int64, | ||
float_dtype=dtypes.float64, | ||
use_unsafe_uint_cast=True, # TODO this can cause overflow, we should set it to false and fix all uses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's check this by triggering the Array API tests too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should automate the kernel support / standard diffing process - in the meantime, it looks like Less
and LessEqual
don't support unsigned types (but for some reason Greater
and GreaterEqual
do). There are only 13 uses so we can follow up on this in a subsequent PR.
I will submit a PR to upstream onnxruntime to see what their opinion on the matter is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's follow up on the unsigned cast safety in a separate PR.
Should this be reverted in the next version of |
Yes, let's revert them after the next release + a little bit of a grace period. I created an issue to keep track of it: #42 |
Onnxruntime doesn't implement
ReduceProd
forfloat64
. We will error forfloat64
and otherwise usefloat32
.Also fixes the implementation for unsigned integers