Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control flow for individual threads in a block #4737

Open
jiashenC opened this issue Sep 17, 2024 · 0 comments
Open

Control flow for individual threads in a block #4737

jiashenC opened this issue Sep 17, 2024 · 0 comments

Comments

@jiashenC
Copy link

jiashenC commented Sep 17, 2024

I am experimenting if I can build a hashtable in Triton. Below code snippet shows my example kernel

import torch
import triton
import triton.language as tl

@triton.jit
def build_key_only_hashtable_kernel(
    key_ptr,
    bitmap_ptr,
    hashtable_ptr,
    size,
    hashtable_size,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offset = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offset < size
    bitmap = tl.load(bitmap_ptr + offset, mask=mask)
    mask = mask & bitmap.cast(tl.int1)
    key = tl.load(key_ptr + offset, mask=mask)
    hashkey = key % hashtable_size
    flag = tl.atomic_cas(hashtable_ptr + hashkey, tl.zeros_like(key), key)
    flag = flag.cast(tl.int1)
    while flag:
        hashkey = (hashkey + 1) % hashtable_size
        flag = tl.atomic_cas(hashtable_ptr + hashkey, tl.zeros_like(key), key)
        flag = flag.cast(tl.int1)

inp = torch.arange(0, 100, 1).to(torch.int32).to("cuda")
bitmap = torch.ones(100).to(torch.bool).to("cuda")
hashtable = torch.zeros(140).to(torch.int32).to("cuda")
grid = lambda meta: (triton.cdiv(100, meta["BLOCK_SIZE"]), )
build_key_only_hashtable_kernel[grid](
    inp, 
    bitmap,
    hashtable,
    100,
    140,
    BLOCK_SIZE=16,
)

It gives error

test/test_correct.py python3: /source/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::detail::TypedValuemlir::IntegerType, From = mlir::Value]: Assertion `isa(Val) && "cast() argument of incompatible type!"' failed.
Fatal Python error: Aborted

In this case, threads in each block might have divergence. I wonder if something Triton is capable of or if something has not been supported yet.

@jiashenC jiashenC changed the title LLVM complains about incompatible types Control flow for individual threads in a block Sep 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant