diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index b5c04f760da2..0d6d98e25574 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -645,14 +645,29 @@ TVM_DLL const Op& ptx_mma_sp();
 TVM_DLL const Op& ptx_ldmatrix();
 
 /*!
- * \brief tvm intrinsics for ptx async copy from global to shared memory
- *
- * void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
- * bytes);
+ * \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async
  *
+ * void ptx_cp_async(Var shared_ptr,
+ *                   Expr shared_offset,
+ *                   Var global_ptr,
+ *                   Expr global_offset,
+ *                   size_t bytes);
  */
 TVM_DLL const Op& ptx_cp_async();
 
+/*!
+ * \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async.bulk
+ *
+ * void ptx_cp_async(Var shared_ptr,
+ *                   Expr shared_offset,
+ *                   Var global_ptr,
+ *                   Expr global_offset,
+ *                   size_t bytes,
+ *                   Var barrier_ptr,
+ *                   Expr barrier_offset);
+ */
+TVM_DLL const Op& ptx_cp_async_bulk();
+
 /*!
  * \brief tvm intrinsics for ptx async copy commit and wait.
  *
@@ -666,7 +681,7 @@ TVM_DLL const Op& ptx_wait_group();
 /*!
  * \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive
  *
- * ptx_cp_async_barrier(barrier_array, barrier_id)
+ * ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset)
  *
  */
 TVM_DLL const Op& ptx_cp_async_barrier();
@@ -674,7 +689,7 @@ TVM_DLL const Op& ptx_cp_async_barrier();
 /*!
  * \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init
  *
- * ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
+ * ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int thread_count)
  *
  */
 TVM_DLL const Op& ptx_init_barrier_thread_count();
@@ -682,15 +697,23 @@ TVM_DLL const Op& ptx_init_barrier_thread_count();
 /*!
  * \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
  *
- * ptx_arrive_barrier(barrier_array, barrier_id)
+ * ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset)
  *
  */
 TVM_DLL const Op& ptx_arrive_barrier();
 
+/*!
+ * \brief tvm intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
+ *
+ * ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int byte_count)
+ *
+ */
+TVM_DLL const Op& ptx_arrive_barrier_expect_tx();
+
 /*!
  * \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
  *
- * ptx_wait_barrier(barrier_array, barrier_id)
+ * ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset)
  *
  */
 TVM_DLL const Op& ptx_wait_barrier();
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index d7bebbacee05..337e06089583 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1847,6 +1847,7 @@ def wrapped(*args, **kwargs):
 ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
 ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
 ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
+ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
 ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
 assume = _op_wrapper(_tir_op.assume)
 undef = _op_wrapper(_tir_op.undef)
@@ -1876,6 +1877,7 @@ def wrapped(*args, **kwargs):
 ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
 ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
 ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk)
 mma_store = _dtype_forward(_tir_op.mma_store)
 mma_fill = _dtype_forward(_tir_op.mma_fill)
 vectorlow = _dtype_forward(_tir_op.vectorlow)
@@ -2115,11 +2117,13 @@ def wrapped(*args, **kwargs):
     "ptx_mma_sp",
     "ptx_ldmatrix",
     "ptx_cp_async",
+    "ptx_cp_async_bulk",
     "ptx_wait_group",
     "ptx_commit_group",
     "ptx_cp_async_barrier",
     "ptx_init_barrier_thread_count",
     "ptx_arrive_barrier",
+    "ptx_arrive_barrier_expect_tx",
     "ptx_wait_barrier",
     "mma_store",
     "mma_fill",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 84c575333712..762fcb599f40 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -63,11 +63,13 @@
 from .op import (
     ptx_ldmatrix,
     ptx_cp_async,
+    ptx_cp_async_bulk,
     ptx_commit_group,
     ptx_wait_group,
     ptx_cp_async_barrier,
     ptx_init_barrier_thread_count,
     ptx_arrive_barrier,
+    ptx_arrive_barrier_expect_tx,
     ptx_wait_barrier,
 )
 from .op import vectorlow, vectorhigh, vectorcombine
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 7e1c520cc432..cb9227e8f2ea 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1335,7 +1335,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme
 
 
 def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
-    """TVM intrinsic for ptx async copy from global to shared memory
+    """TVM intrinsic for ptx async copy from global to shared memory using cp.async
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
 
     Parameters
@@ -1368,6 +1368,56 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by
     )
 
 
