Skip to content

Commit

Permalink
[pallas] Add Triton lowering rule for rsqrt.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 545717971
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Jul 5, 2023
1 parent 260833b commit e9efd24
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,13 @@ def _sqrt_lowering_rule(ctx: TritonLoweringRuleContext, a):
triton_lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule


def _rsqrt_lowering_rule(ctx: TritonLoweringRuleContext, a):
return tl.math.rsqrt(a, _builder=ctx.builder)


triton_lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule


def _neg_lowering_rule(ctx: TritonLoweringRuleContext, a):
return a.__neg__(_builder=ctx.builder)

Expand Down

0 comments on commit e9efd24

Please sign in to comment.