From e97cf5daffe14831d4e0d4938d54d41037f5eafc Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Wed, 21 Jun 2023 18:15:57 -0700 Subject: [PATCH] Add support for float64 casts. Fixes an issue where literal(0, f64) produced an f32 IR value because we never checked the type of the literal. Along the way, also enables x64 in loops. PiperOrigin-RevId: 542416849 --- jax_triton/pallas/triton_lowering.py | 49 ++++++++++++++++++++-------- tests/pallas_test.py | 25 ++++++++++++++ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/jax_triton/pallas/triton_lowering.py b/jax_triton/pallas/triton_lowering.py index a159bb1a..e19c1dc4 100644 --- a/jax_triton/pallas/triton_lowering.py +++ b/jax_triton/pallas/triton_lowering.py @@ -223,7 +223,14 @@ def lower_jaxpr_to_triton_ir( def read_env(var: jax_core.Atom): if type(var) is jax_core.Literal: - return tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder) + t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder) + dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty + if t.type.scalar != dst_ty: + # _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64 + # comes out of .tolist() as list[float], which then comes out of + # _to_tensor as a block of f32. + t = tl.semantic.cast(t, dst_ty, ctx.builder) + return t return env[var] def read_block_info_env(var: jax_core.Atom): @@ -470,12 +477,18 @@ def _convert_element_type_lowering_rule( return a if new_dtype == jnp.float32: new_dtype = tl.float32 + elif new_dtype == jnp.float64: + new_dtype = tl.float64 elif new_dtype == jnp.float16: new_dtype = tl.float16 elif new_dtype == jnp.bfloat16: new_dtype = tl.bfloat16 elif new_dtype == jnp.int32: new_dtype = tl.int32 + elif new_dtype == jnp.int64: + new_dtype = tl.int64 + else: + raise ValueError(f"Unhandled dtype: {new_dtype}") return tl.semantic.cast(a, new_dtype, ctx.builder) @@ -1207,11 +1220,15 @@ def _for_lowering_rule( def _lower_jaxpr_to_for_loop(ctx: TritonLoweringRuleContext, jaxpr: jax_core.Jaxpr, lower_bound, upper_bound, consts, *args, has_loop_index: bool, - step: int = 1): + step: int = 1, + bound_type: tl.dtype = tl.int32): if step != 1: raise NotImplementedError builder = ctx.builder - step = builder.get_int32(step) + if bound_type == tl.int64: + step = builder.get_int64(step) + else: + step = builder.get_int32(step) current_block = builder.get_insertion_block() for_op = builder.create_for_op( lower_bound, upper_bound, step, [arg.handle for arg in args] @@ -1269,10 +1286,14 @@ def _scan_lowering_rule( in_index_var = jaxpr.invars[num_consts] out_index_var = jaxpr.outvars[0] # Check that the loop index argument is an int32 scalar - if in_index_var.aval.shape != () or in_index_var.aval.dtype != jnp.int32: - raise NotImplementedError - if out_index_var.aval.shape != () or out_index_var.aval.dtype != jnp.int32: - raise NotImplementedError + if (in_index_var.aval.shape != () or + in_index_var.aval.dtype not in (jnp.int32, jnp.int64)): + raise NotImplementedError( + f"not a fori_loop index in: {in_index_var.aval} {jaxpr=}") + if (out_index_var.aval.shape != () or + out_index_var.aval.dtype not in (jnp.int32, jnp.int64)): + raise NotImplementedError( + f"not a fori_loop index out: {out_index_var.aval} {jaxpr=}") # Look for the equation that increments the loop index for i, eqn in enumerate(jaxpr.eqns): if eqn.primitive == lax.add_p: @@ -1293,6 +1314,7 @@ def _scan_lowering_rule( lower_bound = lb.handle ub = lb.__add__(tl.constexpr(length), _builder=builder) upper_bound = ub.handle + bound_type = ub.type has_loop_index = True else: # If there's no carry, the loop index has been DCEd and the body does *not* @@ -1300,14 +1322,15 @@ def _scan_lowering_rule( consts, args = args, [] lower_bound = builder.get_int32(0) upper_bound = builder.get_int32(length) + bound_type = tl.int32 has_loop_index = False for_out = _lower_jaxpr_to_for_loop( - ctx, jaxpr, lower_bound, upper_bound, consts, *args, - has_loop_index=has_loop_index, step=1) + ctx, jaxpr, lower_bound, upper_bound, consts, *args, + has_loop_index=has_loop_index, step=1, bound_type=bound_type) if has_loop_index: # Need to return the final loop index value if the outer scan expects # it as an output - return [tl.core.tensor(upper_bound, tl.int32), *for_out] + return [tl.core.tensor(upper_bound, bound_type), *for_out] return for_out triton_lowering_rules[lax.scan_p] = _scan_lowering_rule @@ -1323,9 +1346,9 @@ def _maybe_pattern_match_fori_loop(ctx: TritonLoweringRuleContext, *args, return None # Check that the first two carry values are scalar ints a1, a2 = cond_in_avals[:2] - if a1.shape != () or a1.dtype != jnp.int32: + if a1.shape != () or a1.dtype not in (jnp.int32, jnp.int64): return None - if a2.shape != () or a2.dtype != jnp.int32: + if a2.shape != () or a2.dtype not in (jnp.int32, jnp.int64): return None # Check that the only eqn in the cond checks the loop index condition v1, v2 = cond_invars[:2] @@ -1375,7 +1398,7 @@ def _maybe_pattern_match_fori_loop(ctx: TritonLoweringRuleContext, *args, *args_block_infos[2:]]) for_out = _lower_jaxpr_to_for_loop(ctx, jaxpr, lb.handle, ub.handle, body_consts, *args, has_loop_index=True, - step=1) + step=1, bound_type=lb.type) return [ub, ub, *for_out] def _while_lowering_rule( diff --git a/tests/pallas_test.py b/tests/pallas_test.py index bca5e2bc..d9a24ab9 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -680,6 +680,31 @@ def setUp(self): if self.INTERPRET: self.skipTest("Control flow not supported in interpreter mode yet.") + def test_loop_with_float64_carry(self): + # Test that the jnp.zeros(f64) loop init_val is actually f64, and that + # fori_loop handles i64 index variables, i.e. error: 'scf.for' op along + # control flow edge from Region #0 to Region #0: source type #0 + # 'tensor<4xf64>' should match input type #0 'tensor<4xf32>' + orig_val = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", True) + try: + @functools.partial(self.pallas_call, + out_shape=jax.ShapeDtypeStruct((4,), jnp.float64), + grid=1, + debug=False) + def f(x_ref, y_ref): + def body(i, acc): + # TODO(sharadmv): DCE loop index but retain carry breaks scan pattern. + # return acc + x_ref[...] + return acc + x_ref[...] + i * 0 + y_ref[...] = lax.fori_loop( + 0, 3, body, jnp.zeros((4,), jnp.float64)) + + np.testing.assert_allclose(np.arange(1, 5.) * 3, + f(jnp.arange(1, 5., dtype=jnp.float64))) + finally: + jax.config.update("jax_enable_x64", orig_val) + def test_cond_simple(self): arg = jnp.float32(0.) @functools.partial(self.pallas_call,