diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index d0b98c9859d5..d784429b1764 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -329,7 +329,7 @@ def kernel(a=GLOBAL): pass # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={})) def test_defaults_assign_no_err(): diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index b3aebc9d8526..11ce4f8fc14b 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -61,10 +61,12 @@ def walk_fn(op): ] src = triton.compiler.compiler.ASTSource( fn=kernel, - signature={i: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(args) - if i not in kernel.constexprs}, - constants={i: arg + signature={ + kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs + }, + constants={kernel.arg_names[i]: arg for i, arg in enumerate(args) if not isinstance(arg, torch.Tensor)}, attrs=kernel._get_config(*args, ), diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 7240fb7bb562..c4afd1e0ed11 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -17,8 +17,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): src = ASTSource( fn=kernel_sub, - constants={3: 32}, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + constants={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"}, attrs=attrs, ) triton.compile(src=src, target=target) @@ -42,7 +42,7 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={}) triton.compile(src=src, target=target) @@ -63,7 +63,7 @@ def empty_kernel(): import gc gc.collect() - src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants=dict()) + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={}) triton.compile(src=src, target=target) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 8b793dd36095..9f4b9a0830a0 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -37,8 +37,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, triton.compile( triton.compiler.ASTSource( fn=matmul_kernel, signature={ - 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9: - 'i32', 10: 'i32', 11: 'i32' + 'a_ptr': '*fp32', 'b_ptr': '*fp32', 'c_ptr': '*fp32', 'M': 'i32', 'N': 'i32', 'K': 'i32', 'stride_am': + 'i32', 'stride_ak': 'i32', 'stride_bk': 'i32', 'stride_bn': 'i32', 'stride_cm': 'i32', 'stride_cn': + 'i32' }, constants={})) captured = capfd.readouterr() @@ -75,8 +76,10 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) XBLOCK = 1024 triton.compile( - triton.compiler.ASTSource(fn=ldst_vec, signature={0: '*i64', 1: '*i64', 2: '*fp16', 3: '*fp32', 4: '*fp16'}, - constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) + triton.compiler.ASTSource( + fn=ldst_vec, signature={ + 'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*fp16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp16' + }, constants={"XBLOCK": XBLOCK}), options={"num_warps": 1}) _, err = capfd.readouterr() assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 40b21e1a1219..873a0844901a 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -96,8 +96,16 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") if self.constants is None: - self.constants = dict() + self.constants = {} + else: + for k in self.constants.keys(): + if not isinstance(k, str): + raise TypeError("Constants keys must be string") if self.attrs is None: self.attrs = AttrsDescriptor() diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 872332b03d99..1e9697cc8226 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -92,11 +92,15 @@ def constexpr(s): hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} - constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} constants = {k: v for k, v in constants.items() if v is not None} - signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + signature = { + kernel.arg_names[i]: s.split(":")[0] + for i, s in enumerate(signature) + if kernel.arg_names[i] not in constants + } const_sig = 'x'.join([str(v) for v in constants.values()]) - doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string = [f"{k}={v}" for k, v in constants.items()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] # compile ast into cubin @@ -106,16 +110,23 @@ def constexpr(s): equal_to_1 = [i for i, h in hints.items() if h == 1] attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: - constants.update({i: 1}) + constants.update({kernel.arg_names[i]: 1}) src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) arg_names = [] arg_types = [] - for i in signature.keys(): - if i not in equal_to_1: - arg_names += [kernel.arg_names[i]] - arg_types += [signature[i]] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif i in equal_to_1: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) # dump C stub code suffix = kernel_suffix(signature.values(), attrs) @@ -126,10 +137,10 @@ def constexpr(s): "triton_kernel_name": args.kernel_name, "bin_size": len(hex_), "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), - "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), - "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), - "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), - "num_args": len(arg_names), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]), + "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1]), + "num_args": len(arg_names_not_1), "kernel_docstring": doc_string, "shared": ccinfo.metadata.shared, "num_warps": args.num_warps,