Skip to content

Commit

Permalink
[FRONTEND] Always use argument names as the key for constants and `…
Browse files Browse the repository at this point in the history
…signature` dictionaries (#4693)

Previously some code supplies argument indices as the key for the
dictionaries, which is actually wrong because we have changed the key
type to `str` a few months ago.

https://github.com/triton-lang/triton/blob/a0c1bc9c8c966158cadd603e5a20704081deaa62/python/triton/runtime/jit.py#L645

In addition, this PR also raises a type error if a dictionary with
incorrect types is used for `ASTSource`.
  • Loading branch information
Jokeren authored Sep 11, 2024
1 parent 8b33bcf commit f03acdc
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 26 deletions.
2 changes: 1 addition & 1 deletion python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def kernel(a=GLOBAL):
pass

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))


def test_defaults_assign_no_err():
Expand Down
10 changes: 6 additions & 4 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ def walk_fn(op):
]
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={i: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs},
constants={i: arg
signature={
kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs
},
constants={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=kernel._get_config(*args, ),
Expand Down
8 changes: 4 additions & 4 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def kernel_sub(a, b, o, N: tl.constexpr):

src = ASTSource(
fn=kernel_sub,
constants={3: 32},
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
constants={'N': 32},
signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"},
attrs=attrs,
)
triton.compile(src=src, target=target)
Expand All @@ -42,7 +42,7 @@ def kernel_dot(Z):
z = tl.dot(z, z)
tl.store(Z + offs, z)

src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict())
src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={})
triton.compile(src=src, target=target)


Expand All @@ -63,7 +63,7 @@ def empty_kernel():

import gc
gc.collect()
src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants=dict())
src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={})
triton.compile(src=src, target=target)


Expand Down
11 changes: 7 additions & 4 deletions python/test/unit/test_perf_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk,
triton.compile(
triton.compiler.ASTSource(
fn=matmul_kernel, signature={
0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9:
'i32', 10: 'i32', 11: 'i32'
'a_ptr': '*fp32', 'b_ptr': '*fp32', 'c_ptr': '*fp32', 'M': 'i32', 'N': 'i32', 'K': 'i32', 'stride_am':
'i32', 'stride_ak': 'i32', 'stride_bk': 'i32', 'stride_bn': 'i32', 'stride_cm': 'i32', 'stride_cn':
'i32'
}, constants={}))
captured = capfd.readouterr()

Expand Down Expand Up @@ -75,8 +76,10 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr)

XBLOCK = 1024
triton.compile(
triton.compiler.ASTSource(fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16'},
constants={"XBLOCK": XBLOCK}), options={"num_warps": 1})
triton.compiler.ASTSource(
fn=ldst_vec, signature={
'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*fp16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp16'
}, constants={"XBLOCK": XBLOCK}), options={"num_warps": 1})

_, err = capfd.readouterr()
assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark"
Expand Down
10 changes: 9 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,16 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None:
self.attrs = attrs
if isinstance(self.signature, str):
self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
else:
for k in self.signature.keys():
if not isinstance(k, str):
raise TypeError("Signature keys must be string")
if self.constants is None:
self.constants = dict()
self.constants = {}
else:
for k in self.constants.keys():
if not isinstance(k, str):
raise TypeError("Constants keys must be string")
if self.attrs is None:
self.attrs = AttrsDescriptor()

Expand Down
35 changes: 23 additions & 12 deletions python/triton/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def constexpr(s):

hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s}
hints = {k: v for k, v in hints.items() if v is not None}
constants = {i: constexpr(s) for i, s in enumerate(signature)}
constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)}
constants = {k: v for k, v in constants.items() if v is not None}
signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants}
signature = {
kernel.arg_names[i]: s.split(":")[0]
for i, s in enumerate(signature)
if kernel.arg_names[i] not in constants
}
const_sig = 'x'.join([str(v) for v in constants.values()])
doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()]
doc_string = [f"{k}={v}" for k, v in constants.items()]
doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"]

# compile ast into cubin
Expand All @@ -106,16 +110,23 @@ def constexpr(s):
equal_to_1 = [i for i, h in hints.items() if h == 1]
attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1)
for i in equal_to_1:
constants.update({i: 1})
constants.update({kernel.arg_names[i]: 1})
src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs)
opts = {"num_warps": args.num_warps, "num_stages": args.num_stages}
ccinfo = triton.compile(src, options=opts)
arg_names = []
arg_types = []
for i in signature.keys():
if i not in equal_to_1:
arg_names += [kernel.arg_names[i]]
arg_types += [signature[i]]
arg_names_not_1 = []
arg_types_not_1 = []
for i, arg_name in enumerate(kernel.arg_names):
if arg_name not in constants:
arg_names.append(arg_name)
arg_types.append(signature[arg_name])
arg_names_not_1.append(arg_name)
arg_types_not_1.append(signature[arg_name])
elif i in equal_to_1:
arg_names.append(arg_name)
arg_types.append(signature[arg_name])

# dump C stub code
suffix = kernel_suffix(signature.values(), attrs)
Expand All @@ -126,10 +137,10 @@ def constexpr(s):
"triton_kernel_name": args.kernel_name,
"bin_size": len(hex_),
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
"full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]),
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]),
"num_args": len(arg_names),
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
"full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1]),
"num_args": len(arg_names_not_1),
"kernel_docstring": doc_string,
"shared": ccinfo.metadata.shared,
"num_warps": args.num_warps,
Expand Down

0 comments on commit f03acdc

Please sign in to comment.