Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use float32 in where #23

Merged
merged 8 commits into from
Jul 24, 2024
Merged

Use float32 in where #23

merged 8 commits into from
Jul 24, 2024

Conversation

MatejUrbanQC
Copy link
Contributor

@MatejUrbanQC MatejUrbanQC commented Jul 22, 2024

Onnxruntime doesn't implement ReduceProd for float64. We will error for
float64 and otherwise use float32.

Also fixes the implementation for unsigned integers

onnxruntime doesn't implement where for 16, 32 and 64 bit unsigned integers
Onnxruntime doesn't implement ReduceProd for float64. We will error for
float64 and otherwise use float32.
@MatejUrbanQC MatejUrbanQC linked an issue Jul 22, 2024 that may be closed by this pull request
@@ -1029,6 +1037,34 @@ def _via_i64_f64(
return _variadic_op(arrays, fn, via_dtype, cast_return)


def _via_i64_f32(
Copy link
Member

@adityagoel4512 adityagoel4512 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share some of the logic between this and _via_i64_f64? Specifically, I think providing the types as an input parameter and making sure you don't downcast as you do for i64 here (you can check ndx.iinfo(dtype).bits/ndx.finfo(dtype).bits and signedness) is sufficiently general. It would then make it easier to do more upcasting permutations as and when we find the need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah makes sense, I'll do that.

Copy link
Collaborator

@cbourjau cbourjau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, float64 is usually the default in our use cases. I thought we had upstreamed these implementations in the past. Maybe @neNasko1 can find some time to add float64 support for ReduceProd upstream?

Edit: The related issue has a link to a discussion on the onnxruntime issue tracker. People seemed reluctant to implement ops. It should be an easy change though. I'd say we just make a PR upstream and see what happens.

MatejUrbanQC and others added 3 commits July 23, 2024 14:01
Co-authored-by: Christian Bourjau <[email protected]>
Co-authored-by: Christian Bourjau <[email protected]>
Co-authored-by: Christian Bourjau <[email protected]>
arrays,
int_dtype=dtypes.int64,
float_dtype=dtypes.float64,
use_unsafe_uint_cast=True, # TODO this can cause overflow, we should set it to false and fix all uses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check this by triggering the Array API tests too.

Copy link
Member

@adityagoel4512 adityagoel4512 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should automate the kernel support / standard diffing process - in the meantime, it looks like Less and LessEqual don't support unsigned types (but for some reason Greater and GreaterEqual do). There are only 13 uses so we can follow up on this in a subsequent PR.

@neNasko1
Copy link
Contributor

Unfortunately, float64 is usually the default in our use cases. I thought we had upstreamed these implementations in the past. Maybe @neNasko1 can find some time to add float64 support for ReduceProd upstream?

I will submit a PR to upstream onnxruntime to see what their opinion on the matter is.

Copy link
Member

@adityagoel4512 adityagoel4512 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let's follow up on the unsigned cast safety in a separate PR.

@MatejUrbanQC MatejUrbanQC merged commit 69e5d64 into main Jul 24, 2024
16 checks passed
@adityagoel4512 adityagoel4512 deleted the 20-ndxprod-fails-for-float64 branch July 25, 2024 12:00
@neNasko1
Copy link
Contributor

neNasko1 commented Aug 1, 2024

Should this be reverted in the next version of onnxruntime as per PR.

@cbourjau
Copy link
Collaborator

cbourjau commented Aug 1, 2024

Yes, let's revert them after the next release + a little bit of a grace period. I created an issue to keep track of it: #42

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ndx.prod() fails for float64
4 participants