Skip to content

Commit

Permalink
small update to ELBO calculation to use einsum instead of explicit pr…
Browse files Browse the repository at this point in the history
…oduct+trace
  • Loading branch information
quattro committed Apr 11, 2023
1 parent e864ce4 commit b18dc6d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion susiepca/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def compute_elbo(X: ArrayLike, params: ModelParams) -> ELBOResults:
# (X.T @ E[Z] @ E[W]) is p x p (big!); compute (E[W] @ X.T @ E[Z]) (k x k)
E_ll = (-0.5 * params.tau) * (
jnp.sum(X ** 2) # tr(X.T @ X)
- 2 * jnp.einsum("kp,np,nk", E_W, X, params.mu_z) # tr(E[W] @ X.T @ E[Z])
- 2 * jnp.einsum("kp,np,nk->", E_W, X, params.mu_z) # tr(E[W] @ X.T @ E[Z])
+ jnp.einsum("ij,ji->", E_ZZ, E_WW) # tr(E[Z.T @ Z] @ E[W @ W.T])
) + 0.5 * n_dim * p_dim * jnp.log(
params.tau
Expand Down

0 comments on commit b18dc6d

Please sign in to comment.