Skip to content

Commit

Permalink
Fix vmap rules when num_index_operands > 0
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 548730993
  • Loading branch information
apaszke authored and The jax_triton Authors committed Jul 17, 2023
1 parent f5b7705 commit 1966b57
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,11 @@ def _block_map_function(new_idx, *args):
if dim is not batching.not_mapped:
indices.insert(dim, new_idx)
return tuple(indices)
idx_avals = [jax_core.ShapedArray((), jnp.int32)] * (len(grid) + 1)
i32_aval = jax_core.ShapedArray((), jnp.int32)
if block_mapping is None:
idx_avals = [i32_aval] * (len(grid) + 1)
else:
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
shape = aval.shape if block_mapping is None else block_mapping.block_shape
Expand Down Expand Up @@ -271,8 +275,11 @@ def _pallas_call_batching_rule(args, dims, *,

all_dims = list(dims) + [0] * len(out_shapes)

batched_block_mappings = map(partial(_batch_block_mapping, grid_mapping.grid),
avals, all_dims, block_mappings)
num_index_operands = grid_mapping.num_index_operands
batched_block_mappings = map(
partial(_batch_block_mapping, grid_mapping.grid),
avals[num_index_operands:], all_dims[num_index_operands:], block_mappings)

batched_in_shapes = tuple(
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
tuple_insert(x.shape, dim, axis_size),
Expand Down

0 comments on commit 1966b57

Please sign in to comment.