You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In torch and jax it is possible to perform an all pairs difference using a one liner black magic represented as follows: dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]
While the aforementioned code is not human readable nor obvious what it is doing, it was not obvious how to represent the equivalent in Haliax due to a subset constraint, however a potential solution is below:
with hax.auto_broadcast():
named1_diff = named1 - named1.rename({"T": "T2"})
Basically the only thing stopping this from working is an explicit check I do to avoid accidentally combining arrays where one isn't a subset of the other.
The other thing I could do is relax the check to be "at least one overlapping axis"
In torch and jax it is possible to perform an all pairs difference using a one liner black magic represented as follows:
dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]
This is performed in the reference implementation of Mamba 2
While the aforementioned code is not human readable nor obvious what it is doing, it was not obvious how to represent the equivalent in Haliax due to a subset constraint, however a potential solution is below:
This issue exists provide better support for this kind of operation.
The text was updated successfully, but these errors were encountered: