Skip to content

Commit

Permalink
Import openai/triton from GitHub.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547484298
  • Loading branch information
chsigg authored and The jax_triton Authors committed Jul 12, 2023
1 parent 96f6940 commit 56ffd00
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def lower_jaxpr_to_triton_module(
ir_context = tl_ir.context()
ir_context.load_triton()
builder = tl_ir.builder(ir_context)
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
builder.arch = triton_kernel_call_lib.get_compute_capability(device)
module = builder.create_module()
in_avals = [var.aval for var in jaxpr.invars]
triton_types = [get_triton_type(x) for x in in_avals]
Expand Down
3 changes: 2 additions & 1 deletion jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ def get_or_create_triton_kernel(
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
module = code_gen.ast_to_ttir(
fn, signature, specialization, constants, debug=dump
fn, signature, specialization, constants, debug=dump, arch=arch
)
ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
Expand Down

0 comments on commit 56ffd00

Please sign in to comment.