Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

It seems the SPIRV serialization has issue in translating the block parameter (SSA in MLIR) into the Phi node (SSA in SPIRV IR) #646

Open
chengjunlu opened this issue May 4, 2023 · 2 comments
Labels
Triton Issues tracking Triton/IMEX collaboration

Comments

@chengjunlu
Copy link

We have a simple test case in the Triton for if-else:

@triton.jit
def kernel(Cond, TrueVal, FalseVal, Out):
    if tl.load(Cond):
        val = tl.load(TrueVal)
    else:
        val = tl.load(FalseVal)
    tl.store(Out, val)

The vanilla SPIRV dialect after code lowering is:

// -----// IR Dump After ReconcileUnrealizedCasts (reconcile-unrealized-casts) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
  spirv.func @kernel_0d1d2d3d(%arg0: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg3: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
    %cst0_i32 = spirv.Constant 0 : i32
    %true = spirv.Constant true
    %0 = spirv.Undef : i32
    spirv.BranchConditional %true, ^bb1, ^bb2(%0 : i32)
  ^bb1:  // pred: ^bb0
    %1 = spirv.Load "CrossWorkgroup" %arg0 : i32
    spirv.Branch ^bb2(%1 : i32)
  ^bb2(%2: i32):  // 2 preds: ^bb0, ^bb1
    %3 = spirv.INotEqual %2, %cst0_i32 : i32
    spirv.BranchConditional %3, ^bb3, ^bb6
  ^bb3:  // pred: ^bb2
    %true_0 = spirv.Constant true
    %4 = spirv.Undef : i32
    spirv.BranchConditional %true_0, ^bb4, ^bb5(%4 : i32)
  ^bb4:  // pred: ^bb3
    %5 = spirv.Load "CrossWorkgroup" %arg1 : i32
    spirv.Branch ^bb5(%5 : i32)
  ^bb5(%6: i32):  // 2 preds: ^bb3, ^bb4
    spirv.Branch ^bb9(%6 : i32)
  ^bb6:  // pred: ^bb2
    %true_1 = spirv.Constant true
    %7 = spirv.Undef : i32
    spirv.BranchConditional %true_1, ^bb7, ^bb8(%7 : i32)
  ^bb7:  // pred: ^bb6
    %8 = spirv.Load "CrossWorkgroup" %arg2 : i32
    spirv.Branch ^bb8(%8 : i32)
  ^bb8(%9: i32):  // 2 preds: ^bb6, ^bb7
    spirv.Branch ^bb9(%9 : i32)
  ^bb9(%10: i32):  // 2 preds: ^bb5, ^bb8
    spirv.Branch ^bb10
  ^bb10:  // pred: ^bb9
    %true_2 = spirv.Constant true
    %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
    %11 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %12 = spirv.CompositeExtract %11[0 : i32] : vector<3xi64>
    %13 = spirv.SConvert %12 : i64 to i32
    %cst1_i32 = spirv.Constant 1 : i32
    %cst1_i32_3 = spirv.Constant 1 : i32
    %14 = spirv.SLessThan %13, %cst1_i32_3 : i32
    %15 = spirv.LogicalAnd %true_2, %14 : i1
    spirv.BranchConditional %15, ^bb11, ^bb12
  ^bb11:  // pred: ^bb10
    %16 = spirv.Undef : i32
    %17 = spirv.Undef : i32
    spirv.Store "CrossWorkgroup" %arg3, %10 : i32
    spirv.Branch ^bb12
  ^bb12:  // 2 preds: ^bb10, ^bb11
    spirv.Return
  }
}

And after some SPIRV dialect canonicalize:

