Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stackless fixes #37

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

stackless fixes #37

wants to merge 2 commits into from

Conversation

mattjj
Copy link

@mattjj mattjj commented Dec 10, 2024

From jax-ml/jax#25372, this is an attempt at fixing up quax in light of the JAX core rewrite in jax-ml/jax@c36e1f7 (aka "stackless").

Discussion: jax-ml/jax#25372

`pytest tests` passes with these changes
@mattjj
Copy link
Author

mattjj commented Dec 19, 2024

@patrick-kidger @nstarman I forgot about this until just now... :P

Does this seem like it's worth merging?

@mattjj mattjj marked this pull request as ready for review December 19, 2024 22:43
@@ -141,76 +141,51 @@ def _wrap_if_array(x: Union[ArrayLike, "Value"]) -> "Value":


class _QuaxTrace(core.Trace[_QuaxTracer]):
def pure(self, val: ArrayLike) -> _QuaxTracer:
if _is_value(val):
raise TypeError(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's important to keep this behavior, can you give me a test that exercises it?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing! The following should raise an error:

import quax
import jax

x = jax.numpy.arange(4.).reshape(2, 2)
key = jax.random.key(0)
y = quax.lora.LoraArray(x, rank=1, key=key)

def f(x):
    return jax.lax.add_p.bind(x, y)

quax.quaxify(f)(y)

Basically this is just a check that all of our Quax values are properly wrapped into tracers.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay! Finding time to get back to this now :D

IIUC from what I'm reading I think this largely LGTM!

Comment on lines +148 to +152
def to_value(self, val):
if isinstance(val, _QuaxTracer) and val._trace.tag is self.tag:
return val.value
else:
return _DenseArrayValue(val)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding of how new-JAX works, is it expected that we ever hit the else branch? Old-JAX had it such that the tracers of process_primitive(..., tracers, ...) was guaranteed to be tracers from the current trace, which would imply always hitting the if statement here.

values = tuple(
x.array if isinstance(x, _DenseArrayValue) else x for x in values
)
try:
rule = _rules[primitive]
except KeyError:
out = _default_process(primitive, values, params)
with core.set_current_trace(self.parent_trace):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha! It's super interesting to see that new-JAX now has each trace record the parent in a stack like this. I've done a couple of hobby projects reimplementing various kinds of JAXlike designs (sadly none of them released yet), and I've always ended up going for something similar.

(Albeit I've tended to go the opposite way and do full-data-dependence rather than full-dynamic-context dependence, but I'm guessing you're constrained there by the desire to do omnistaging.)

else:
out = method(*values, **params)
with core.set_current_trace(self.parent_trace):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think you could hoist each core.set_current_trace out of all these try/except blocks and have just one wrapping the whole thing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants