Skip to content

Commit

Permalink
Report some debugging context when pallas->triton lowering fails.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544755063
  • Loading branch information
brianwa84 authored and The jax_triton Authors committed Jun 30, 2023
1 parent 0256a59 commit 7d8bd3f
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class TritonCompilationResult:
lowering_result: TritonLoweringResult


class TritonLoweringException(Exception):
pass


def _eval_index_map(
ctx: TritonModuleContext, idx, block_mapping: Optional[BlockMapping]
):
Expand Down Expand Up @@ -261,7 +265,17 @@ def write_env(var: jax_core.Var, val):
rule_ctx = TritonLoweringRuleContext(
ctx, avals_in, avals_out, eqn_block_infos
)
outvals = rule(rule_ctx, *invals, **eqn.params)
try:
outvals = rule(rule_ctx, *invals, **eqn.params)
except TritonLoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
raise TritonLoweringException(
f"Exception while lowering eqn:\n {eqn}\n"
f"With context:\n {rule_ctx}\n"
f"With inval shapes={map(lambda t: t.shape, invals)}\n"
f"With inval types={map(lambda t: t.type, invals)}\n"
f"In jaxpr:\n{jaxpr}") from e
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
else:
Expand Down

0 comments on commit 7d8bd3f

Please sign in to comment.