Description
We have a simple test case in the Triton for if-else:
@triton.jit
def kernel(Cond, TrueVal, FalseVal, Out):
if tl.load(Cond):
val = tl.load(TrueVal)
else:
val = tl.load(FalseVal)
tl.store(Out, val)
The vanilla SPIRV dialect after code lowering is:
// -----// IR Dump After ReconcileUnrealizedCasts (reconcile-unrealized-casts) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
spirv.func @kernel_0d1d2d3d(%arg0: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg3: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
%cst0_i32 = spirv.Constant 0 : i32
%true = spirv.Constant true
%0 = spirv.Undef : i32
spirv.BranchConditional %true, ^bb1, ^bb2(%0 : i32)
^bb1: // pred: ^bb0
%1 = spirv.Load "CrossWorkgroup" %arg0 : i32
spirv.Branch ^bb2(%1 : i32)
^bb2(%2: i32): // 2 preds: ^bb0, ^bb1
%3 = spirv.INotEqual %2, %cst0_i32 : i32
spirv.BranchConditional %3, ^bb3, ^bb6
^bb3: // pred: ^bb2
%true_0 = spirv.Constant true
%4 = spirv.Undef : i32
spirv.BranchConditional %true_0, ^bb4, ^bb5(%4 : i32)
^bb4: // pred: ^bb3
%5 = spirv.Load "CrossWorkgroup" %arg1 : i32
spirv.Branch ^bb5(%5 : i32)
^bb5(%6: i32): // 2 preds: ^bb3, ^bb4
spirv.Branch ^bb9(%6 : i32)
^bb6: // pred: ^bb2
%true_1 = spirv.Constant true
%7 = spirv.Undef : i32
spirv.BranchConditional %true_1, ^bb7, ^bb8(%7 : i32)
^bb7: // pred: ^bb6
%8 = spirv.Load "CrossWorkgroup" %arg2 : i32
spirv.Branch ^bb8(%8 : i32)
^bb8(%9: i32): // 2 preds: ^bb6, ^bb7
spirv.Branch ^bb9(%9 : i32)
^bb9(%10: i32): // 2 preds: ^bb5, ^bb8
spirv.Branch ^bb10
^bb10: // pred: ^bb9
%true_2 = spirv.Constant true
%__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
%11 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
%12 = spirv.CompositeExtract %11[0 : i32] : vector<3xi64>
%13 = spirv.SConvert %12 : i64 to i32
%cst1_i32 = spirv.Constant 1 : i32
%cst1_i32_3 = spirv.Constant 1 : i32
%14 = spirv.SLessThan %13, %cst1_i32_3 : i32
%15 = spirv.LogicalAnd %true_2, %14 : i1
spirv.BranchConditional %15, ^bb11, ^bb12
^bb11: // pred: ^bb10
%16 = spirv.Undef : i32
%17 = spirv.Undef : i32
spirv.Store "CrossWorkgroup" %arg3, %10 : i32
spirv.Branch ^bb12
^bb12: // 2 preds: ^bb10, ^bb11
spirv.Return
}
}
And after some SPIRV dialect canonicalize:
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
spirv.func @kernel_0d1d2d3d(%arg0: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg3: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
%cst1_i32 = spirv.Constant 1 : i32
%cst0_i32 = spirv.Constant 0 : i32
%true = spirv.Constant true
%0 = spirv.Undef : i32
spirv.BranchConditional %true, ^bb1, ^bb2(%0 : i32)
^bb1: // pred: ^bb0
%1 = spirv.Load "CrossWorkgroup" %arg0 : i32
spirv.Branch ^bb2(%1 : i32)
^bb2(%2: i32): // 2 preds: ^bb0, ^bb1
%3 = spirv.INotEqual %2, %cst0_i32 : i32
spirv.BranchConditional %3, ^bb3(%arg1 : !spirv.ptr<i32, CrossWorkgroup>), ^bb3(%arg2 : !spirv.ptr<i32, CrossWorkgroup>)
^bb3(%4: !spirv.ptr<i32, CrossWorkgroup>): // 2 preds: ^bb2, ^bb2
%5 = spirv.Undef : i32
spirv.BranchConditional %true, ^bb4(%4 : !spirv.ptr<i32, CrossWorkgroup>), ^bb5(%5 : i32)
^bb4(%6: !spirv.ptr<i32, CrossWorkgroup>): // pred: ^bb3
%7 = spirv.Load "CrossWorkgroup" %6 : i32
spirv.Branch ^bb5(%7 : i32)
^bb5(%8: i32): // 2 preds: ^bb3, ^bb4
spirv.Branch ^bb6(%8 : i32)
^bb6(%9: i32): // pred: ^bb5
spirv.Branch ^bb7
^bb7: // pred: ^bb6
%__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
%10 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
%11 = spirv.CompositeExtract %10[0 : i32] : vector<3xi64>
%12 = spirv.SConvert %11 : i64 to i32
%13 = spirv.SLessThan %12, %cst1_i32 : i32
spirv.BranchConditional %13, ^bb8, ^bb9
^bb8: // pred: ^bb7
spirv.Store "CrossWorkgroup" %arg3, %9 : i32
spirv.Branch ^bb9
^bb9: // 2 preds: ^bb7, ^bb8
spirv.Return
}
}
The canonicalize merge the Ture and False blocks into one, and it uses the block parameter for the branch divergency.
spirv.BranchConditional %3, ^bb3(%arg1 : !spirv.ptr<i32, CrossWorkgroup>), ^bb3(%arg2 : !spirv.ptr<i32, CrossWorkgroup>) ^bb3(%4: !spirv.ptr<i32, CrossWorkgroup>): // 2 preds: ^bb2, ^bb2
You can see for the true condition, it uses the arg1 as the parameter. for the false condition, it uses the arg2 as the parameter.
I think the optimize pass is working correctly.
But it seems issue happened when serialization to SPIRV IR:
%5 = OpTypeFunction %void %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
%bool = OpTypeBool
%true = OpConstantTrue %bool
%19 = OpUndef %uint
%kernel_0d1d2d3d = OpFunction %void None %5
%10 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%11 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%12 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%13 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%14 = OpLabel
OpBranchConditional %true %20 %21
%20 = OpLabel
%22 = OpLoad %uint %10
OpBranch %21
%21 = OpLabel
%23 = OpPhi %uint %22 %20 %19 %14
%24 = OpINotEqual %bool %23 %uint_0
OpBranchConditional %24 %25 %25
%25 = OpLabel
%26 = OpPhi %_ptr_CrossWorkgroup_uint %11 %21 %11 %21
OpBranchConditional %true %27 %28
%27 = OpLabel
%29 = OpPhi %_ptr_CrossWorkgroup_uint %26 %25
%30 = OpLoad %uint %29
OpBranch %28
%28 = OpLabel
%31 = OpPhi %uint %30 %27 %19 %25
OpBranch %32
%32 = OpLabel
%33 = OpPhi %uint %31 %28
OpBranch %34
%34 = OpLabel
%35 = OpLoad %v3ulong %__builtin_var_LocalInvocationId__
%36 = OpCompositeExtract %ulong %35 0
%37 = OpSConvert %uint %36
%38 = OpSLessThan %bool %37 %uint_1
OpBranchConditional %38 %39 %40
%39 = OpLabel
OpStore %13 %33
OpBranch %40
%40 = OpLabel
OpReturn
OpFunctionEnd
The Phi node uses the same ptr for two pred blocks.
%26 = OpPhi %_ptr_CrossWorkgroup_uint %11 %21 %11 %21
(There is no two physical blocks in SPIRV IR. I think it is optimized. But logically we should have two blocks one for true and one for false.)