Skip to content

Commit

Permalink
Fix a bug in lowering of integer powers.
Browse files Browse the repository at this point in the history
Also avoids asking for broadcasts in the case where the reshape target shape is scalar.

PiperOrigin-RevId: 542929594
  • Loading branch information
brianwa84 authored and The jax_triton Authors committed Jun 23, 2023
1 parent 4f97b83 commit c3721db
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
10 changes: 7 additions & 3 deletions jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
if y == 3:
return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder)
if y == -2:
return tl.math.rsqrt(a, _builder=ctx.builder)
one_ = tl.core._to_tensor(1.0, ctx.builder)
a_sq = a.__mul__(a, _builder=ctx.builder)
return one_.__truediv__(a_sq, _builder=ctx.builder)
return tl.math.pow(a, y, _builder=ctx.builder)


Expand Down Expand Up @@ -620,8 +622,10 @@ def _reshape_lowering_rule(
if tuple(s.value for s in a.shape) == dst_shp:
return a
if not a.type.is_block():
return tl.broadcast_to(a, [tl.constexpr(s) for s in dst_shp],
_builder=ctx.builder)
if dst_shp:
return tl.broadcast_to(a, [tl.constexpr(s) for s in dst_shp],
_builder=ctx.builder)
return a
# Expand-dims or reduce-sum to handle singleton dims.
if ([s.value for s in a.shape if s.value != 1] ==
[s for s in dst_shp if s != 1]):
Expand Down
49 changes: 34 additions & 15 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def test_reduce_only_dim(self, use_store):
@functools.partial(
self.pallas_call,
out_shape=out_shape,
grid=1, debug=True)
grid=1, debug=False)
def reduce(x_ref, y_ref):
x = pl.load(x_ref, (jnp.arange(m),))
y = jnp.sum(x, axis=-1)
Expand Down Expand Up @@ -846,7 +846,7 @@ def test_scan_cond_vm_explicit_ref_arg(self):
pl.BlockSpec(lambda i: (i,), (bx,))], # x
out_specs=pl.BlockSpec(lambda i: (i,), (bx,)),
grid=jt.cdiv(x.shape[0], bx),
debug=True)
debug=False)
def f(program_ref, params_ref, x_ref, out_ref):
x = x_ref[...]

Expand Down Expand Up @@ -1069,15 +1069,23 @@ def body(i):
class PallasControlFlowInterpreterTest(PallasControlFlowTest):
INTERPRET = True

class PallasCallAutodifferentiationTest(PallasTest):

@parameterized.named_parameters(*[
AD_TEST_CASES = [
("square", lambda x: x * x),
("square_pow", lambda x: x ** 2),
("square_fn", jnp.square),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
# ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is
# updated
])
("reciprocal", jnp.reciprocal),
("one_over_x", lambda x: 1. / x),
("recip_exp_sq", lambda x: jnp.reciprocal(jnp.exp(x) ** 2)),
("exp_neg_sq", lambda x: jnp.exp(-x) ** 2),
("sin", jnp.sin),
("tanh", jnp.tanh),
]

class PallasCallAutodifferentiationTest(PallasTest):

@parameterized.named_parameters(*AD_TEST_CASES)
def test_jvp(self, impl):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
Expand All @@ -1097,13 +1105,24 @@ def pallas_impl(x_ref, o_ref):
rtol=1e-5)
jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2)

@parameterized.named_parameters(*[
("square", lambda x: x * x),
("add_one", lambda x: x + 1.),
("exp", jnp.exp),
# ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is
# updated
])
@parameterized.named_parameters(*AD_TEST_CASES)
def test_pallas_around_grad(self, impl):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((), jnp.float32),
name=self.id().split(".")[-1],
debug=True,
grid=1)
def pallas_impl(x_ref, o_ref):
x = x_ref[()]
o_ref[()] = jax.grad(impl)(x)

x = random.normal(random.PRNGKey(0))
out_grad = pallas_impl(x)
out_grad_ref = jax.grad(impl)(x)
np.testing.assert_allclose(out_grad, out_grad_ref, atol=1e-5, rtol=1e-5)

@parameterized.named_parameters(*AD_TEST_CASES)
def test_jvp_slice(self, impl):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
Expand Down

0 comments on commit c3721db

Please sign in to comment.