Skip to content

Commit

Permalink
[Hopper TMA] Add CUDA codegen support for bulk asynchronous copy (apa…
Browse files Browse the repository at this point in the history
…che#15656)

* [Hopper TMA] Add CUDA codegen support for bulk asynchronous copy

* fix typo in comments; use barrier ptr and offset rather than string
  • Loading branch information
adstraw authored Sep 5, 2023
1 parent 04ee895 commit d26fdcf
Show file tree
Hide file tree
Showing 11 changed files with 421 additions and 75 deletions.
39 changes: 31 additions & 8 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,14 +645,29 @@ TVM_DLL const Op& ptx_mma_sp();
TVM_DLL const Op& ptx_ldmatrix();

/*!
* \brief tvm intrinsics for ptx async copy from global to shared memory
*
* void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
* bytes);
* \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async
*
* void ptx_cp_async(Var shared_ptr,
* Expr shared_offset,
* Var global_ptr,
* Expr global_offset,
* size_t bytes);
*/
TVM_DLL const Op& ptx_cp_async();

/*!
* \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async.bulk
*
* void ptx_cp_async(Var shared_ptr,
* Expr shared_offset,
* Var global_ptr,
* Expr global_offset,
* size_t bytes,
* Var barrier_ptr,
* Expr barrier_offset);
*/
TVM_DLL const Op& ptx_cp_async_bulk();

/*!
* \brief tvm intrinsics for ptx async copy commit and wait.
*
Expand All @@ -666,31 +681,39 @@ TVM_DLL const Op& ptx_wait_group();
/*!
* \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive
*
* ptx_cp_async_barrier(barrier_array, barrier_id)
* ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset)
*
*/
TVM_DLL const Op& ptx_cp_async_barrier();

/*!
* \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init
*
* ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
* ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int thread_count)
*
*/
TVM_DLL const Op& ptx_init_barrier_thread_count();

/*!
* \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
*
* ptx_arrive_barrier(barrier_array, barrier_id)
* ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset)
*
*/
TVM_DLL const Op& ptx_arrive_barrier();

/*!
* \brief tvm intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
*
* ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int byte_count)
*
*/
TVM_DLL const Op& ptx_arrive_barrier_expect_tx();

/*!
* \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
*
* ptx_wait_barrier(barrier_array, barrier_id)
* ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset)
*
*/
TVM_DLL const Op& ptx_wait_barrier();
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,7 @@ def wrapped(*args, **kwargs):
ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
Expand Down Expand Up @@ -1876,6 +1877,7 @@ def wrapped(*args, **kwargs):
ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
mma_store = _dtype_forward(_tir_op.mma_store)
mma_fill = _dtype_forward(_tir_op.mma_fill)
vectorlow = _dtype_forward(_tir_op.vectorlow)
Expand Down Expand Up @@ -2115,11 +2117,13 @@ def wrapped(*args, **kwargs):
"ptx_mma_sp",
"ptx_ldmatrix",
"ptx_cp_async",
"ptx_cp_async_bulk",
"ptx_wait_group",
"ptx_commit_group",
"ptx_cp_async_barrier",
"ptx_init_barrier_thread_count",
"ptx_arrive_barrier",
"ptx_arrive_barrier_expect_tx",
"ptx_wait_barrier",
"mma_store",
"mma_fill",
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@
from .op import (
ptx_ldmatrix,
ptx_cp_async,
ptx_cp_async_bulk,
ptx_commit_group,
ptx_wait_group,
ptx_cp_async_barrier,
ptx_init_barrier_thread_count,
ptx_arrive_barrier,
ptx_arrive_barrier_expect_tx,
ptx_wait_barrier,
)
from .op import vectorlow, vectorhigh, vectorcombine
Expand Down
126 changes: 104 additions & 22 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme


def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
"""TVM intrinsic for ptx async copy from global to shared memory
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
Parameters
Expand Down Expand Up @@ -1368,6 +1368,56 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by
)


def ptx_cp_async_bulk(
dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_ptr, barrier_offset
):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
Parameters
----------
dtype : str
The data type of the result.
shared_ptr : Var
The shared memory pointer variable.
shared_offset : Expr
The offset of shared memory pointer.
global_ptr : Var
The global memory pointer variable.
global_offset : Expr
The offset of global memory pointer.
bytes : int
The data size to copy.
barrier_ptr : Var
The barrier shared memory pointer variable.
barrier_id : int
The offset of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.ptx_cp_async_bulk",
shared_ptr,
shared_offset,
global_ptr,
global_offset,
bytes,
barrier_ptr,
barrier_offset,
)


def ptx_commit_group():
"""TVM intrinsic for ptx async copy commit
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
Expand Down Expand Up @@ -1397,84 +1447,116 @@ def ptx_wait_group(num):
return call_intrin("", "tir.ptx_wait_group", num)


def ptx_cp_async_barrier(barrier_arr, barrier_id):
def ptx_cp_async_barrier(barrier_ptr, barrier_offset):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.
barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, barrier_offset)


def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.
barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.
thread_count : int
Number of threads expected to arrive at the barrier
Number of threads expected to arrive at the barrier.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count
"", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, thread_count
)


def ptx_arrive_barrier(barrier_arr, barrier_id):
def ptx_arrive_barrier(barrier_ptr, barrier_offset):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.
barrier_id : int
The offset of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, barrier_offset)


def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
"""TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
Parameters
----------
barrier_ptr : Var
The barrier shared memory pointer variable.
barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.
byte_count : int
Increases the tx count of the mbarrier object to track completion of
addtional async transactions.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)
return call_intrin(
"", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, byte_count
)


def ptx_wait_barrier(barrier_arr, barrier_id):
def ptx_wait_barrier(barrier_ptr, barrier_offset):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_ptr : Var
The barrier shared memory pointer variable.
barrier_id : int
Index into the barrier array
The offset of the barrier shared memory pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)
return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset)


def vectorlow(dtype, vec):
Expand Down
Loading

0 comments on commit d26fdcf

Please sign in to comment.