+def ptx_cp_async_bulk(
+    dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_ptr, barrier_offset
+):
+    """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
+
+    Parameters
+    ----------
+    dtype : str
+       The data type of the result.
+
+    shared_ptr : Var
+        The shared memory pointer variable.
+
+    shared_offset : Expr
+        The offset of shared memory pointer.
+
+    global_ptr : Var
+        The global memory pointer variable.
+
+    global_offset : Expr
+        The offset of global memory pointer.
+
+    bytes : int
+        The data size to copy.
+
+    barrier_ptr : Var
+        The barrier shared memory pointer variable.
+
+    barrier_id : int
+        The offset of the barrier shared memory pointer.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        dtype,
+        "tir.ptx_cp_async_bulk",
+        shared_ptr,
+        shared_offset,
+        global_ptr,
+        global_offset,
+        bytes,
+        barrier_ptr,
+        barrier_offset,
+    )
+
+
 def ptx_commit_group():
     """TVM intrinsic for ptx async copy commit
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
@@ -1397,37 +1447,40 @@ def ptx_wait_group(num):
     return call_intrin("", "tir.ptx_wait_group", num)
 
 
-def ptx_cp_async_barrier(barrier_arr, barrier_id):
+def ptx_cp_async_barrier(barrier_ptr, barrier_offset):
     """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
 
     Parameters
     ----------
-    barrier_arr : string
-        The name of the barrier array in shared memory
+    barrier_ptr : Var
+        The barrier shared memory pointer variable.
+
     barrier_id : int
-        Index into the barrier array
+        The offset of the barrier shared memory pointer.
 
     Returns
     -------
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)
+    return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, barrier_offset)
 
 
-def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
+def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
     """TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
 
     Parameters
     ----------
-    barrier_arr : string
-        The name of the barrier array in shared memory
+    barrier_ptr : Var
+        The barrier shared memory pointer variable.
+
     barrier_id : int
-        Index into the barrier array
+        The offset of the barrier shared memory pointer.
+
     thread_count : int
-        Number of threads expected to arrive at the barrier
+        Number of threads expected to arrive at the barrier.
 
     Returns
     -------
@@ -1435,46 +1488,75 @@ def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
         The call expression.
     """
     return call_intrin(
-        "", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count
+        "", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, thread_count
     )
 
 
-def ptx_arrive_barrier(barrier_arr, barrier_id):
+def ptx_arrive_barrier(barrier_ptr, barrier_offset):
     """TVM intrinsic for ptx barrier arrival using mbarrier.arrive
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
 
     Parameters
     ----------
-    barrier_arr : string
-        The name of the barrier array in shared memory
+    barrier_ptr : Var
+        The barrier shared memory pointer variable.
+
+    barrier_id : int
+        The offset of the barrier shared memory pointer.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, barrier_offset)
+
+
+def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
+    """TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation
+
+    Parameters
+    ----------
+    barrier_ptr : Var
+        The barrier shared memory pointer variable.
+
     barrier_id : int
-        Index into the barrier array
+        The offset of the barrier shared memory pointer.
+
+    byte_count : int
+        Increases the tx count of the mbarrier object to track completion of
+        addtional async transactions.
 
     Returns
     -------
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)
+    return call_intrin(
+        "", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, byte_count
+    )
 
 
-def ptx_wait_barrier(barrier_arr, barrier_id):
+def ptx_wait_barrier(barrier_ptr, barrier_offset):
     """TVM intrinsic for ptx barrier wait using mbarrier.try_wait
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
 
     Parameters
     ----------
-    barrier_arr : string
-        The name of the barrier array in shared memory
+    barrier_ptr : Var
+        The barrier shared memory pointer variable.
+
     barrier_id : int
-        Index into the barrier array
+        The offset of the barrier shared memory pointer.
 
     Returns
     -------
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)
+    return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset)
 
 
 def vectorlow(dtype, vec):
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index edbe8be0303f..d880b978b5b9 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -953,14 +953,25 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
     std::string src = this->PrintExpr(op->args[2]);
     std::string src_offset = this->PrintExpr(op->args[3]);
     std::string size = this->PrintExpr(op->args[4]);
-    // use size of argument list to indicate whether or not to use predicated cp.async
     need_cast_smem_ptr_to_int_ = true;
+    // use size of argument list to indicate whether or not to use predicated cp.async
     if (op->args.size() == 5) {
       this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
     } else {
       this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size,
                                                      this->PrintExpr(op->args[5]));
     }
+  } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
+    need_cast_smem_ptr_to_int_ = true;
+    std::string dst = this->PrintExpr(op->args[0]);
+    std::string dst_offset = this->PrintExpr(op->args[1]);
+    std::string src = this->PrintExpr(op->args[2]);
+    std::string src_offset = this->PrintExpr(op->args[3]);
+    std::string size = this->PrintExpr(op->args[4]);
+    std::string barrier_ptr = this->PrintExpr(op->args[5]);
+    std::string barrier_offset = this->PrintExpr(op->args[6]);
+    this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier_ptr,
+                                        barrier_offset);
   } else if (op->op.same_as(builtin::ptx_commit_group())) {
     this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
   } else if (op->op.same_as(builtin::ptx_wait_group())) {
@@ -968,29 +979,31 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
     this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n";
   } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
-    std::string barrier_id = this->PrintExpr(op->args[1]);
-    std::string barrier = barriers_arr + "[" + barrier_id + "]";
-    this->stream << PrintCpAsyncBarrierAsm(barrier);
+    std::string barrier_ptr = this->PrintExpr(op->args[0]);
+    std::string barrier_offset = this->PrintExpr(op->args[1]);
+    this->stream << PrintCpAsyncBarrierAsm(barrier_ptr, barrier_offset);
   } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
-    std::string barrier_id = this->PrintExpr(op->args[1]);
-    std::string barrier = barriers_arr + "[" + barrier_id + "]";
+    std::string barrier_ptr = this->PrintExpr(op->args[0]);
+    std::string barrier_offset = this->PrintExpr(op->args[1]);
     std::string thread_count = this->PrintExpr(op->args[2]);
-    this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
+    this->stream << PrintInitBarrierThreadCountAsm(barrier_ptr, barrier_offset, thread_count);
   } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
-    std::string barrier_id = this->PrintExpr(op->args[1]);
-    std::string barrier = barriers_arr + "[" + barrier_id + "]";
-    this->stream << PrintArriveBarrierAsm(barrier);
+    std::string barrier_ptr = this->PrintExpr(op->args[0]);
+    std::string barrier_offset = this->PrintExpr(op->args[1]);
+    this->stream << PrintArriveBarrierAsm(barrier_ptr, barrier_offset);
+  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
+    need_cast_smem_ptr_to_int_ = true;
+    std::string barrier_ptr = this->PrintExpr(op->args[0]);
+    std::string barrier_offset = this->PrintExpr(op->args[1]);
+    std::string byte_count = this->PrintExpr(op->args[2]);
+    this->stream << PrintArriveBarrierExpectTxAsm(barrier_ptr, barrier_offset, byte_count);
   } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
     need_cast_smem_ptr_to_int_ = true;
-    std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
-    std::string barrier_id = this->PrintExpr(op->args[1]);
-    std::string barrier = barriers_arr + "[" + barrier_id + "]";
-    this->stream << PrintWaitBarrierAsm(barrier);
+    std::string barrier_ptr = this->PrintExpr(op->args[0]);
+    std::string barrier_offset = this->PrintExpr(op->args[1]);
+    this->stream << PrintWaitBarrierAsm(barrier_ptr, barrier_offset);
   } else if (op->op.same_as(builtin::ptx_ldg32())) {
     /*
     asm volatile (
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index 6ff57f43bd2d..dd7c7cb7c402 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -709,10 +709,38 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
   return predicated_asm_code;
 }
 
-std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
+std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr,
+                                const std::string& shared_elem_offset,
+                                const std::string& global_ptr,
+                                const std::string& global_elem_offset, const std::string& bytes,
+                                const std::string& barrier_ptr,
+                                const std::string& barrier_elem_offset) {
+  std::string asm_code = R"(
+  {
+    unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr});
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
+    __asm__ __volatile__(
+      "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];"
+      :: "r"(smem_addr_int), "l"({global_ptr}), "r"({bytes}), "r"(barrier_addr_int)
+      : "memory"
+    );
+  }
+)";
+
+  Replacer replacer;
+  replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
+  replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
+  replacer.register_rule("{bytes}", bytes);
+  replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset);
+  asm_code = replacer.rewrite(asm_code);
+  return asm_code;
+}
+
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr,
+                                   const std::string& barrier_elem_offset) {
   std::string predicated_asm_code = R"(
   {
-    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
     __asm__ __volatile__(
       "cp.async.mbarrier.arrive.shared.b64 [%0];"
       :: "r" (barrier_addr_int)
@@ -721,16 +749,17 @@ std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier);
+  replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
 
-std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr,
+                                           const std::string& barrier_elem_offset,
                                            const std::string& thread_count) {
   std::string predicated_asm_code = R"(
   {
-    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
     int thread_count = {thread_count};
     __asm__ __volatile__(
       "mbarrier.init.shared.b64 [%0], %1;"
@@ -740,16 +769,17 @@ std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier);
+  replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset);
   replacer.register_rule("{thread_count}", thread_count);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
 
-std::string PrintArriveBarrierAsm(const std::string& barrier) {
+std::string PrintArriveBarrierAsm(const std::string& barrier_ptr,
+                                  const std::string& barrier_elem_offset) {
   std::string predicated_asm_code = R"(
   {
-    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
     __asm__ __volatile__(
       "{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }"
       :: "r"(barrier_addr_int)
@@ -758,15 +788,37 @@ std::string PrintArriveBarrierAsm(const std::string& barrier) {
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier);
+  replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset);
+  predicated_asm_code = replacer.rewrite(predicated_asm_code);
+  return predicated_asm_code;
+}
+
+std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr,
+                                          const std::string& barrier_elem_offset,
+                                          const std::string& byte_count) {
+  std::string predicated_asm_code = R"(
+  {
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
+    int byte_count = {byte_count};
+    __asm__ __volatile__(
+      "mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
+      :: "r"(barrier_addr_int), "r"(byte_count)
+    );
+  }
+)";
+
+  Replacer replacer;
+  replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset);
+  replacer.register_rule("{byte_count}", byte_count);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
 
-std::string PrintWaitBarrierAsm(const std::string& barrier) {
+std::string PrintWaitBarrierAsm(const std::string& barrier_ptr,
+                                const std::string& barrier_elem_offset) {
   std::string predicated_asm_code = R"(
   {
-    unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
+    unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier});
     constexpr int phase_bit = 0;
     __asm__ __volatile__(
       "{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; @P bra.uni DONE; bra.uni WAIT; DONE: }"
@@ -776,7 +828,7 @@ std::string PrintWaitBarrierAsm(const std::string& barrier) {
 )";
 
   Replacer replacer;
-  replacer.register_rule("{barrier}", barrier);
+  replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset);
   predicated_asm_code = replacer.rewrite(predicated_asm_code);
   return predicated_asm_code;
 }
diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h
index 18519d85f6a4..a73180d40b77 100644
--- a/src/target/source/ptx.h
+++ b/src/target/source/ptx.h
@@ -108,31 +108,67 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
                                            const std::string& bytes,
                                            const std::string& predicate_value);
 
+/*!
+ * \brief Print ptx async copy from global to shared memory using cp.async.bulk
+ * \param shared_ptr: The pointer to the destination shared memory.
+ * \param shared_elem_offset: The offset into the shared memory.
+ * \param global_ptr: The pointer to the global memory.
+ * \param global_elem_offset: The offset into the global memory.
+ * \param bytes: The number of bytes to copy.
+ * \param barrier_ptr: The pointer to the barrier in shared memory.
+ * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ */
+std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr,
+                                const std::string& shared_elem_offset,
+                                const std::string& global_ptr,
+                                const std::string& global_elem_offset, const std::string& bytes,
+                                const std::string& barrier_ptr,
+                                const std::string& barrier_elem_offset);
+
 /*!
  * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
- * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
+ * \param barrier_ptr: The pointer to the barrier in shared memory.
+ * \param barrier_elem_offset: The offset to the barrier in shared memory.
  */
-std::string PrintCpAsyncBarrierAsm(const std::string& barrier);
+std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr,
+                                   const std::string& barrier_elem_offset);
 
 /*!
  * \brief Print ptx barrier initialization of thread count using mbarrier.init
- * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
- * \param thread_count: The number of threads expected to arrive at the barrier
+ * \param barrier_ptr: The pointer to the barrier in shared memory.
+ * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param thread_count: The number of threads expected to arrive at the barrier.
  */
-std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
+std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr,
+                                           const std::string& barrier_elem_offset,
                                            const std::string& thread_count);
 
 /*!
  * \brief Print ptx barrier arrival using mbarrier.arrive
- * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
+ * \param barrier_ptr: The pointer to the barrier in shared memory.
+ * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ */
+std::string PrintArriveBarrierAsm(const std::string& barrier_ptr,
+                                  const std::string& barrier_elem_offset);
+
+/*!
+ * \brief Print ptx barrier arrival with expect tx operation using mbarrier.arrive.expect_tx
+ * \param barrier_ptr: The pointer to the barrier in shared memory.
+ * \param barrier_elem_offset: The offset to the barrier in shared memory.
+ * \param byte_count: Increases the the tx count of the mbarrier object to track completion of
+ * addtional async transactions.
  */
-std::string PrintArriveBarrierAsm(const std::string& barrier);
+std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr,
+                                          const std::string& barrier_elem_offset,
+                                          const std::string& byte_count);
 
 /*!
  * \brief Print ptx barrier wait using mbarrier.try_wait
- * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
+ * \param barrier_ptr: The pointer to the barrier in shared memory.
+ * \param barrier_elem_offset: The offset to the barrier in shared memory.
  */
-std::string PrintWaitBarrierAsm(const std::string& barrier);
+std::string PrintWaitBarrierAsm(const std::string& barrier_ptr,
+                                const std::string& barrier_elem_offset);
 
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 0ca61b409967..a4116abf136f 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -284,6 +284,11 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async)
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
                                          Integer(ScriptDtypePrintLocation::kFirst));
 
+TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         Integer(ScriptDtypePrintLocation::kFirst));
+
 TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
@@ -296,6 +301,8 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier_expect_tx)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py
index 30e9ed2dfac5..e4922e1e0c76 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -234,6 +234,16 @@ def test_op_ptx_cp_async():
     assert expr.op.name == "tir.ptx_cp_async"
 
 
+def test_op_ptx_cp_async_bulk():
+    buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
+    buffer_local = tir.decl_buffer([8], "float16", scope="local")
+    barrier = tir.decl_buffer([1], "uint64", scope="shared")
+    expr = tir.ptx_cp_async_bulk(
+        "float16", buffer_shared.data, 0, buffer_local.data, 0, 16, barrier.data, 0
+    )
+    assert expr.op.name == "tir.ptx_cp_async_bulk"
+
+
 def test_op_ptx_commit_group():
     expr = tir.ptx_commit_group()
     assert expr.op.name == "tir.ptx_commit_group"
@@ -249,17 +259,22 @@ def test_op_ptx_cp_async_barrier():
     assert expr.op.name == "tir.ptx_cp_async_barrier"
 
 
-def ptx_init_barrier_thread_count():
+def test_op_ptx_init_barrier_thread_count():
     expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32)
     assert expr.op.name == "tir.ptx_init_barrier_thread_count"
 
 
-def ptx_arrive_barrier():
+def test_op_ptx_arrive_barrier():
     expr = tir.ptx_arrive_barrier("barrier", 0)
     assert expr.op.name == "tir.ptx_arrive_barrier"
 
 
-def ptx_wait_barrier():
+def test_op_ptx_arrive_barrier_expect_tx():
+    expr = tir.ptx_arrive_barrier_expect_tx("barrier", 0, 32)
+    assert expr.op.name == "tir.ptx_arrive_barrier_expect_tx"
+
+
+def test_op_ptx_wait_barrier():
     expr = tir.ptx_wait_barrier("barrier", 0)
     assert expr.op.name == "tir.ptx_wait_barrier"
 
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py
index 0e61f6d1b4f9..e6d3942ce500 100644
--- a/tests/python/unittest/test_tir_ptx_cp_async.py
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -61,5 +61,117 @@ def test_ptx_cp_async():
     tvm.testing.assert_allclose(B_nd.numpy(), A_np)
 
 
