Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]Why the ptr type of tt.atomic_rmw don't allow TT_TensorPtr? #4672

Open
tfruan2000 opened this issue Sep 7, 2024 · 2 comments
Open

[Bug]Why the ptr type of tt.atomic_rmw don't allow TT_TensorPtr? #4672

tfruan2000 opened this issue Sep 7, 2024 · 2 comments

Comments

@tfruan2000
Copy link
Contributor

tfruan2000 commented Sep 7, 2024

Hi, guys~

I am a bit confused about the definition of tt.atomic_rmw in TritonOps.td.

Currently, the type verification for the ptr and val operands is done using getPointerTypeSameShape.

def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
  SameOperandsAndResultShape,
  SameOperandsAndResultEncoding,
  MemoryEffects<[MemRead<GlobalMemory>]>,
  MemoryEffects<[MemWrite<GlobalMemory>]>,
  TypesMatchWith<"ptr type matches value type", "val", "ptr", 
                 "getPointerTypeSameShape($_self)">, // here, used `getPointerTypeSameShape`
  ...

However, the behavior of this op is similar to tt.load and tt.store, but in tt.load and tt.store, getPointeeType is used for verification.

def TT_StoreOp : TT_Op<"store", [
  SameLoadStoreOperandsShape,
  SameLoadStoreOperandsEncoding,
  MemoryEffects<[MemWrite<GlobalMemory>]>,
  TypesMatchWith<"value type matches ptr type", "ptr", "value",
                 "getPointeeType($_self)">,  // here, used `getPointeeType`

This leads to the following IR describing tt.store being valid

  %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16>
  %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : !tt.ptr<tensor<128x32xf16>>
  tt.store %0, %cst : !tt.ptr<tensor<128x32xf16>>

while the corresponding IR for tt.atomic_rmw is considered invalid

error: 'tt.atomic_rmw' op failed to verify that ptr type matches value type
  %1 = tt.atomic_rmw fadd, relaxed, gpu, %0, %cst, %mask : (!tt.ptr<tensor<128x32xf16>>, tensor<128x32xf16>, tensor<128x32xi1>) -> tensor<128x32xf16>
       ^
tmp.mlir:12:8: note: see current operation: %10 = "tt.atomic_rmw"(%9, %8, %arg1) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 1 : i32}> : (!tt.ptr<tensor<128x32xf16>>, tensor<128x32xf16>, tensor<128x32xi1>) -> tensor<128x32xf16>

And the type of ptr in atomic_rmw don't allow TT_PtrLike(!ptr<tensor<>>).

Could some one explain why we don't define tt.atomic_rmw same as tt.load, like:

def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
  SameLoadStoreOperandsAndResultShape, // before: SameOperandsAndResultShape
  SameLoadStoreOperandsAndResultEncoding, // before:  SameOperandsAndResultEncoding,
  MemoryEffects<[MemRead<GlobalMemory>]>,
  MemoryEffects<[MemWrite<GlobalMemory>]>,
  TypesMatchWith<"ptr type matches value type", "ptr", "val",
                 "getPointeeType($_self)">, // before:  "val", "ptr",  "getPointerTypeSameShape($_self)"
  ...
]> {
    ...
    let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op,
                      AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, // before: TT_PtrLike:$ptr

thx~

@tfruan2000
Copy link
Contributor Author

I can make the follow ir legal by change the td file like:

%1 = tt.atomic_rmw fadd, relaxed, gpu, %0, %cst, %mask : (!tt.ptr<tensor<128x32xf16>>, tensor<128x32xf16>, tensor<128x32xi1>) -> tensor<128x32xf16>
image

but I’m not sure if this change is correct

@tfruan2000 tfruan2000 changed the title [Bug]Why tt.atomic_rmw don't use getPointeeType to verify the type of val and ptr? [Bug]Why the ptr type of tt.atomic_rmw don't allow TT_TensorPtr? Sep 7, 2024
@tfruan2000
Copy link
Contributor Author

I noticed that the PR Support block pointer semantics #1392 was the first to add the TT_TensorPtr type to tt.load and tt.store. Could it be that the pointer type for atomic operations was overlooked, or is it meanless for atomic ops to support TT_TensorPtr like load and store

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant