-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StateIndex is a Module, but not a PyTree #842
Comments
It seems like in a variety of places, Equinox is torn between making modules py-trees (with all static fields marked as static) versus leaving some static fields as dynamic to allow using This is seems confusing to me (at least for now). Would it be possible to add a flag to |
I believe Looking at your PR it seems like you're trying to avoid having non-arrays in its PyTree structure. I can see that that's a small QoL improvement, which seems reasonable to me. On your latter point: note that it's not possible to have a module declare its static fields at class definition time. Consider for example WDYT? |
Hmm, I think you're mistaken. The
Not exactly. I'm avoiding have non-pytrees in dynamic parameters. Another alternative would be to use a sentinel that's a pytree. For example, class Sentinel(eqx.Module):
pass
sentinel = Sentinel() # Use this instead of sentinel = object()
# When checking, you can do if isinstance(init, Sentinel) instead of init is sentinel
A user of class JaxActivation(eqx.Module):
f: Callable[[Array], Array] = eqx.field(static=True)
def __call__(self, x: Array, /) -> Array:
return self.f(x)
@jit
def f(m: MLP, /): ...
mlp = MLP(3, 2, activation=JaxActivation(jax.nn.relu))
f(mlp) # Won't crash! This will ensure that |
All types are pytrees. JAX is explicit about the fact that even Anyway, to ease use of |
Ah, okay! I'm using the wrong terminology. Let's say "dynamic" then? That is, things that can be passed dynamically to a function decorated by jax.jit. In Flax, they used to call this
Great! Thanks. |
Is this a correct usage of stateful programming?
If so, why isn't
StateIndex
a PyTree?The text was updated successfully, but these errors were encountered: