diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index b5c04f760da2..0d6d98e25574 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -645,14 +645,29 @@ TVM_DLL const Op& ptx_mma_sp(); TVM_DLL const Op& ptx_ldmatrix(); /*! - * \brief tvm intrinsics for ptx async copy from global to shared memory - * - * void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t - * bytes); + * \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async * + * void ptx_cp_async(Var shared_ptr, + * Expr shared_offset, + * Var global_ptr, + * Expr global_offset, + * size_t bytes); */ TVM_DLL const Op& ptx_cp_async(); +/*! + * \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async.bulk + * + * void ptx_cp_async(Var shared_ptr, + * Expr shared_offset, + * Var global_ptr, + * Expr global_offset, + * size_t bytes, + * Var barrier_ptr, + * Expr barrier_offset); + */ +TVM_DLL const Op& ptx_cp_async_bulk(); + /*! * \brief tvm intrinsics for ptx async copy commit and wait. * @@ -666,7 +681,7 @@ TVM_DLL const Op& ptx_wait_group(); /*! * \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive * - * ptx_cp_async_barrier(barrier_array, barrier_id) + * ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset) * */ TVM_DLL const Op& ptx_cp_async_barrier(); @@ -674,7 +689,7 @@ TVM_DLL const Op& ptx_cp_async_barrier(); /*! * \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init * - * ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count) + * ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int thread_count) * */ TVM_DLL const Op& ptx_init_barrier_thread_count(); @@ -682,15 +697,23 @@ TVM_DLL const Op& ptx_init_barrier_thread_count(); /*! * \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive * - * ptx_arrive_barrier(barrier_array, barrier_id) + * ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset) * */ TVM_DLL const Op& ptx_arrive_barrier(); +/*! + * \brief tvm intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx + * + * ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int byte_count) + * + */ +TVM_DLL const Op& ptx_arrive_barrier_expect_tx(); + /*! * \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait * - * ptx_wait_barrier(barrier_array, barrier_id) + * ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset) * */ TVM_DLL const Op& ptx_wait_barrier(); diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index d7bebbacee05..337e06089583 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1847,6 +1847,7 @@ def wrapped(*args, **kwargs): ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier) ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count) ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier) +ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx) ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) @@ -1876,6 +1877,7 @@ def wrapped(*args, **kwargs): ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) +ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) mma_store = _dtype_forward(_tir_op.mma_store) mma_fill = _dtype_forward(_tir_op.mma_fill) vectorlow = _dtype_forward(_tir_op.vectorlow) @@ -2115,11 +2117,13 @@ def wrapped(*args, **kwargs): "ptx_mma_sp", "ptx_ldmatrix", "ptx_cp_async", + "ptx_cp_async_bulk", "ptx_wait_group", "ptx_commit_group", "ptx_cp_async_barrier", "ptx_init_barrier_thread_count", "ptx_arrive_barrier", + "ptx_arrive_barrier_expect_tx", "ptx_wait_barrier", "mma_store", "mma_fill", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 84c575333712..762fcb599f40 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -63,11 +63,13 @@ from .op import ( ptx_ldmatrix, ptx_cp_async, + ptx_cp_async_bulk, ptx_commit_group, ptx_wait_group, ptx_cp_async_barrier, ptx_init_barrier_thread_count, ptx_arrive_barrier, + ptx_arrive_barrier_expect_tx, ptx_wait_barrier, ) from .op import vectorlow, vectorhigh, vectorcombine diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 7e1c520cc432..cb9227e8f2ea 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1335,7 +1335,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes): - """TVM intrinsic for ptx async copy from global to shared memory + """TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async Parameters @@ -1368,6 +1368,56 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by ) +def ptx_cp_async_bulk( + dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_ptr, barrier_offset +): + """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk + + Parameters + ---------- + dtype : str + The data type of the result. + + shared_ptr : Var + The shared memory pointer variable. + + shared_offset : Expr + The offset of shared memory pointer. + + global_ptr : Var + The global memory pointer variable. + + global_offset : Expr + The offset of global memory pointer. + + bytes : int + The data size to copy. + + barrier_ptr : Var + The barrier shared memory pointer variable. + + barrier_id : int + The offset of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + "tir.ptx_cp_async_bulk", + shared_ptr, + shared_offset, + global_ptr, + global_offset, + bytes, + barrier_ptr, + barrier_offset, + ) + + def ptx_commit_group(): """TVM intrinsic for ptx async copy commit https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group @@ -1397,37 +1447,40 @@ def ptx_wait_group(num): return call_intrin("", "tir.ptx_wait_group", num) -def ptx_cp_async_barrier(barrier_arr, barrier_id): +def ptx_cp_async_barrier(barrier_ptr, barrier_offset): """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive Parameters ---------- - barrier_arr : string - The name of the barrier array in shared memory + barrier_ptr : Var + The barrier shared memory pointer variable. + barrier_id : int - Index into the barrier array + The offset of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id) + return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, barrier_offset) -def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count): +def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count): """TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init Parameters ---------- - barrier_arr : string - The name of the barrier array in shared memory + barrier_ptr : Var + The barrier shared memory pointer variable. + barrier_id : int - Index into the barrier array + The offset of the barrier shared memory pointer. + thread_count : int - Number of threads expected to arrive at the barrier + Number of threads expected to arrive at the barrier. Returns ------- @@ -1435,46 +1488,75 @@ def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count): The call expression. """ return call_intrin( - "", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count + "", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, thread_count ) -def ptx_arrive_barrier(barrier_arr, barrier_id): +def ptx_arrive_barrier(barrier_ptr, barrier_offset): """TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive Parameters ---------- - barrier_arr : string - The name of the barrier array in shared memory + barrier_ptr : Var + The barrier shared memory pointer variable. + + barrier_id : int + The offset of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, barrier_offset) + + +def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count): + """TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation + + Parameters + ---------- + barrier_ptr : Var + The barrier shared memory pointer variable. + barrier_id : int - Index into the barrier array + The offset of the barrier shared memory pointer. + + byte_count : int + Increases the tx count of the mbarrier object to track completion of + addtional async transactions. Returns ------- call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id) + return call_intrin( + "", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, byte_count + ) -def ptx_wait_barrier(barrier_arr, barrier_id): +def ptx_wait_barrier(barrier_ptr, barrier_offset): """TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait Parameters ---------- - barrier_arr : string - The name of the barrier array in shared memory + barrier_ptr : Var + The barrier shared memory pointer variable. + barrier_id : int - Index into the barrier array + The offset of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id) + return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset) def vectorlow(dtype, vec): diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index edbe8be0303f..d880b978b5b9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -953,14 +953,25 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); - // use size of argument list to indicate whether or not to use predicated cp.async need_cast_smem_ptr_to_int_ = true; + // use size of argument list to indicate whether or not to use predicated cp.async if (op->args.size() == 5) { this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); } else { this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5])); } + } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { + need_cast_smem_ptr_to_int_ = true; + std::string dst = this->PrintExpr(op->args[0]); + std::string dst_offset = this->PrintExpr(op->args[1]); + std::string src = this->PrintExpr(op->args[2]); + std::string src_offset = this->PrintExpr(op->args[3]); + std::string size = this->PrintExpr(op->args[4]); + std::string barrier_ptr = this->PrintExpr(op->args[5]); + std::string barrier_offset = this->PrintExpr(op->args[6]); + this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier_ptr, + barrier_offset); } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { @@ -968,29 +979,31 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n"; } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { need_cast_smem_ptr_to_int_ = true; - std::string barriers_arr = Downcast<StringImm>(op->args[0])->value; - std::string barrier_id = this->PrintExpr(op->args[1]); - std::string barrier = barriers_arr + "[" + barrier_id + "]"; - this->stream << PrintCpAsyncBarrierAsm(barrier); + std::string barrier_ptr = this->PrintExpr(op->args[0]); + std::string barrier_offset = this->PrintExpr(op->args[1]); + this->stream << PrintCpAsyncBarrierAsm(barrier_ptr, barrier_offset); } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { need_cast_smem_ptr_to_int_ = true; - std::string barriers_arr = Downcast<StringImm>(op->args[0])->value; - std::string barrier_id = this->PrintExpr(op->args[1]); - std::string barrier = barriers_arr + "[" + barrier_id + "]"; + std::string barrier_ptr = this->PrintExpr(op->args[0]); + std::string barrier_offset = this->PrintExpr(op->args[1]); std::string thread_count = this->PrintExpr(op->args[2]); - this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); + this->stream << PrintInitBarrierThreadCountAsm(barrier_ptr, barrier_offset, thread_count); } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { need_cast_smem_ptr_to_int_ = true; - std::string barriers_arr = Downcast<StringImm>(op->args[0])->value; - std::string barrier_id = this->PrintExpr(op->args[1]); - std::string barrier = barriers_arr + "[" + barrier_id + "]"; - this->stream << PrintArriveBarrierAsm(barrier); + std::string barrier_ptr = this->PrintExpr(op->args[0]); + std::string barrier_offset = this->PrintExpr(op->args[1]); + this->stream << PrintArriveBarrierAsm(barrier_ptr, barrier_offset); + } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { + need_cast_smem_ptr_to_int_ = true; + std::string barrier_ptr = this->PrintExpr(op->args[0]); + std::string barrier_offset = this->PrintExpr(op->args[1]); + std::string byte_count = this->PrintExpr(op->args[2]); + this->stream << PrintArriveBarrierExpectTxAsm(barrier_ptr, barrier_offset, byte_count); } else if (op->op.same_as(builtin::ptx_wait_barrier())) { need_cast_smem_ptr_to_int_ = true; - std::string barriers_arr = Downcast<StringImm>(op->args[0])->value; - std::string barrier_id = this->PrintExpr(op->args[1]); - std::string barrier = barriers_arr + "[" + barrier_id + "]"; - this->stream << PrintWaitBarrierAsm(barrier); + std::string barrier_ptr = this->PrintExpr(op->args[0]); + std::string barrier_offset = this->PrintExpr(op->args[1]); + this->stream << PrintWaitBarrierAsm(barrier_ptr, barrier_offset); } else if (op->op.same_as(builtin::ptx_ldg32())) { /* asm volatile ( diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 6ff57f43bd2d..dd7c7cb7c402 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -709,10 +709,38 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, return predicated_asm_code; } -std::string PrintCpAsyncBarrierAsm(const std::string& barrier) { +std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes, + const std::string& barrier_ptr, + const std::string& barrier_elem_offset) { + std::string asm_code = R"( + { + unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr}); + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + __asm__ __volatile__( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" + :: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int) + : "memory" + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + +std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset) { std::string predicated_asm_code = R"( { - unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); __asm__ __volatile__( "cp.async.mbarrier.arrive.shared.b64 [%0];" :: "r" (barrier_addr_int) @@ -721,16 +749,17 @@ std::string PrintCpAsyncBarrierAsm(const std::string& barrier) { )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier); + replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } -std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset, const std::string& thread_count) { std::string predicated_asm_code = R"( { - unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); int thread_count = {thread_count}; __asm__ __volatile__( "mbarrier.init.shared.b64 [%0], %1;" @@ -740,16 +769,17 @@ std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier); + replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); replacer.register_rule("{thread_count}", thread_count); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } -std::string PrintArriveBarrierAsm(const std::string& barrier) { +std::string PrintArriveBarrierAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset) { std::string predicated_asm_code = R"( { - unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); __asm__ __volatile__( "{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }" :: "r"(barrier_addr_int) @@ -758,15 +788,37 @@ std::string PrintArriveBarrierAsm(const std::string& barrier) { )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier); + replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset, + const std::string& byte_count) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); + int byte_count = {byte_count}; + __asm__ __volatile__( + "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + :: "r"(barrier_addr_int), "r"(byte_count) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + replacer.register_rule("{byte_count}", byte_count); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } -std::string PrintWaitBarrierAsm(const std::string& barrier) { +std::string PrintWaitBarrierAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset) { std::string predicated_asm_code = R"( { - unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); constexpr int phase_bit = 0; __asm__ __volatile__( "{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; @P bra.uni DONE; bra.uni WAIT; DONE: }" @@ -776,7 +828,7 @@ std::string PrintWaitBarrierAsm(const std::string& barrier) { )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier); + replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index 18519d85f6a4..a73180d40b77 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -108,31 +108,67 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, const std::string& bytes, const std::string& predicate_value); +/*! + * \brief Print ptx async copy from global to shared memory using cp.async.bulk + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy. + * \param barrier_ptr: The pointer to the barrier in shared memory. + * \param barrier_elem_offset: The offset to the barrier in shared memory. + */ +std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, const std::string& bytes, + const std::string& barrier_ptr, + const std::string& barrier_elem_offset); + /*! * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive - * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] + * \param barrier_ptr: The pointer to the barrier in shared memory. + * \param barrier_elem_offset: The offset to the barrier in shared memory. */ -std::string PrintCpAsyncBarrierAsm(const std::string& barrier); +std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset); /*! * \brief Print ptx barrier initialization of thread count using mbarrier.init - * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] - * \param thread_count: The number of threads expected to arrive at the barrier + * \param barrier_ptr: The pointer to the barrier in shared memory. + * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param thread_count: The number of threads expected to arrive at the barrier. */ -std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset, const std::string& thread_count); /*! * \brief Print ptx barrier arrival using mbarrier.arrive - * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] + * \param barrier_ptr: The pointer to the barrier in shared memory. + * \param barrier_elem_offset: The offset to the barrier in shared memory. + */ +std::string PrintArriveBarrierAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset); + +/*! + * \brief Print ptx barrier arrival with expect tx operation using mbarrier.arrive.expect_tx + * \param barrier_ptr: The pointer to the barrier in shared memory. + * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param byte_count: Increases the the tx count of the mbarrier object to track completion of + * addtional async transactions. */ -std::string PrintArriveBarrierAsm(const std::string& barrier); +std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset, + const std::string& byte_count); /*! * \brief Print ptx barrier wait using mbarrier.try_wait - * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] + * \param barrier_ptr: The pointer to the barrier in shared memory. + * \param barrier_elem_offset: The offset to the barrier in shared memory. */ -std::string PrintWaitBarrierAsm(const std::string& barrier); +std::string PrintWaitBarrierAsm(const std::string& barrier_ptr, + const std::string& barrier_elem_offset); } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0ca61b409967..a4116abf136f 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -284,6 +284,11 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async) .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); +TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk) + .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -296,6 +301,8 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier_expect_tx) + .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index 30e9ed2dfac5..e4922e1e0c76 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -234,6 +234,16 @@ def test_op_ptx_cp_async(): assert expr.op.name == "tir.ptx_cp_async" +def test_op_ptx_cp_async_bulk(): + buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared") + buffer_local = tir.decl_buffer([8], "float16", scope="local") + barrier = tir.decl_buffer([1], "uint64", scope="shared") + expr = tir.ptx_cp_async_bulk( + "float16", buffer_shared.data, 0, buffer_local.data, 0, 16, barrier.data, 0 + ) + assert expr.op.name == "tir.ptx_cp_async_bulk" + + def test_op_ptx_commit_group(): expr = tir.ptx_commit_group() assert expr.op.name == "tir.ptx_commit_group" @@ -249,17 +259,22 @@ def test_op_ptx_cp_async_barrier(): assert expr.op.name == "tir.ptx_cp_async_barrier" -def ptx_init_barrier_thread_count(): +def test_op_ptx_init_barrier_thread_count(): expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32) assert expr.op.name == "tir.ptx_init_barrier_thread_count" -def ptx_arrive_barrier(): +def test_op_ptx_arrive_barrier(): expr = tir.ptx_arrive_barrier("barrier", 0) assert expr.op.name == "tir.ptx_arrive_barrier" -def ptx_wait_barrier(): +def test_op_ptx_arrive_barrier_expect_tx(): + expr = tir.ptx_arrive_barrier_expect_tx("barrier", 0, 32) + assert expr.op.name == "tir.ptx_arrive_barrier_expect_tx" + + +def test_op_ptx_wait_barrier(): expr = tir.ptx_wait_barrier("barrier", 0) assert expr.op.name == "tir.ptx_wait_barrier" diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index 0e61f6d1b4f9..e6d3942ce500 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -61,5 +61,117 @@ def test_ptx_cp_async(): tvm.testing.assert_allclose(B_nd.numpy(), A_np) +@T.prim_func +def ptx_cp_async_barrier( + A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16") +) -> None: + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + # Shared memory targets for cp.async.bulk must be 16 byte aligned + # Problem: CUDA codegen does not support allocation alignment + # Workaround: Ensure that `A_shared` occurs before `barrier` in program order + # by allocating and initializing `A_shared` before `barrier` + # which should result in `A_shared` being 16+ byte aligned + # given it will be the first shared memory allocation + # TODO(Straw) Add CUDA codegen support for allocation alignment + A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") + A_shared[0, 0] = 0 + + barrier = T.alloc_buffer([1], "uint64", scope="shared") + barrier[0] = 0 + + T.reads(A[0:32, 0:128]) + T.writes(B[0:32, 0:128]) + + T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype="")) + + for i in range(16): + T.evaluate( + T.ptx_cp_async( + A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16" + ) + ) + + T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype="")) + T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype="")) + T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype="")) + + for i in range(128): + B[tx, i] = A_shared[tx, i] + + +@tvm.testing.requires_cuda_compute_version(8) +def test_ptx_cp_async_barrier(): + f = ptx_cp_async_barrier + + mod = tvm.build(f, target="cuda") + A_np = np.random.rand(32, 128).astype("float16") + B_np = np.zeros((32, 128)).astype("float16") + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@T.prim_func +def ptx_cp_async_bulk(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")) -> None: + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + # Shared memory targets for cp.async.bulk must be 16 byte aligned + # Problem: CUDA codegen does not support allocation alignment + # Workaround: Ensure that `A_shared` occurs before `barrier` in program order + # by allocating and initializing `A_shared` before `barrier` + # which should result in `A_shared` being 16+ byte aligned + # given it will be the first shared memory allocation + # TODO(Straw) Add CUDA codegen support for allocation alignment + A_shared = T.alloc_buffer([32, 128], "float16", scope="shared", align=16) + A_shared[0, 0] = 0 + + barrier = T.alloc_buffer([1], "uint64", scope="shared") + barrier[0] = 0 + + T.reads(A[0:32, 0:128]) + T.writes(B[0:32, 0:128]) + + T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype="")) + + T.evaluate( + T.ptx_cp_async_bulk( + A_shared.data, tx * 128, A.data, tx * 128, 256, barrier.data, 0, dtype="float16" + ) + ) + + T.evaluate(T.ptx_arrive_barrier_expect_tx(barrier.data, 0, 256, dtype="")) + T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype="")) + + for i in range(128): + B[tx, i] = A_shared[tx, i] + + +@tvm.testing.requires_cuda_compute_version(9) +def test_ptx_cp_async_bulk(): + f = ptx_cp_async_bulk + + mod = tvm.build(f, target="cuda") + A_np = np.random.rand(32, 128).astype("float16") + B_np = np.zeros((32, 128)).astype("float16") + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + if __name__ == "__main__": test_ptx_cp_async() + test_ptx_cp_async_barrier() + test_ptx_cp_async_bulk() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 5d866199e79b..ff70eeae81ab 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -199,15 +199,15 @@ def ptx_global_to_shared_copy_fp32x1_barrier( T.writes(B[0:32, 0:128], barrier[0:1]) barrier[0] = 0 - T.evaluate(T.ptx_init_barrier_thread_count("barrier", 0, 32, dtype="")) + T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype="")) T.attr("default", "async_scope", 1) for i in T.serial(128): A_shared[tx, i] = A[tx, i] - T.evaluate(T.ptx_cp_async_barrier("barrier", 0, dtype="")) - T.evaluate(T.ptx_arrive_barrier("barrier", 0, dtype="")) - T.evaluate(T.ptx_wait_barrier("barrier", 0, dtype="")) + T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype="")) + T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype="")) + T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype="")) for i in range(128): B[tx, i] = A_shared[tx, i]