Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch matmul does not handle different dtypes #245

Open
ev-br opened this issue Jan 30, 2025 · 3 comments
Open

torch matmul does not handle different dtypes #245

ev-br opened this issue Jan 30, 2025 · 3 comments

Comments

@ev-br
Copy link
Contributor

ev-br commented Jan 30, 2025

The spec requires that matmul follows the type promotion rules for the arguments, but pytorch requires that the dtypes match:

In [3]: import array_api_strict as xp

In [5]: xp.ones(3, dtype=xp.float32) @ xp.ones(3, dtype=xp.float64)
Out[5]: Array(3., dtype=array_api_strict.float64)

In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)

RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double

It's not immediately clear to me whether we want to paper over it in compat- or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.

@rgommers
Copy link
Member

There's no good way to override @ behavior I think. For matmul we can do same-kind type promotion I think, there shouldn't be extra overhead - no other library has mixed-dtype implementations either AFAIK (e.g., see np.matmul.types).

@ev-br
Copy link
Contributor Author

ev-br commented Jan 30, 2025

Numpy seems to do it:

In [8]: np.ones(3) @ np.ones(3, dtype=complex)
Out[8]: np.complex128(3+0j)

cross-ref https://discuss.pytorch.org/t/matmul-mixed-dtypes/216044 for a pytorch discourse question.

@rgommers
Copy link
Member

Yeah I know, I didn't say it doesn't - I meant it does internal upcasting and then calls a routine with both dtypes being the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants