Skip to content

It seems the SPIRV serialization has issue in translating the block parameter (SSA in MLIR) into the Phi node (SSA in SPIRV IR) #646

Open
@chengjunlu

Description

@chengjunlu

We have a simple test case in the Triton for if-else:

@triton.jit
def kernel(Cond, TrueVal, FalseVal, Out):
    if tl.load(Cond):
        val = tl.load(TrueVal)
    else:
        val = tl.load(FalseVal)
    tl.store(Out, val)

The vanilla SPIRV dialect after code lowering is:

// -----// IR Dump After ReconcileUnrealizedCasts (reconcile-unrealized-casts) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
  spirv.func @kernel_0d1d2d3d(%arg0: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg3: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
    %cst0_i32 = spirv.Constant 0 : i32
    %true = spirv.Constant true
    %0 = spirv.Undef : i32
    spirv.BranchConditional %true, ^bb1, ^bb2(%0 : i32)
  ^bb1:  // pred: ^bb0
    %1 = spirv.Load "CrossWorkgroup" %arg0 : i32
    spirv.Branch ^bb2(%1 : i32)
  ^bb2(%2: i32):  // 2 preds: ^bb0, ^bb1
    %3 = spirv.INotEqual %2, %cst0_i32 : i32
    spirv.BranchConditional %3, ^bb3, ^bb6
  ^bb3:  // pred: ^bb2
    %true_0 = spirv.Constant true
    %4 = spirv.Undef : i32
    spirv.BranchConditional %true_0, ^bb4, ^bb5(%4 : i32)
  ^bb4:  // pred: ^bb3
    %5 = spirv.Load "CrossWorkgroup" %arg1 : i32
    spirv.Branch ^bb5(%5 : i32)
  ^bb5(%6: i32):  // 2 preds: ^bb3, ^bb4
    spirv.Branch ^bb9(%6 : i32)
  ^bb6:  // pred: ^bb2
    %true_1 = spirv.Constant true
    %7 = spirv.Undef : i32
    spirv.BranchConditional %true_1, ^bb7, ^bb8(%7 : i32)
  ^bb7:  // pred: ^bb6
    %8 = spirv.Load "CrossWorkgroup" %arg2 : i32
    spirv.Branch ^bb8(%8 : i32)
  ^bb8(%9: i32):  // 2 preds: ^bb6, ^bb7
    spirv.Branch ^bb9(%9 : i32)
  ^bb9(%10: i32):  // 2 preds: ^bb5, ^bb8
    spirv.Branch ^bb10
  ^bb10:  // pred: ^bb9
    %true_2 = spirv.Constant true
    %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
    %11 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %12 = spirv.CompositeExtract %11[0 : i32] : vector<3xi64>
    %13 = spirv.SConvert %12 : i64 to i32
    %cst1_i32 = spirv.Constant 1 : i32
    %cst1_i32_3 = spirv.Constant 1 : i32
    %14 = spirv.SLessThan %13, %cst1_i32_3 : i32
    %15 = spirv.LogicalAnd %true_2, %14 : i1
    spirv.BranchConditional %15, ^bb11, ^bb12
  ^bb11:  // pred: ^bb10
    %16 = spirv.Undef : i32
    %17 = spirv.Undef : i32
    spirv.Store "CrossWorkgroup" %arg3, %10 : i32
    spirv.Branch ^bb12
  ^bb12:  // 2 preds: ^bb10, ^bb11
    spirv.Return
  }
}

And after some SPIRV dialect canonicalize:

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
  spirv.func @kernel_0d1d2d3d(%arg0: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg3: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
    %cst1_i32 = spirv.Constant 1 : i32
    %cst0_i32 = spirv.Constant 0 : i32
    %true = spirv.Constant true
    %0 = spirv.Undef : i32
    spirv.BranchConditional %true, ^bb1, ^bb2(%0 : i32)
  ^bb1:  // pred: ^bb0
    %1 = spirv.Load "CrossWorkgroup" %arg0 : i32
    spirv.Branch ^bb2(%1 : i32)
  ^bb2(%2: i32):  // 2 preds: ^bb0, ^bb1
    %3 = spirv.INotEqual %2, %cst0_i32 : i32
    spirv.BranchConditional %3, ^bb3(%arg1 : !spirv.ptr<i32, CrossWorkgroup>), ^bb3(%arg2 : !spirv.ptr<i32, CrossWorkgroup>)
  ^bb3(%4: !spirv.ptr<i32, CrossWorkgroup>):  // 2 preds: ^bb2, ^bb2
    %5 = spirv.Undef : i32
    spirv.BranchConditional %true, ^bb4(%4 : !spirv.ptr<i32, CrossWorkgroup>), ^bb5(%5 : i32)
  ^bb4(%6: !spirv.ptr<i32, CrossWorkgroup>):  // pred: ^bb3
    %7 = spirv.Load "CrossWorkgroup" %6 : i32
    spirv.Branch ^bb5(%7 : i32)
  ^bb5(%8: i32):  // 2 preds: ^bb3, ^bb4
    spirv.Branch ^bb6(%8 : i32)
  ^bb6(%9: i32):  // pred: ^bb5
    spirv.Branch ^bb7
  ^bb7:  // pred: ^bb6
    %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
    %10 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %11 = spirv.CompositeExtract %10[0 : i32] : vector<3xi64>
    %12 = spirv.SConvert %11 : i64 to i32
    %13 = spirv.SLessThan %12, %cst1_i32 : i32
    spirv.BranchConditional %13, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    spirv.Store "CrossWorkgroup" %arg3, %9 : i32
    spirv.Branch ^bb9
  ^bb9:  // 2 preds: ^bb7, ^bb8
    spirv.Return
  }
}

The canonicalize merge the Ture and False blocks into one, and it uses the block parameter for the branch divergency.
spirv.BranchConditional %3, ^bb3(%arg1 : !spirv.ptr<i32, CrossWorkgroup>), ^bb3(%arg2 : !spirv.ptr<i32, CrossWorkgroup>) ^bb3(%4: !spirv.ptr<i32, CrossWorkgroup>): // 2 preds: ^bb2, ^bb2

You can see for the true condition, it uses the arg1 as the parameter. for the false condition, it uses the arg2 as the parameter.

I think the optimize pass is working correctly.

But it seems issue happened when serialization to SPIRV IR:

%5 = OpTypeFunction %void %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
%bool = OpTypeBool
%true = OpConstantTrue %bool
%19 = OpUndef %uint
%kernel_0d1d2d3d = OpFunction %void None %5
%10 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%11 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%12 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%13 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%14 = OpLabel
OpBranchConditional %true %20 %21
%20 = OpLabel
%22 = OpLoad %uint %10
OpBranch %21
%21 = OpLabel
%23 = OpPhi %uint %22 %20 %19 %14
%24 = OpINotEqual %bool %23 %uint_0
OpBranchConditional %24 %25 %25
%25 = OpLabel
%26 = OpPhi %_ptr_CrossWorkgroup_uint %11 %21 %11 %21
OpBranchConditional %true %27 %28
%27 = OpLabel
%29 = OpPhi %_ptr_CrossWorkgroup_uint %26 %25
%30 = OpLoad %uint %29
OpBranch %28
%28 = OpLabel
%31 = OpPhi %uint %30 %27 %19 %25
OpBranch %32
%32 = OpLabel
%33 = OpPhi %uint %31 %28
OpBranch %34
%34 = OpLabel
%35 = OpLoad %v3ulong %__builtin_var_LocalInvocationId__
%36 = OpCompositeExtract %ulong %35 0
%37 = OpSConvert %uint %36
%38 = OpSLessThan %bool %37 %uint_1
OpBranchConditional %38 %39 %40
%39 = OpLabel
OpStore %13 %33
OpBranch %40
%40 = OpLabel
OpReturn
OpFunctionEnd

The Phi node uses the same ptr for two pred blocks.
%26 = OpPhi %_ptr_CrossWorkgroup_uint %11 %21 %11 %21

(There is no two physical blocks in SPIRV IR. I think it is optimized. But logically we should have two blocks one for true and one for false.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    TritonIssues tracking Triton/IMEX collaboration

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions