Skip to content

How did you solve the problem of jax.pmap hanging? #61

Answered by Joshuaalbert
SmearingMap asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @halgorthim you can use https://github.com/Joshuaalbert/jaxns/blob/master/jaxns/internals/maps.py#L101.

It behaves just like pmap with a few extra parameters. Use it like this:

from jaxns.internals.maps import chunked_pmap

def embarassingly_parallel_func(*args, **kwargs):
  pass

parallel_func = chunked_pmap(embarassingly_parallel_func, chunksize=...)

results = parallel_func(*args, **kwargs)

#if args[0] is a pytree then you need to specify the batch_size parameter too.
parallel_func = chunked_pmap(embarassingly_parallel_func, chunksize=..., batch_size=...)

chunck_size is slightly misnamed, and means how many parallel workers to use. Much be less that number of devices. batch_size is…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Joshuaalbert
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants