Skip to content

Commit

Permalink
ψ(`∇´)ψ dicts in jit loops
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Nov 8, 2023
1 parent eca0113 commit acdad03
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions vbjax/tests/test_jaxisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,25 @@ def test_batched_jax_norm():
bbnf = jax.vmap(bnf, 0)
bbJf = jax.vmap(bJf, 0)
assert bbnf(x).shape == bbJf(x).shape == (100, 50)


def test_loop_dict():
import jax, jax.numpy as np
ns = {
"int": 3,
"float": 3.14,
"array": np.r_[:10.0]
}
def op(ns, inputs):
ns["int"] += inputs
return ns, ns["int"]
@jax.jit
def loop(ns: dict, inps: np.ndarray):
ns, _ = jax.lax.scan(op, ns, inps)
return ns

ns1 = loop(ns, np.r_[:3])
ns2 = loop(ns, np.r_[2:4])
assert ns["int"] == 3
assert ns1["int"] == 6
assert ns2["int"] == 8

0 comments on commit acdad03

Please sign in to comment.