From c3721db034b4311e0a8799d6a28ce8aea7be1da1 Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Fri, 23 Jun 2023 12:32:20 -0700 Subject: [PATCH] Fix a bug in lowering of integer powers. Also avoids asking for broadcasts in the case where the reshape target shape is scalar. PiperOrigin-RevId: 542929594 --- jax_triton/pallas/triton_lowering.py | 10 ++++-- tests/pallas_test.py | 49 +++++++++++++++++++--------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/jax_triton/pallas/triton_lowering.py b/jax_triton/pallas/triton_lowering.py index ff2b1926..37b10b10 100644 --- a/jax_triton/pallas/triton_lowering.py +++ b/jax_triton/pallas/triton_lowering.py @@ -445,7 +445,9 @@ def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y): if y == 3: return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder) if y == -2: - return tl.math.rsqrt(a, _builder=ctx.builder) + one_ = tl.core._to_tensor(1.0, ctx.builder) + a_sq = a.__mul__(a, _builder=ctx.builder) + return one_.__truediv__(a_sq, _builder=ctx.builder) return tl.math.pow(a, y, _builder=ctx.builder) @@ -620,8 +622,10 @@ def _reshape_lowering_rule( if tuple(s.value for s in a.shape) == dst_shp: return a if not a.type.is_block(): - return tl.broadcast_to(a, [tl.constexpr(s) for s in dst_shp], - _builder=ctx.builder) + if dst_shp: + return tl.broadcast_to(a, [tl.constexpr(s) for s in dst_shp], + _builder=ctx.builder) + return a # Expand-dims or reduce-sum to handle singleton dims. if ([s.value for s in a.shape if s.value != 1] == [s for s in dst_shp if s != 1]): diff --git a/tests/pallas_test.py b/tests/pallas_test.py index d9a24ab9..d98657c0 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -521,7 +521,7 @@ def test_reduce_only_dim(self, use_store): @functools.partial( self.pallas_call, out_shape=out_shape, - grid=1, debug=True) + grid=1, debug=False) def reduce(x_ref, y_ref): x = pl.load(x_ref, (jnp.arange(m),)) y = jnp.sum(x, axis=-1) @@ -846,7 +846,7 @@ def test_scan_cond_vm_explicit_ref_arg(self): pl.BlockSpec(lambda i: (i,), (bx,))], # x out_specs=pl.BlockSpec(lambda i: (i,), (bx,)), grid=jt.cdiv(x.shape[0], bx), - debug=True) + debug=False) def f(program_ref, params_ref, x_ref, out_ref): x = x_ref[...] @@ -1069,15 +1069,23 @@ def body(i): class PallasControlFlowInterpreterTest(PallasControlFlowTest): INTERPRET = True -class PallasCallAutodifferentiationTest(PallasTest): - - @parameterized.named_parameters(*[ +AD_TEST_CASES = [ ("square", lambda x: x * x), + ("square_pow", lambda x: x ** 2), + ("square_fn", jnp.square), ("add_one", lambda x: x + 1.), ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated - ]) + ("reciprocal", jnp.reciprocal), + ("one_over_x", lambda x: 1. / x), + ("recip_exp_sq", lambda x: jnp.reciprocal(jnp.exp(x) ** 2)), + ("exp_neg_sq", lambda x: jnp.exp(-x) ** 2), + ("sin", jnp.sin), + ("tanh", jnp.tanh), +] + +class PallasCallAutodifferentiationTest(PallasTest): + + @parameterized.named_parameters(*AD_TEST_CASES) def test_jvp(self, impl): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), @@ -1097,13 +1105,24 @@ def pallas_impl(x_ref, o_ref): rtol=1e-5) jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2) - @parameterized.named_parameters(*[ - ("square", lambda x: x * x), - ("add_one", lambda x: x + 1.), - ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated - ]) + @parameterized.named_parameters(*AD_TEST_CASES) + def test_pallas_around_grad(self, impl): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((), jnp.float32), + name=self.id().split(".")[-1], + debug=True, + grid=1) + def pallas_impl(x_ref, o_ref): + x = x_ref[()] + o_ref[()] = jax.grad(impl)(x) + + x = random.normal(random.PRNGKey(0)) + out_grad = pallas_impl(x) + out_grad_ref = jax.grad(impl)(x) + np.testing.assert_allclose(out_grad, out_grad_ref, atol=1e-5, rtol=1e-5) + + @parameterized.named_parameters(*AD_TEST_CASES) def test_jvp_slice(self, impl): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),