Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Issue with 1x1 Tensors (arith.cmpi) & tt.call #162

Open
JoeLi12345 opened this issue Aug 13, 2024 · 6 comments
Open

[Bug]: Issue with 1x1 Tensors (arith.cmpi) & tt.call #162

JoeLi12345 opened this issue Aug 13, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@JoeLi12345
Copy link

Triton python code

import torch
import torch._inductor.config

def test_model_2(x):
    device = torch.device("cuda")
    model = torch.nn.Sequential(
        torch.nn.Linear(10, 10),
        torch.nn.LayerNorm(10),
        torch.nn.CELU(),
    )
    model.to(device)
    return model(x)

compiled = torch.compile(test_model_2, backend="inductor")
y = compiled(torch.ones(10, 10, device="cuda"))

Triton IR

#loc = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0)
#loc61 = loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":265:0)
#loc63 = loc(unknown)
#loc65 = loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":255:0)
module {
  tt.func public @triton_(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0), %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0), %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0), %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0), %arg6: i32 loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0), %arg7: i32 loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":18:0)) attributes {noinline = false} {
    %c10_i32 = arith.constant 10 : i32 loc(#loc1)
    %c10_i32_0 = arith.constant 10 : i32 loc(#loc2)
    %0 = tt.get_program_id x : i32 loc(#loc3)
    %c1_i32 = arith.constant 1 : i32 loc(#loc4)
    %1 = arith.muli %0, %c1_i32 : i32 loc(#loc4)
    %2 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> loc(#loc5)
    %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<1xi32> -> tensor<1x1xi32> loc(#loc6)
    %4 = tt.splat %1 : i32 -> tensor<1x1xi32> loc(#loc7)
    %5 = arith.addi %4, %3 : tensor<1x1xi32> loc(#loc7)
    %cst = arith.constant dense<10> : tensor<1x1xi32> loc(#loc8)
    %6 = arith.cmpi slt, %5, %cst : tensor<1x1xi32> loc(#loc8)
    %7 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc9)
    %8 = tt.expand_dims %7 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> loc(#loc10)
    %c0_i32 = arith.constant 0 : i32 loc(#loc11)
    %cst_1 = arith.constant dense<10> : tensor<1x16xi32> loc(#loc12)
    %9 = arith.cmpi slt, %8, %cst_1 : tensor<1x16xi32> loc(#loc12)
    %c10_i32_2 = arith.constant 10 : i32 loc(#loc13)
    %cst_3 = arith.constant dense<10> : tensor<1x1xi32> loc(#loc13)
    %10 = arith.muli %5, %cst_3 : tensor<1x1xi32> loc(#loc13)
    %11 = tt.broadcast %10 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc14)
    %12 = arith.addi %8, %11 : tensor<1x16xi32> loc(#loc14)
    %13 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1x16x!tt.ptr<f32>> loc(#loc15)
    %14 = tt.addptr %13, %12 : tensor<1x16x!tt.ptr<f32>>, tensor<1x16xi32> loc(#loc15)
    %15 = tt.broadcast %6 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc16)
    %16 = arith.andi %9, %15 : tensor<1x16xi1> loc(#loc16)
    %cst_4 = arith.constant 0.000000e+00 : f32 loc(#loc17)
    %cst_5 = arith.constant dense<0.000000e+00> : tensor<1x16xf32> loc(#loc17)
    %17 = tt.load %14, %16, %cst_5 : tensor<1x16x!tt.ptr<f32>> loc(#loc17)
    %18 = tt.splat %arg3 : !tt.ptr<f32> -> tensor<1x16x!tt.ptr<f32>> loc(#loc18)
    %19 = tt.addptr %18, %8 : tensor<1x16x!tt.ptr<f32>>, tensor<1x16xi32> loc(#loc18)
    %cst_6 = arith.constant 0.000000e+00 : f32 loc(#loc19)
    %cst_7 = arith.constant dense<0.000000e+00> : tensor<1x16xf32> loc(#loc19)
    %20 = tt.load %19, %9, %cst_7 evictionPolicy = evict_last : tensor<1x16x!tt.ptr<f32>> loc(#loc19)
    %21 = tt.splat %arg4 : !tt.ptr<f32> -> tensor<1x16x!tt.ptr<f32>> loc(#loc20)
    %22 = tt.addptr %21, %8 : tensor<1x16x!tt.ptr<f32>>, tensor<1x16xi32> loc(#loc20)
    %cst_8 = arith.constant 0.000000e+00 : f32 loc(#loc21)
    %cst_9 = arith.constant dense<0.000000e+00> : tensor<1x16xf32> loc(#loc21)
    %23 = tt.load %22, %9, %cst_9 evictionPolicy = evict_last : tensor<1x16x!tt.ptr<f32>> loc(#loc21)
    %24 = tt.broadcast %6 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc22)
    %25 = arith.andi %9, %24 : tensor<1x16xi1> loc(#loc22)
    %c0_i32_10 = arith.constant 0 : i32 loc(#loc23)
    %cst_11 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc23)
    %26 = arith.sitofp %cst_11 : tensor<1x16xi32> to tensor<1x16xf32> loc(#loc23)
    %27 = arith.select %25, %17, %26 : tensor<1x16xi1>, tensor<1x16xf32> loc(#loc23)
    %28 = tt.broadcast %6 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc24)
    %29 = arith.andi %9, %28 : tensor<1x16xi1> loc(#loc24)
    %c0_i32_12 = arith.constant 0 : i32 loc(#loc25)
    %cst_13 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc25)
    %30 = arith.sitofp %cst_13 : tensor<1x16xi32> to tensor<1x16xf32> loc(#loc25)
    %31 = arith.select %29, %17, %30 : tensor<1x16xi1>, tensor<1x16xf32> loc(#loc25)
    %32 = tt.call @sum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%31) : (tensor<1x16xf32>) -> tensor<1xf32> loc(#loc26)
    %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<1xf32> -> tensor<1x1xf32> loc(#loc27)
    %c10_i32_14 = arith.constant 10 : i32 loc(#loc28)
    %cst_15 = arith.constant dense<10> : tensor<1x1xi32> loc(#loc28)
    %34 = arith.sitofp %cst_15 : tensor<1x1xi32> to tensor<1x1xf32> loc(#loc29)
    %35 = arith.divf %33, %34 : tensor<1x1xf32> loc(#loc30)
    %36 = tt.broadcast %35 : tensor<1x1xf32> -> tensor<1x16xf32> loc(#loc31)
    %37 = arith.subf %17, %36 : tensor<1x16xf32> loc(#loc31)
    %38 = arith.mulf %37, %37 : tensor<1x16xf32> loc(#loc32)
    %39 = tt.broadcast %6 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc33)
    %40 = arith.andi %9, %39 : tensor<1x16xi1> loc(#loc33)
    %c0_i32_16 = arith.constant 0 : i32 loc(#loc34)
    %cst_17 = arith.constant dense<0> : tensor<1x16xi32> loc(#loc34)
    %41 = arith.sitofp %cst_17 : tensor<1x16xi32> to tensor<1x16xf32> loc(#loc34)
    %42 = arith.select %40, %38, %41 : tensor<1x16xi1>, tensor<1x16xf32> loc(#loc34)
    %43 = tt.call @sum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%42) : (tensor<1x16xf32>) -> tensor<1xf32> loc(#loc35)
    %44 = tt.expand_dims %43 {axis = 1 : i32} : tensor<1xf32> -> tensor<1x1xf32> loc(#loc36)
    %cst_18 = arith.constant 1.000000e+01 : f32 loc(#loc37)
    %cst_19 = arith.constant dense<1.000000e+01> : tensor<1x1xf32> loc(#loc38)
    %45 = arith.divf %44, %cst_19 : tensor<1x1xf32> loc(#loc38)
    %cst_20 = arith.constant 9.99999974E-6 : f32 loc(#loc39)
    %cst_21 = arith.constant dense<9.99999974E-6> : tensor<1x1xf32> loc(#loc40)
    %46 = arith.addf %45, %cst_21 : tensor<1x1xf32> loc(#loc40)
    %47 = tt.extern_elementwise %46 {libname = "", libpath = "", pure = true, symbol = "__nv_rsqrtf"} : (tensor<1x1xf32>) -> tensor<1x1xf32> loc(#loc41)
    %48 = tt.broadcast %35 : tensor<1x1xf32> -> tensor<1x16xf32> loc(#loc42)
    %49 = arith.subf %17, %48 : tensor<1x16xf32> loc(#loc42)
    %50 = tt.broadcast %47 : tensor<1x1xf32> -> tensor<1x16xf32> loc(#loc43)
    %51 = arith.mulf %49, %50 : tensor<1x16xf32> loc(#loc43)
    %52 = arith.mulf %51, %20 : tensor<1x16xf32> loc(#loc44)
    %53 = arith.addf %52, %23 : tensor<1x16xf32> loc(#loc45)
    %cst_22 = arith.constant 0.000000e+00 : f32 loc(#loc46)
    %cst_23 = arith.constant dense<0.000000e+00> : tensor<1x16xf32> loc(#loc47)
    %54 = arith.cmpf ogt, %53, %cst_23 : tensor<1x16xf32> loc(#loc47)
    %55 = tt.extern_elementwise %53 {libname = "", libpath = "", pure = true, symbol = "__nv_expm1f"} : (tensor<1x16xf32>) -> tensor<1x16xf32> loc(#loc48)
    %56 = arith.select %54, %53, %55 : tensor<1x16xi1>, tensor<1x16xf32> loc(#loc49)
    gpu.barrier loc(#loc50)
    %57 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1x1x!tt.ptr<f32>> loc(#loc51)
    %58 = tt.addptr %57, %5 : tensor<1x1x!tt.ptr<f32>>, tensor<1x1xi32> loc(#loc51)
    tt.store %58, %47, %6 : tensor<1x1x!tt.ptr<f32>> loc(#loc52)
    %c10_i32_24 = arith.constant 10 : i32 loc(#loc53)
    %cst_25 = arith.constant dense<10> : tensor<1x1xi32> loc(#loc53)
    %59 = arith.muli %5, %cst_25 : tensor<1x1xi32> loc(#loc53)
    %60 = tt.broadcast %59 : tensor<1x1xi32> -> tensor<1x16xi32> loc(#loc54)
    %61 = arith.addi %8, %60 : tensor<1x16xi32> loc(#loc54)
    %62 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1x16x!tt.ptr<f32>> loc(#loc55)
    %63 = tt.addptr %62, %61 : tensor<1x16x!tt.ptr<f32>>, tensor<1x16xi32> loc(#loc55)
    %64 = tt.broadcast %6 : tensor<1x1xi1> -> tensor<1x16xi1> loc(#loc56)
    %65 = arith.andi %9, %64 : tensor<1x16xi1> loc(#loc56)
    tt.store %63, %56, %65 : tensor<1x16x!tt.ptr<f32>> loc(#loc57)
    %66 = tt.splat %arg5 : !tt.ptr<f32> -> tensor<1x1x!tt.ptr<f32>> loc(#loc58)
    %67 = tt.addptr %66, %5 : tensor<1x1x!tt.ptr<f32>>, tensor<1x1xi32> loc(#loc58)
    tt.store %67, %35, %6 : tensor<1x1x!tt.ptr<f32>> loc(#loc59)
    tt.return loc(#loc60)
  } loc(#loc)
  tt.func private @sum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%arg0: tensor<1x16xf32> loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":265:0)) -> tensor<1xf32> attributes {noinline = false} {
    %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
    ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
      %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc62)
      tt.reduce.return %1 : f32 loc(#loc62)
    }) : (tensor<1x16xf32>) -> tensor<1xf32> loc(#loc62)
    tt.return %0 : tensor<1xf32> loc(#loc64)
  } loc(#loc61)

  tt.func private @_sum_combine__fp32_fp32__(%arg0: f32 loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":255:0), %arg1: f32 loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":255:0)) -> f32 attributes {noinline = false} {
    %0 = arith.addf %arg0, %arg1 : f32 loc(#loc66)
    tt.return %0 : f32 loc(#loc67)
  } loc(#loc65)


  } loc(#loc)
#loc1 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":19:13)
#loc2 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":20:13)
#loc3 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":22:28)
#loc4 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":22:33)
#loc5 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":23:36)
#loc6 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":23:44)
#loc7 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":23:23)
#loc8 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":24:21)
#loc9 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":25:26)
#loc10 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":25:34)
#loc11 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":26:14)
#loc12 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":27:21)
#loc13 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":30:39)
#loc14 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":30:36)
#loc15 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":30:30)
#loc16 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":30:53)
#loc17 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":30:45)
#loc18 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":31:31)
#loc19 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":31:36)
#loc20 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":32:31)
#loc21 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":32:36)
#loc22 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":34:28)
#loc23 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":34:41)
#loc24 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":36:28)
#loc25 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":36:41)
#loc26 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":37:24)
#loc27 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":37:27)
#loc28 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":38:36)
#loc29 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":39:19)
#loc30 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":40:19)
#loc31 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":41:19)
#loc32 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":42:20)
#loc33 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":44:29)
#loc34 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":44:43)
#loc35 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":45:26)
#loc36 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":45:29)
#loc37 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":46:12)
#loc38 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":47:20)
#loc39 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":48:12)
#loc40 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":49:20)
#loc41 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":50:28)
#loc42 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":51:19)
#loc43 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":52:20)
#loc44 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":53:20)
#loc45 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":54:20)
#loc46 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":55:12)
#loc47 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":56:20)
#loc48 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":57:28)
#loc49 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":58:35)
#loc50 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":59:4)
#loc51 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":60:28)
#loc52 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":60:40)
#loc53 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":61:37)
#loc54 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":61:34)
#loc55 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":61:28)
#loc56 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":61:58)
#loc57 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":61:50)
#loc58 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":62:25)
#loc59 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":62:37)
#loc60 = loc("/tmp/torchinductor_joeli/7s/c7soccbyvhubhgkwekrswibupfdmo4fgvsojpfy5twsmezq3tj4e.py":62:4)
#loc62 = loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":267:36)
#loc64 = loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":267:11)
#loc66 = loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":256:15)
#loc67 = loc("/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py":256:11)

Crash log

No response

Additional information

I'm running the following command on the MLIR file included above: triton-shared-opt --triton-to-linalg-experimental file.mlir. Two issues arise.

  1. 1x1 Tensors: For this line in the MLIR file %6 = arith.cmpi slt, %5, %cst : tensor<1x1xi32>, I get the following error:

/triton_shared/lib/Analysis/MaskAnalysis.cpp:318: mlir::LogicalResult mlir::triton::MaskState::parseCmp(mlir::arith::CmpIOp, mlir::Location, mlir::OpBuilder&): Assertion `cmpDim != -1 && "Unexpected case where no dimension has size larger than 1"' failed.

Why is are 1x1 tensors unsupported? As a result, I commented out that assertion in Line 318 of the MaskAnalysis.cpp file and ran the above command again. Which leads to Issue 2.

  1. Strange issue with tt.call and tt.func: For some reason, tt.call does not recognize the '_sum_combine__fp32_fp32__' function even though it is defined by tt.func in the MLIR file. Why is this?

/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py:267:36: error: 'tt.call' op '_sum_combine__fp32_fp32__' does not reference a valid function
/home/joeli/miniconda3/envs/torch_to_triton/lib/python3.10/site-packages/triton/language/standard.py:267:36: note: see current operation: %1 = "tt.call"(%arg7, %arg8) <{callee = @_sum_combine__fp32_fp32__}> : (f32, f32) -> f32

@JoeLi12345 JoeLi12345 added the bug Something isn't working label Aug 13, 2024
@JoeLi12345 JoeLi12345 changed the title [Bug]: Issue with 1x1 Tensors & tt.call [Bug]: Issue with 1x1 Tensors (arith.cmpi) & tt.call Aug 13, 2024
@JoeLi12345
Copy link
Author

Update: For Issue 2, even this very basic code still produces the same error of tt.call not referencing a valid function. No idea why.

module {

  tt.func @_sum_combine__fp32() -> f32{
    %0 = arith.constant 42.0 : f32
    tt.return %0 : f32
  }

  tt.func @test() -> f32{
    %0 = tt.call @_sum_combine__fp32() : () -> f32
    tt.return %0 : f32
  }

}

@JoeLi12345
Copy link
Author

Seems like the main issue is that tt.call (triton::CallOp) is currently unsupported by triton-shared.

@parsifal-47
Copy link
Contributor

Seems like the main issue is that tt.call (triton::CallOp) is currently unsupported by triton-shared.

That is an easy fix that I can do, I can convert triton call to an LLVM call, but the question is where to get this function which is being called?

@JoeLi12345
Copy link
Author

Thank you!

This function is defined in the same MLIR file. The MLIR file contains the function declaration already, it's just that it cannot be called by tt.call.

@parsifal-47
Copy link
Contributor

Thank you!

This function is defined in the same MLIR file. The MLIR file contains the function declaration already, it's just that it cannot be called by tt.call.

sounds good, let me get back to you with a patch

@parsifal-47
Copy link
Contributor

#164

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants