You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In [3]: import array_api_strict as xp
In [5]: xp.ones(3, dtype=xp.float32) @ xp.ones(3, dtype=xp.float64)
Out[5]: Array(3., dtype=array_api_strict.float64)
In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double
It's not immediately clear to me whether we want to paper over it in compat- or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.
The text was updated successfully, but these errors were encountered:
There's no good way to override @ behavior I think. For matmul we can do same-kind type promotion I think, there shouldn't be extra overhead - no other library has mixed-dtype implementations either AFAIK (e.g., see np.matmul.types).
The spec requires that
matmul
follows the type promotion rules for the arguments, but pytorch requires that the dtypes match:It's not immediately clear to me whether we want to paper over it in
compat-
or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.The text was updated successfully, but these errors were encountered: