diff --git a/hoi/core/mi.py b/hoi/core/mi.py index 70c36b90..2809e9e5 100644 --- a/hoi/core/mi.py +++ b/hoi/core/mi.py @@ -239,7 +239,7 @@ def mi_gauss(x: jnp.array, y: jnp.array): def _cdist(x, y) -> jnp.ndarray: """Pairwise squared distances between all samples of x and y.""" diff = x.T[:, None, :] - y.T[None] - _dist = jnp.einsum("ijc->ij", diff**2) + _dist = jnp.sum(diff**2, axis=-1) return _dist