+@T.prim_func
+def ptx_cp_async_barrier(
+    A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")
+) -> None:
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    bx = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(bx, 1)
+    T.launch_thread(tx, 32)
+    with T.block():
+        # Shared memory targets for cp.async.bulk must be 16 byte aligned
+        # Problem: CUDA codegen does not support allocation alignment
+        # Workaround: Ensure that `A_shared` occurs before `barrier` in program order
+        #             by allocating and initializing `A_shared` before `barrier`
+        #             which should result in `A_shared` being 16+ byte aligned
+        #             given it will be the first shared memory allocation
+        # TODO(Straw) Add CUDA codegen support for allocation alignment
+        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
+        A_shared[0, 0] = 0
+
+        barrier = T.alloc_buffer([1], "uint64", scope="shared")
+        barrier[0] = 0
+
+        T.reads(A[0:32, 0:128])
+        T.writes(B[0:32, 0:128])
+
+        T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype=""))
+
+        for i in range(16):
+            T.evaluate(
+                T.ptx_cp_async(
+                    A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16"
+                )
+            )
+
+        T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype=""))
+        T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype=""))
+        T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+
+        for i in range(128):
+            B[tx, i] = A_shared[tx, i]
+
+
+@tvm.testing.requires_cuda_compute_version(8)
+def test_ptx_cp_async_barrier():
+    f = ptx_cp_async_barrier
+
+    mod = tvm.build(f, target="cuda")
+    A_np = np.random.rand(32, 128).astype("float16")
+    B_np = np.zeros((32, 128)).astype("float16")
+    dev = tvm.cuda(0)
+    A_nd = tvm.nd.array(A_np, device=dev)
+    B_nd = tvm.nd.array(B_np, device=dev)
+    mod(A_nd, B_nd)
+    tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
+@T.prim_func
+def ptx_cp_async_bulk(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")) -> None:
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    bx = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(bx, 1)
+    T.launch_thread(tx, 32)
+    with T.block():
+        # Shared memory targets for cp.async.bulk must be 16 byte aligned
+        # Problem: CUDA codegen does not support allocation alignment
+        # Workaround: Ensure that `A_shared` occurs before `barrier` in program order
+        #             by allocating and initializing `A_shared` before `barrier`
+        #             which should result in `A_shared` being 16+ byte aligned
+        #             given it will be the first shared memory allocation
+        # TODO(Straw) Add CUDA codegen support for allocation alignment
+        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared", align=16)
+        A_shared[0, 0] = 0
+
+        barrier = T.alloc_buffer([1], "uint64", scope="shared")
+        barrier[0] = 0
+
+        T.reads(A[0:32, 0:128])
+        T.writes(B[0:32, 0:128])
+
+        T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype=""))
+
+        T.evaluate(
+            T.ptx_cp_async_bulk(
+                A_shared.data, tx * 128, A.data, tx * 128, 256, barrier.data, 0, dtype="float16"
+            )
+        )
+
+        T.evaluate(T.ptx_arrive_barrier_expect_tx(barrier.data, 0, 256, dtype=""))
+        T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
+
+        for i in range(128):
+            B[tx, i] = A_shared[tx, i]
+
+
+@tvm.testing.requires_cuda_compute_version(9)
+def test_ptx_cp_async_bulk():
+    f = ptx_cp_async_bulk
+
+    mod = tvm.build(f, target="cuda")
+    A_np = np.random.rand(32, 128).astype("float16")
+    B_np = np.zeros((32, 128)).astype("float16")
+    dev = tvm.cuda(0)
+    A_nd = tvm.nd.array(A_np, device=dev)
+    B_nd = tvm.nd.array(B_np, device=dev)
+    mod(A_nd, B_nd)
+    tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
 if __name__ == "__main__":
     test_ptx_cp_async()
+    test_ptx_cp_async_barrier()
+    test_ptx_cp_async_bulk()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 5d866199e79b..ff70eeae81ab 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -199,15 +199,15 @@ def ptx_global_to_shared_copy_fp32x1_barrier(
         T.writes(B[0:32, 0:128], barrier[0:1])
 
         barrier[0] = 0
-        T.evaluate(T.ptx_init_barrier_thread_count("barrier", 0, 32, dtype=""))
+        T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype=""))
 
         T.attr("default", "async_scope", 1)
         for i in T.serial(128):
             A_shared[tx, i] = A[tx, i]
 
-        T.evaluate(T.ptx_cp_async_barrier("barrier", 0, dtype=""))
-        T.evaluate(T.ptx_arrive_barrier("barrier", 0, dtype=""))
-        T.evaluate(T.ptx_wait_barrier("barrier", 0, dtype=""))
+        T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype=""))
+        T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype=""))
+        T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype=""))
 
         for i in range(128):
             B[tx, i] = A_shared[tx, i]