// -----// IR Dump After Canonicalizer (canonicalize) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
  spirv.func @kernel_0d1d2d3d(%arg0: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg3: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
    %cst1_i32 = spirv.Constant 1 : i32
    %cst0_i32 = spirv.Constant 0 : i32
    %true = spirv.Constant true
    %0 = spirv.Undef : i32
    spirv.BranchConditional %true, ^bb1, ^bb2(%0 : i32)
  ^bb1:  // pred: ^bb0
    %1 = spirv.Load "CrossWorkgroup" %arg0 : i32
    spirv.Branch ^bb2(%1 : i32)
  ^bb2(%2: i32):  // 2 preds: ^bb0, ^bb1
    %3 = spirv.INotEqual %2, %cst0_i32 : i32
    spirv.BranchConditional %3, ^bb3(%arg1 : !spirv.ptr<i32, CrossWorkgroup>), ^bb3(%arg2 : !spirv.ptr<i32, CrossWorkgroup>)
  ^bb3(%4: !spirv.ptr<i32, CrossWorkgroup>):  // 2 preds: ^bb2, ^bb2
    %5 = spirv.Undef : i32
    spirv.BranchConditional %true, ^bb4(%4 : !spirv.ptr<i32, CrossWorkgroup>), ^bb5(%5 : i32)
  ^bb4(%6: !spirv.ptr<i32, CrossWorkgroup>):  // pred: ^bb3
    %7 = spirv.Load "CrossWorkgroup" %6 : i32
    spirv.Branch ^bb5(%7 : i32)
  ^bb5(%8: i32):  // 2 preds: ^bb3, ^bb4
    spirv.Branch ^bb6(%8 : i32)
  ^bb6(%9: i32):  // pred: ^bb5
    spirv.Branch ^bb7
  ^bb7:  // pred: ^bb6
    %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
    %10 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %11 = spirv.CompositeExtract %10[0 : i32] : vector<3xi64>
    %12 = spirv.SConvert %11 : i64 to i32
    %13 = spirv.SLessThan %12, %cst1_i32 : i32
    spirv.BranchConditional %13, ^bb8, ^bb9
  ^bb8:  // pred: ^bb7
    spirv.Store "CrossWorkgroup" %arg3, %9 : i32
    spirv.Branch ^bb9
  ^bb9:  // 2 preds: ^bb7, ^bb8
    spirv.Return
  }
}

The canonicalize merge the Ture and False blocks into one, and it uses the block parameter for the branch divergency.
spirv.BranchConditional %3, ^bb3(%arg1 : !spirv.ptr<i32, CrossWorkgroup>), ^bb3(%arg2 : !spirv.ptr<i32, CrossWorkgroup>) ^bb3(%4: !spirv.ptr<i32, CrossWorkgroup>): // 2 preds: ^bb2, ^bb2

You can see for the true condition, it uses the arg1 as the parameter. for the false condition, it uses the arg2 as the parameter.

I think the optimize pass is working correctly.

But it seems issue happened when serialization to SPIRV IR:

%5 = OpTypeFunction %void %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
%bool = OpTypeBool
%true = OpConstantTrue %bool
%19 = OpUndef %uint
%kernel_0d1d2d3d = OpFunction %void None %5
%10 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%11 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%12 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%13 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%14 = OpLabel
OpBranchConditional %true %20 %21
%20 = OpLabel
%22 = OpLoad %uint %10
OpBranch %21
%21 = OpLabel
%23 = OpPhi %uint %22 %20 %19 %14
%24 = OpINotEqual %bool %23 %uint_0
OpBranchConditional %24 %25 %25
%25 = OpLabel
%26 = OpPhi %_ptr_CrossWorkgroup_uint %11 %21 %11 %21
OpBranchConditional %true %27 %28
%27 = OpLabel
%29 = OpPhi %_ptr_CrossWorkgroup_uint %26 %25
%30 = OpLoad %uint %29
OpBranch %28
%28 = OpLabel
%31 = OpPhi %uint %30 %27 %19 %25
OpBranch %32
%32 = OpLabel
%33 = OpPhi %uint %31 %28
OpBranch %34
%34 = OpLabel
%35 = OpLoad %v3ulong %__builtin_var_LocalInvocationId__
%36 = OpCompositeExtract %ulong %35 0
%37 = OpSConvert %uint %36
%38 = OpSLessThan %bool %37 %uint_1
OpBranchConditional %38 %39 %40
%39 = OpLabel
OpStore %13 %33
OpBranch %40
%40 = OpLabel
OpReturn
OpFunctionEnd

The Phi node uses the same ptr for two pred blocks.
%26 = OpPhi %_ptr_CrossWorkgroup_uint %11 %21 %11 %21

(There is no two physical blocks in SPIRV IR. I think it is optimized. But logically we should have two blocks one for true and one for false.)

@chengjunlu
Copy link
Author

The test can passed with the work around by removing the canonicalize pass.

Although we have the work around, but I think we should take it as a high priority on this issue.

F.Y.I:
SPIRV dialect with the work around. (No canonicalize in combine the if-else block)

