Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU hangs if using random inside a pure_callback call #24720

Open
tengomucho opened this issue Nov 5, 2024 · 5 comments
Open

TPU hangs if using random inside a pure_callback call #24720

tengomucho opened this issue Nov 5, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@tengomucho
Copy link

Description

I can reproduce this on a v5e-litepod8 TPU. It seems that random calls get stuck if called from a pure_callback call. I have not found a workaround so far. Here's a snippet to reproduce the issue:

import jax, jax.numpy as jnp


class DataProcessor:
    def __init__(self, key, topk=0, temperature=1.0):
        self.key = key
        self.topk = topk
        self.temperature = temperature
        self.sample_topk_logits = self._sample_topk_logits
        self.do_something = jax.jit(self._do_something)

    def _do_something(self, logits):
        def inner(logits):
            if self.topk <= 0:
                return jnp.argmax(logits, axis=-1)
            # self.key, _ = jax.random.split(self.key)
            return self.sample_topk_logits(
                logits,
                self.topk,
                self.temperature,
                self.key,
            )

        # If we did this, it would not check the topk value after the first call
        # token = inner(logits)

        # We use a pure_callback to allow the inner function to use python side-effects and
        # detect objects info
        token = jax.pure_callback(
            inner,
            result_shape_dtypes=jax.ShapeDtypeStruct((1,), jnp.int32),
            logits=logits,
        )

        return token

    def _sample_topk_logits(self, logits, topk, temperature, rng):
        """Restricting sampling to the best k logits."""
        if topk <= 0:
            raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0")
        topk_logits, topk_idxs = jax.lax.top_k(logits, topk)
        topk_token = jnp.expand_dims(
            jax.random.categorical(rng, topk_logits / temperature).astype(jnp.int32),
            axis=-1,
        )
        sampled_tokens = jnp.squeeze(jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1).astype(jnp.int32)
        return sampled_tokens


def main():
    key = jax.random.PRNGKey(0)
    scores = jax.random.uniform(key, shape=(1, 32000)) * 9.0 - 3.7

    v = DataProcessor(key)

    # Top k is 0, so it will not use random and it will work
    assert v.topk == 0
    out = v.do_something(scores)
    print(out)
    # We can do it again
    out = v.do_something(scores)
    print(out)

    # Now, we change the topk to 50, and it will use topk sampling, and hang
    print("Now topk sampling")
    print(v.key)
    v.topk = 50
    out = v.do_something(scores)
    print(out)


if __name__ == "__main__":
    main()

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(0,3,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,3,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-56d87a44-w-0', release='6.5.0-1013-gcp', version='#13~22.04.1-Ubuntu SMP Wed Jan 24 23:39:40 UTC 2024', machine='x86_64')
@tengomucho tengomucho added the bug Something isn't working label Nov 5, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 5, 2024

Fundamentally, I think the issue here is that you're misusing pure_callback. As the name implies, the callback function must be pure, and your callback is not pure (it may raise a ValueError in some cases). This is particularly problematic on TPU, because there's no mechanism to halt execution at runtime, and so I think the behavior of your code is undefined.

Does the issue persist if you change your callback funciton to be pure? And as a followup: if your callback is pure, is there any reason to use a callback, or could you just execute the code as part of the normal program flow?

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 5, 2024

Another issue with your code: your approach of setting self.do_something = jax.jit(self._do_something) implicitly assumes self is static and immutable, which is problematic for the reasons discussed in https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods, and when you mutate your class instance (setting v.topk = 50), this will lead to unintended behavior – your jit-compiled method will still be using the old value of the attribute.

I would suggest changing to one of the JIT compilation strategies mentioned in that FAQ entry.

@tengomucho
Copy link
Author

Thanks for the quick feedback @jakevdp! I will try to adapt my code, perhaps making the class a pytree is the best way.

Let me give you a little more background. I am using Jetstream Pytorch, that has a JIT'ed prefill method, and I want to pass a custom select function to do different selection when the parameters of the calling object change, and that is how I god stuck. I will try to re-write it (I tried the partial one, but I think I mush have done something wrong).

@tengomucho
Copy link
Author

OK I re-wrote the example code using pytrees and it worked as you suggested. I will try to make it work with the base code I am working on, thank you for you answer.
I still think it would be good to provide some feedback instead of hanging though.
Last question, you said the function was not pure, what did you mean by that, that it should not have side-effects?

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 5, 2024

Last question, you said the function was not pure, what did you mean by that, that it should not have side-effects?

Precisely – and raising any sort of error from Python is a side-effect that should not be part of a pure callback, and has undefined behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants