Skip to content

Commit

Permalink
Add support for float64 casts.
Browse files Browse the repository at this point in the history
Fixes an issue where literal(0, f64) produced an f32 IR value because we never checked the type of the literal.

Along the way, also enables x64 in loops.

PiperOrigin-RevId: 542416849
  • Loading branch information
brianwa84 authored and The jax_triton Authors committed Jun 22, 2023
1 parent 3bf5654 commit e97cf5d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 13 deletions.
49 changes: 36 additions & 13 deletions jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,14 @@ def lower_jaxpr_to_triton_ir(

def read_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
return tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder)
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder)
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty
if t.type.scalar != dst_ty:
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64
# comes out of .tolist() as list[float], which then comes out of
# _to_tensor as a block of f32.
t = tl.semantic.cast(t, dst_ty, ctx.builder)
return t
return env[var]

def read_block_info_env(var: jax_core.Atom):
Expand Down Expand Up @@ -470,12 +477,18 @@ def _convert_element_type_lowering_rule(
return a
if new_dtype == jnp.float32:
new_dtype = tl.float32
elif new_dtype == jnp.float64:
new_dtype = tl.float64
elif new_dtype == jnp.float16:
new_dtype = tl.float16
elif new_dtype == jnp.bfloat16:
new_dtype = tl.bfloat16
elif new_dtype == jnp.int32:
new_dtype = tl.int32
elif new_dtype == jnp.int64:
new_dtype = tl.int64
else:
raise ValueError(f"Unhandled dtype: {new_dtype}")
return tl.semantic.cast(a, new_dtype, ctx.builder)


Expand Down Expand Up @@ -1207,11 +1220,15 @@ def _for_lowering_rule(
def _lower_jaxpr_to_for_loop(ctx: TritonLoweringRuleContext, jaxpr: jax_core.Jaxpr,
lower_bound, upper_bound, consts, *args,
has_loop_index: bool,
step: int = 1):
step: int = 1,
bound_type: tl.dtype = tl.int32):
if step != 1:
raise NotImplementedError
builder = ctx.builder
step = builder.get_int32(step)
if bound_type == tl.int64:
step = builder.get_int64(step)
else:
step = builder.get_int32(step)
current_block = builder.get_insertion_block()
for_op = builder.create_for_op(
lower_bound, upper_bound, step, [arg.handle for arg in args]
Expand Down Expand Up @@ -1269,10 +1286,14 @@ def _scan_lowering_rule(
in_index_var = jaxpr.invars[num_consts]
out_index_var = jaxpr.outvars[0]
# Check that the loop index argument is an int32 scalar
if in_index_var.aval.shape != () or in_index_var.aval.dtype != jnp.int32:
raise NotImplementedError
if out_index_var.aval.shape != () or out_index_var.aval.dtype != jnp.int32:
raise NotImplementedError
if (in_index_var.aval.shape != () or
in_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
raise NotImplementedError(
f"not a fori_loop index in: {in_index_var.aval} {jaxpr=}")
if (out_index_var.aval.shape != () or
out_index_var.aval.dtype not in (jnp.int32, jnp.int64)):
raise NotImplementedError(
f"not a fori_loop index out: {out_index_var.aval} {jaxpr=}")
# Look for the equation that increments the loop index
for i, eqn in enumerate(jaxpr.eqns):
if eqn.primitive == lax.add_p:
Expand All @@ -1293,21 +1314,23 @@ def _scan_lowering_rule(
lower_bound = lb.handle
ub = lb.__add__(tl.constexpr(length), _builder=builder)
upper_bound = ub.handle
bound_type = ub.type
has_loop_index = True
else:
# If there's no carry, the loop index has been DCEd and the body does *not*
# expect a loop index as an argument.
consts, args = args, []
lower_bound = builder.get_int32(0)
upper_bound = builder.get_int32(length)
bound_type = tl.int32
has_loop_index = False
for_out = _lower_jaxpr_to_for_loop(
ctx, jaxpr, lower_bound, upper_bound, consts, *args,
has_loop_index=has_loop_index, step=1)
ctx, jaxpr, lower_bound, upper_bound, consts, *args,
has_loop_index=has_loop_index, step=1, bound_type=bound_type)
if has_loop_index:
# Need to return the final loop index value if the outer scan expects
# it as an output
return [tl.core.tensor(upper_bound, tl.int32), *for_out]
return [tl.core.tensor(upper_bound, bound_type), *for_out]
return for_out

triton_lowering_rules[lax.scan_p] = _scan_lowering_rule
Expand All @@ -1323,9 +1346,9 @@ def _maybe_pattern_match_fori_loop(ctx: TritonLoweringRuleContext, *args,
return None
# Check that the first two carry values are scalar ints
a1, a2 = cond_in_avals[:2]
if a1.shape != () or a1.dtype != jnp.int32:
if a1.shape != () or a1.dtype not in (jnp.int32, jnp.int64):
return None
if a2.shape != () or a2.dtype != jnp.int32:
if a2.shape != () or a2.dtype not in (jnp.int32, jnp.int64):
return None
# Check that the only eqn in the cond checks the loop index condition
v1, v2 = cond_invars[:2]
Expand Down Expand Up @@ -1375,7 +1398,7 @@ def _maybe_pattern_match_fori_loop(ctx: TritonLoweringRuleContext, *args,
*args_block_infos[2:]])
for_out = _lower_jaxpr_to_for_loop(ctx, jaxpr, lb.handle, ub.handle,
body_consts, *args, has_loop_index=True,
step=1)
step=1, bound_type=lb.type)
return [ub, ub, *for_out]

def _while_lowering_rule(
Expand Down
25 changes: 25 additions & 0 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,31 @@ def setUp(self):
if self.INTERPRET:
self.skipTest("Control flow not supported in interpreter mode yet.")

def test_loop_with_float64_carry(self):
# Test that the jnp.zeros(f64) loop init_val is actually f64, and that
# fori_loop handles i64 index variables, i.e. error: 'scf.for' op along
# control flow edge from Region #0 to Region #0: source type #0
# 'tensor<4xf64>' should match input type #0 'tensor<4xf32>'
orig_val = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", True)
try:
@functools.partial(self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float64),
grid=1,
debug=False)
def f(x_ref, y_ref):
def body(i, acc):
# TODO(sharadmv): DCE loop index but retain carry breaks scan pattern.
# return acc + x_ref[...]
return acc + x_ref[...] + i * 0
y_ref[...] = lax.fori_loop(
0, 3, body, jnp.zeros((4,), jnp.float64))

np.testing.assert_allclose(np.arange(1, 5.) * 3,
f(jnp.arange(1, 5., dtype=jnp.float64)))
finally:
jax.config.update("jax_enable_x64", orig_val)

def test_cond_simple(self):
arg = jnp.float32(0.)
@functools.partial(self.pallas_call,
Expand Down

0 comments on commit e97cf5d

Please sign in to comment.