// -----// IR Dump After CSE (cse) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
  spirv.func @kernel_0d1d2d3d(%arg0: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg3: !spirv.ptr<i32, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
    %cst0_i32 = spirv.Constant 0 : i32
    %true = spirv.Constant true
    %0 = spirv.Undef : i32
    spirv.BranchConditional %true, ^bb1, ^bb2(%0 : i32)
  ^bb1:  // pred: ^bb0
    %1 = spirv.Load "CrossWorkgroup" %arg0 : i32
    spirv.Branch ^bb2(%1 : i32)
  ^bb2(%2: i32):  // 2 preds: ^bb0, ^bb1
    %3 = spirv.INotEqual %2, %cst0_i32 : i32
    spirv.BranchConditional %3, ^bb3, ^bb6
  ^bb3:  // pred: ^bb2
    spirv.BranchConditional %true, ^bb4, ^bb5(%0 : i32)
  ^bb4:  // pred: ^bb3
    %4 = spirv.Load "CrossWorkgroup" %arg1 : i32
    spirv.Branch ^bb5(%4 : i32)
  ^bb5(%5: i32):  // 2 preds: ^bb3, ^bb4
    spirv.Branch ^bb9(%5 : i32)
  ^bb6:  // pred: ^bb2
    spirv.BranchConditional %true, ^bb7, ^bb8(%0 : i32)
  ^bb7:  // pred: ^bb6
    %6 = spirv.Load "CrossWorkgroup" %arg2 : i32
    spirv.Branch ^bb8(%6 : i32)
  ^bb8(%7: i32):  // 2 preds: ^bb6, ^bb7
    spirv.Branch ^bb9(%7 : i32)
  ^bb9(%8: i32):  // 2 preds: ^bb5, ^bb8
    spirv.Branch ^bb10
  ^bb10:  // pred: ^bb9
    %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
    %9 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %10 = spirv.CompositeExtract %9[0 : i32] : vector<3xi64>
    %11 = spirv.SConvert %10 : i64 to i32
    %cst1_i32 = spirv.Constant 1 : i32
    %12 = spirv.SLessThan %11, %cst1_i32 : i32
    %13 = spirv.LogicalAnd %true, %12 : i1
    spirv.BranchConditional %13, ^bb11, ^bb12
  ^bb11:  // pred: ^bb10
    spirv.Store "CrossWorkgroup" %arg3, %8 : i32
    spirv.Branch ^bb12
  ^bb12:  // 2 preds: ^bb10, ^bb11
    spirv.Return
  }
}

The corresponding SPIRV IR:

%5 = OpTypeFunction %void %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint %_ptr_CrossWorkgroup_uint
%uint_0 = OpConstant %uint 0
%bool = OpTypeBool
%true = OpConstantTrue %bool
%18 = OpUndef %uint
%uint_1 = OpConstant %uint 1
%kernel_0d1d2d3d = OpFunction %void None %5
%10 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%11 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%12 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%13 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
%14 = OpLabel
OpBranchConditional %true %19 %20
%19 = OpLabel
%21 = OpLoad %uint %10
OpBranch %20
%20 = OpLabel
%22 = OpPhi %uint %21 %19 %18 %14
%23 = OpINotEqual %bool %22 %uint_0
OpBranchConditional %23 %24 %25
%24 = OpLabel
OpBranchConditional %true %26 %27
%26 = OpLabel
%28 = OpLoad %uint %11
OpBranch %27
%27 = OpLabel
%29 = OpPhi %uint %28 %26 %18 %24
OpBranch %30
%30 = OpLabel
%31 = OpPhi %uint %29 %27 %44 %32
OpBranch %33
%33 = OpLabel
%34 = OpLoad %v3ulong %__builtin_var_LocalInvocationId__
%35 = OpCompositeExtract %ulong %34 0
%36 = OpSConvert %uint %35
%38 = OpSLessThan %bool %36 %uint_1
%39 = OpLogicalAnd %bool %true %38
OpBranchConditional %39 %40 %41
%40 = OpLabel
OpStore %13 %31
OpBranch %41
%41 = OpLabel
OpReturn
%25 = OpLabel
OpBranchConditional %true %42 %32
%42 = OpLabel
%43 = OpLoad %uint %12
OpBranch %32
%32 = OpLabel
%44 = OpPhi %uint %43 %42 %18 %25
OpBranch %30
OpFunctionEnd

@chengjunlu chengjunlu added the Triton Issues tracking Triton/IMEX collaboration label May 6, 2023
@chengjunlu
Copy link
Author

chengjunlu commented Jun 8, 2023

Can we raise the priority to this issue?
This is not functional blocking issue. But it is annoying that we have to turn off the canonization pass. The un-simplified IR is hard to understand and debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Triton Issues tracking Triton/IMEX collaboration
Projects
None yet
Development

No branches or pull requests

1 participant