From fa884ec0af19f26db62065448025d8a22163a21d Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 21 Dec 2023 08:42:31 -0800 Subject: [PATCH] No longer using a dynamic trace. This existed just to allow for `jnp.zeros{,_like}` to produce Zeros even though none of its inputs were. However we have now removed that as "too magic" (#3), so means that we can simplify this too. --- examples/prng/_core.py | 4 ++-- quax/_core.py | 15 ++++----------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/examples/prng/_core.py b/examples/prng/_core.py index 046ba75..0f3fe2f 100644 --- a/examples/prng/_core.py +++ b/examples/prng/_core.py @@ -165,5 +165,5 @@ def split(key: PRNG_T, num: int = 2) -> Sequence[PRNG_T]: # Allows for `jnp.where(pred, key1, key2)`. @quax.register(lax.select_n_p) -def _(pred, *cases: PRNG) -> PRNG: - return jtu.tree_map(ft.partial(lax.select_n, pred), *cases) +def _(pred: quax.DenseArrayValue, *cases: PRNG) -> PRNG: + return jtu.tree_map(ft.partial(lax.select_n, pred.array), *cases) diff --git a/quax/_core.py b/quax/_core.py index 215c3cd..953a499 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -116,10 +116,7 @@ def _default_process(primitive, values, params): f"Multiple array-ish types {types} are specifying default process rules." ) - # Avoid an infinite loop, by pushing a new interpreter to the dynamic interpreter - # stack. - with jax.ensure_compile_time_eval(): - return default(primitive, values, params) # pyright: ignore + return default(primitive, values, params) # pyright: ignore class _QuaxTrace(core.Trace[_QuaxTracer]): @@ -155,10 +152,7 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zero in_leaves, in_treedef = jtu.tree_flatten(in_values) fun, out_treedef1 = _custom_jvp_fun_wrap(fun, self.main, in_treedef) # pyright: ignore jvp, out_treedef2 = _custom_jvp_jvp_wrap(jvp, self.main, in_treedef) # pyright: ignore - with jax.ensure_compile_time_eval(): - out_leaves = primitive.bind( - fun, jvp, *in_leaves, symbolic_zeros=symbolic_zeros - ) + out_leaves = primitive.bind(fun, jvp, *in_leaves, symbolic_zeros=symbolic_zeros) _, out_treedef = lu.merge_linear_aux(out_treedef1, out_treedef2) out_values = jtu.tree_unflatten(out_treedef, out_leaves) return [_QuaxTracer(self, x) for x in out_values] @@ -250,7 +244,7 @@ def __wrapped__(self): return self.fn def __call__(self, *args, **kwargs): - with core.new_main(_QuaxTrace, dynamic=True) as main: + with core.new_main(_QuaxTrace) as main: trace = _QuaxTrace(main, core.cur_sublevel()) # Note that we do *not* wrap arraylikes here. We let that happen in # `_QuaxTrace.{pure,lift}` as necessary. This means that we can do e.g. @@ -442,8 +436,7 @@ def _(*args: ArrayValue, jaxpr, inline, **kwargs): else: leaves, treedef = jtu.tree_flatten(args) # remove all Values flat_fun = lambda x: fun(*jtu.tree_unflatten(treedef, x)) - with jax.ensure_compile_time_eval(): # replace the dynamic QuaxTrace - return jax.jit(flat_fun)(leaves) # now we can call without Quax. + return jax.jit(flat_fun)(leaves) # TODO: also register higher-order primitives like `lax.cond_p` etc.