From a00492e3fa4ed4386c6cd409b40b83731c1c090c Mon Sep 17 00:00:00 2001 From: jayhshah Date: Fri, 26 Jul 2024 00:09:49 +0000 Subject: [PATCH] clean up unneeded methods and variants --- hopper/flash_api.cpp | 2 - hopper/flash_fwd_kernel.h | 70 +-- hopper/kernel_traits.h | 26 +- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 620 +---------------------- hopper/named_barrier.hpp | 4 - hopper/setup.py | 33 +- hopper/softmax.h | 8 +- 7 files changed, 45 insertions(+), 718 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index ad765c7d0..abee9a55b 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -223,7 +223,6 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split // run_mha_fwd_(params, stream); // }); if (!params.is_e4m3) { - #if 0 if (params.is_bf16) { if (params.d == 64) { run_mha_fwd_(params, stream); @@ -241,7 +240,6 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split run_mha_fwd_(params, stream); } } - #endif } else { if (params.d == 64) { run_mha_fwd_(params, stream); diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index e6575efba..d5518616c 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -194,8 +194,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = Ktraits::kBlockM; // static constexpr int kBlockN = Ktraits::kBlockN; - static constexpr int kHeadDim = Ktraits::kHeadDim; - static constexpr bool Delay_V_release = Is_causal && kHeadDim == 128; + // static constexpr int kHeadDim = Ktraits::kHeadDim; + static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128; using CollectiveMainloop = CollectiveMainloopFwd; using CollectiveEpilogue = CollectiveEpilogueFwd; @@ -238,11 +238,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (warp_idx == 0 && lane_predicate) { shared_storage.barrier_Q.init(1 /*numThreads*/); -#ifndef NO_UNION - #ifndef NEW_FP8_EPI_BARRIER shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); - #endif -#endif } // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); @@ -266,15 +262,9 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); - - #ifdef USE_TRI_MMA_FP8 - PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); - PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); - PipelineState smem_pipe_read_v; - #else PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState smem_pipe_read, smem_pipe_release; - #endif + int work_idx = 0; @@ -289,32 +279,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, if (Is_causal && n_block_max <= 0) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); - // TODO: remove this + // need to sync producer warpgroup cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); continue; } - #ifdef USE_TRI_MMA_FP8 - collective_mainloop.load_fp8_ver1( - mainloop_params, pipeline_k, pipeline_v, pipeline_vt, - smem_pipe_write_k, smem_pipe_write_v, smem_pipe_read_v, shared_storage, - scheduler, scheduler_params, work_tile_info, block_coord, work_idx); - #else + collective_mainloop.load_fp8( mainloop_params, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, smem_pipe_read, shared_storage, - scheduler, scheduler_params, work_tile_info, block_coord, work_idx); - #endif + scheduler, scheduler_params, work_tile_info, block_coord, work_idx); ++work_idx; - // need to sync producer warpgroup - // TODO: remove this - // if (Is_causal) - // cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); + // don't need to sync producer warpgroup here + // if constexpr (Is_causal) { + // cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); } } - #ifdef USE_TRI_MMA_FP8 - collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); - #else collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write); - #endif + } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); @@ -322,12 +302,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, TileScheduler scheduler(&shared_storage.tile_count_semaphore); // Initialize matmul objects. typename Ktraits::TiledMma1 tiled_mma1; - #ifdef USE_TRI_MMA_FP8 - PipelineState smem_pipe_read_k, smem_pipe_read_vt; - #else PipelineState smem_pipe_read; PipelineState smem_pipe_release; - #endif collective_mainloop.mma_init_fp8(); scheduler.init_consumer(); @@ -349,32 +325,16 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord); continue; } - - #ifdef USE_TRI_MMA_FP8 - collective_mainloop.mma_fp8_ver1( - mainloop_params, pipeline_k, pipeline_vt, - smem_pipe_read_k, smem_pipe_read_vt, - tOrO, softmax, n_block_max, - threadIdx.x - NumCopyThreads, work_idx, m_block, - shared_storage); - #else - // collective_mainloop.mma_fp8( - // mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, - // smem_pipe_release, tOrO, softmax, n_block_max, - // threadIdx.x - NumCopyThreads, work_idx, m_block, - // shared_storage); - collective_mainloop.mma_fp8_ver2( + + collective_mainloop.mma_fp8( mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release, tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, - shared_storage); - #endif + shared_storage); - #ifdef COLUMN_PERMUTE + #ifndef NO_FP8_COLUMN_PERMUTE collective_epilogue.store_fp8(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, - threadIdx.x - NumCopyThreads, block_coord); - // collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, - // threadIdx.x - NumCopyThreads, block_coord); + threadIdx.x - NumCopyThreads, block_coord); #else collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, threadIdx.x - NumCopyThreads, block_coord); diff --git a/hopper/kernel_traits.h b/hopper/kernel_traits.h index 17593db0a..bbf438bca 100644 --- a/hopper/kernel_traits.h +++ b/hopper/kernel_traits.h @@ -39,33 +39,15 @@ struct SharedStorageQKVOVt { struct { cute::array_aligned> smem_q; cute::array_aligned> smem_k; - cute::array_aligned> smem_v; -#ifdef NO_UNION - cute::array_aligned> smem_v_out; - cute::array_aligned> smem_o; -#else + cute::array_aligned> smem_v; union { cute::array_aligned> smem_v_out; cute::array_aligned> smem_o; }; -#endif - // union { - // struct { - // cute::array_aligned> smem_v; - // cute::array_aligned> smem_v_out; - // }; - // struct { - // cute::array_aligned> smem_o; - // }; - // }; }; struct { - cutlass::arch::ClusterTransactionBarrier barrier_Q; -#ifndef NO_UNION - #ifndef NEW_FP8_EPI_BARRIER + cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; - #endif -#endif typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; typename cutlass::PipelineAsync::SharedStorage pipeline_vt; @@ -155,7 +137,7 @@ struct Flash_fwd_kernel_traits { }; -// Traits struct for fp8 kernel +// Traits struct for fp8 kernel with in-kernel transpose template struct Flash_fwd_kernel_traits_fp8 { @@ -230,7 +212,7 @@ struct Flash_fwd_kernel_traits_fp8 { decltype(composition(SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{}))); using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); -#ifdef COLUMN_PERMUTE +#ifndef NO_FP8_COLUMN_PERMUTE using SmemShapeSTSM = Shape, Shape<_8, _8>>; #else using SmemShapeSTSM = Shape, Shape<_16, _4>>; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index bc85f45cf..5dedd6d07 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -36,7 +36,7 @@ struct SmemTransposeFp8_64x64 { using stsm_thread_shape = Shape<_4, _1, _8, _4>; // using stsm_thread_stride = Stride<_1, _0, _4, _32>; -#ifdef COLUMN_PERMUTE +#ifndef NO_FP8_COLUMN_PERMUTE using stsm_value_shape = Shape<_4, _4, _1, _2>; using stsm_value_stride = Stride<_1, _8, _0, _4>; #else @@ -312,149 +312,6 @@ struct CollectiveMainloopFwd { scheduler.broadcast_next_work(work_tile_info); } - - template - CUTLASS_DEVICE void - load_fp8_ver1(Params const& mainloop_params, - MainloopPipeline pipeline_k, - MainloopPipeline pipeline_v, - MainloopPipelineNoTMA pipeline_vt, - PipelineState& smem_pipe_write_k, - PipelineState& smem_pipe_write_v, - PipelineState& smem_pipe_read_v, - SharedStorage &shared_storage, - Scheduler& scheduler, - typename Scheduler::Params const& scheduler_params, - typename Scheduler::WorkTileInfo& work_tile_info, - cute::tuple block_coord, - int work_idx - ) { - - using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV; - using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt; - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); - - Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{})); - Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{})); - - auto smem_transpose_V = SmemTransposeFp8_64x64(); - auto do_transpose_V = [&](int stage) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) { - smem_transpose_V(flatten(sV_divide(_, i, j, stage)), - flatten(sVt_divide(_, i, j, stage))); - } - } - }; - - Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.shape_Q); - Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K); - Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K); - - auto [m_block, bidh, bidb] = block_coord; - int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); - - // Prepare the TMA loads - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - - Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); - Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); - auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, - group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA) - auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout{}, - group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE) - auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout{}, - group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE) - - uint16_t mcast_mask_kv = 0; - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); - } - } - - int n_block_max = get_n_block_max(mainloop_params, m_block); - int n_block = n_block_max - 1; - - int lane_predicate = cute::elect_one_sync(); - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write_k); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), - tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); - ++smem_pipe_write_k; - } - - // Wait for the MMA warpgroups to say that smem_q is ready - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); - // do first V load - pipeline_v.producer_acquire(smem_pipe_write_v); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), - tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); - } - - --n_block; - - // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem - // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the - // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. - shared_storage.barrier_O.wait((work_idx + 1) % 2); - - #pragma unroll 2 - for (; n_block >= 0; --n_block) { - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write_k); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), - tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); - ++smem_pipe_write_k; - } - - pipeline_v.consumer_wait(smem_pipe_read_v); - pipeline_vt.producer_acquire(smem_pipe_write_v); - do_transpose_V(smem_pipe_read_v.index()); - pipeline_vt.producer_commit(smem_pipe_write_v); - pipeline_v.consumer_release(smem_pipe_read_v); - - ++smem_pipe_write_v; - ++smem_pipe_read_v; - - if (warp_idx_in_warpgroup == 0 && lane_predicate) { - pipeline_v.producer_acquire(smem_pipe_write_v); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), - tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); - } - } - - scheduler.prefetch_next_work(scheduler_params, work_tile_info); - - // do last transpose - pipeline_v.consumer_wait(smem_pipe_read_v); - pipeline_vt.producer_acquire(smem_pipe_write_v); - do_transpose_V(smem_pipe_read_v.index()); - pipeline_vt.producer_commit(smem_pipe_write_v); - pipeline_v.consumer_release(smem_pipe_read_v); - - ++smem_pipe_write_v; - ++smem_pipe_read_v; - - scheduler.broadcast_next_work(work_tile_info); - } - template CUTLASS_DEVICE void load_fp8(Params const& mainloop_params, @@ -552,7 +409,6 @@ struct CollectiveMainloopFwd { shared_storage.barrier_O.wait((work_idx + 1) % 2); - #if 1 CUTLASS_PRAGMA_UNROLL for (int iter = 0; iter < kStages && n_block > 0; ++iter, --n_block) { pipeline_v.consumer_wait(smem_pipe_read); @@ -574,7 +430,6 @@ struct CollectiveMainloopFwd { } } - #endif #pragma unroll 2 for (; n_block > 0; --n_block) { @@ -618,24 +473,15 @@ struct CollectiveMainloopFwd { if (warp_idx_in_warpgroup == 0 && lane_predicate) { shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); - // do first V load pipeline_v.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv), tVgV(_, n_block), tVsV(_, smem_pipe_write.index())); } - - // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem - // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the - // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. - // NOTE: for fp8 can replace with NamedBarrier. - - #ifndef NO_UNION - #ifdef NEW_FP8_EPI_BARRIER - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::OutEmpty) /*id*/); - #else + + // With fp8, smem_o is in union with smem_v_out, + // so could use NamedBarrier instead of ClusterBarrier. + // But, this doesn't appear to have any benefit. shared_storage.barrier_O.wait((work_idx + 1) % 2); - #endif - #endif pipeline_v.consumer_wait(smem_pipe_read); // pipeline_vt.producer_acquire(smem_pipe_write); @@ -696,7 +542,6 @@ struct CollectiveMainloopFwd { // scheduler.broadcast_next_work(work_tile_info); } - } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster @@ -770,16 +615,9 @@ struct CollectiveMainloopFwd { } CUTLASS_DEVICE void - mma_init_fp8() { - + mma_init_fp8() { // Tell producer (warpgroup 0) that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - #ifdef NEW_FP8_EPI_BARRIER - // For fp8, use NamedBarrier::OutEmpty for epilogue sync - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::OutEmpty) /*id*/); - #endif - - // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::TransposeVProducer) /*id*/); if constexpr (!UseSchedulerBarrier) { return; } static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); @@ -956,443 +794,9 @@ struct CollectiveMainloopFwd { return; } - template - CUTLASS_DEVICE void - mma_fp8(Params const& mainloop_params, - MainloopPipeline pipeline_k, - MainloopPipelineNoTMA pipeline_vt, - PipelineState& smem_pipe_read, - PipelineState& smem_pipe_release, - FrgTensorO& tOrO, - Softmax& softmax, - int n_block_count, - int thread_idx, - int work_idx, - int m_block, - SharedStorage& shared_storage - ) { - static_assert(is_rmem::value, "O tensor must be rmem resident."); - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{}); - - typename Ktraits::TiledMma0 tiled_mma0; - typename Ktraits::TiledMma1 tiled_mma1; - auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); - auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); - - // Allocate "fragments/descriptors" for first matmul. - Tensor tSrQ = threadMma0.partition_fragment_A(sQ); - Tensor tSrK = threadMma0.partition_fragment_B(sK); - // Allocate "fragments/descriptors" for second matmul. - // Note: S becomes P. - Tensor tOrV = threadMma1.partition_fragment_B(sVt); - - auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - }; - - tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; - int const seqlen_q = get<0>(mainloop_params.shape_Q); - int const seqlen_k = get<0>(mainloop_params.shape_K); - int n_block = n_block_count - 1; - - // consumer wait for Q - cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); - if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } - - // Allocate accumulator for S - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - // first consumer wait for K - consumer_wait(pipeline_k, smem_pipe_read); // wait K - - warp_scheduler_barrier_sync(); // used for pingpong - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - #ifndef NO_UNION - warp_scheduler_barrier_arrive(); // used for pingpong - // overlap first gemm with prior epilogue - if (work_idx != 0) { - #ifdef NEW_FP8_EPI_BARRIER - tma_store_wait<0>(); - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::OutEmpty) /*id*/); - #else - int lane_predicate = cute::elect_one_sync(); - if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { - tma_store_wait<0>(); - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.barrier_O.arrive(cta_id, lane_predicate); - } - } - #endif - } - warpgroup_wait<0>(); - #else - warpgroup_wait<0>(); - warp_scheduler_barrier_arrive(); // used for pingpong - #endif - - - #ifndef RELEASE_PATTERN - pipeline_k.consumer_release(smem_pipe_read); // release current K - #endif - -#if 0 - auto col_limit_causal = [&](int row, int n_block) { - return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; - }; - { - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); - Tensor tScS = threadMma0.partition_C(cS); - #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if constexpr (!Is_causal) { // Just masking based on col - if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } - } else { // mask based on both row and col - // using std::min is faster than doing col >= limit0 or col >= limit1 - // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the - // right hand side can be negative and might be converted to a very large unsigned integer. - if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, - col_limit_causal(int(get<0>(tScS(i))), n_block))) { - tSrS(i) = -INFINITY; - } - } - } - } -#endif - -#ifdef USE_CUSTOM_SOFTMAX - softmax.template online_softmax_and_rescale_o(tSrS, tOrO, mainloop_params.softmax_scale_log2); -#else - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); -#endif - // for fp8 use different layout reshape - Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); - permute_regs_A_to_C(tOrP); - -#ifndef USE_CUSTOM_SOFTMAX - Tensor scores_scale = make_fragment_like(softmax.row_max); - clear(scores_scale); -#endif - - consumer_wait(pipeline_vt, smem_pipe_read); // wait V - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - - #ifndef RELEASE_PATTERN - pipeline_vt.consumer_release(smem_pipe_read); // release current V - #endif - - ++smem_pipe_read; // advance pipeline state read - // ++smem_pipe_release; // advance pipeline state release (if not staggering) - --n_block; - -#if 0 - constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; - // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal - #pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_vt, smem_pipe_read_v); - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - warp_scheduler_barrier_arrive(); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read_k); // release K - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); - Tensor tScS = threadMma0.partition_C(cS); - #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) { - tSrS(i) = -INFINITY; - } - } - cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - warpgroup_wait<0>(); - pipeline_vt.consumer_release(smem_pipe_read_v); // release V - ++smem_pipe_read_k; - ++smem_pipe_read_v; - cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); - } -#endif - -#if 1 - constexpr int extra_iterations = kStages - 1; - CUTLASS_PRAGMA_UNROLL - for (int iter = 0; iter < extra_iterations && n_block >= 0 ; ++iter) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - #ifdef RELEASE_PATTERN - pipeline_k.consumer_release(smem_pipe_release); // release previous K - #endif - consumer_wait(pipeline_k, smem_pipe_read); // wait K - warp_scheduler_barrier_sync(); // used for pingpong - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - warp_scheduler_barrier_arrive(); // used for pingpong - // warpgroup_wait<0>(); - - #ifndef RELEASE_PATTERN - pipeline_k.consumer_release(smem_pipe_read); // release current K - #endif - - #ifdef USE_CUSTOM_SOFTMAX - softmax.template online_softmax_and_rescale_o - (tSrS, tOrO, mainloop_params.softmax_scale_log2); - #else - cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - softmax.rescale_o(tOrO, scores_scale); - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - #endif - - Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); - permute_regs_A_to_C(tOrP); - - #ifdef RELEASE_PATTERN - pipeline_vt.consumer_release(smem_pipe_release); // release previous V - #endif - consumer_wait(pipeline_vt, smem_pipe_read); // wait V - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - #ifndef RELEASE_PATTERN - pipeline_vt.consumer_release(smem_pipe_read); // release current V - #endif - - ++smem_pipe_read; // advance pipeline state read - #ifdef RELEASE_PATTERN - ++smem_pipe_release; // advance pipeline state release - #endif - - --n_block; - } -#endif - - CUTLASS_PRAGMA_NO_UNROLL - for (; n_block >= 0; --n_block) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - #ifdef RELEASE_PATTERN - pipeline_k.consumer_release(smem_pipe_release); // release previous K - #endif - consumer_wait(pipeline_k, smem_pipe_read); // wait K - warp_scheduler_barrier_sync(); // used for pingpong - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - warp_scheduler_barrier_arrive(); // used for pingpong - // warpgroup_wait<0>(); - - #ifndef RELEASE_PATTERN - pipeline_k.consumer_release(smem_pipe_read); // release current K - #endif - - #ifdef USE_CUSTOM_SOFTMAX - softmax.template online_softmax_and_rescale_o - (tSrS, tOrO, mainloop_params.softmax_scale_log2); - #else - cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - softmax.rescale_o(tOrO, scores_scale); - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - #endif - - Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); - permute_regs_A_to_C(tOrP); - - #ifdef RELEASE_PATTERN - pipeline_vt.consumer_release(smem_pipe_release); // release previous V - #endif - consumer_wait(pipeline_vt, smem_pipe_read); // wait V - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); - #ifndef RELEASE_PATTERN - pipeline_vt.consumer_release(smem_pipe_read); // release current V - #endif - - ++smem_pipe_read; // advance pipeline state read - #ifdef RELEASE_PATTERN - ++smem_pipe_release; // advance pipeline state release - #endif - } - - // Tell warpgroup 0 that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - - #ifdef USE_CUSTOM_SOFTMAX - Tensor scores_scale = make_fragment_like(softmax.row_max); - #endif - cute::copy(softmax.template finalize(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - softmax.rescale_o(tOrO, scores_scale); - -#ifdef RELEASE_PATTERN - pipeline_k.consumer_release(smem_pipe_release); // release end K - pipeline_vt.consumer_release(smem_pipe_release); // release end V - ++smem_pipe_release; // advance pipeline state release -#endif - return; - - } - - template - CUTLASS_DEVICE void - mma_fp8_ver1(Params const& mainloop_params, - MainloopPipeline pipeline_k, - MainloopPipelineNoTMA pipeline_vt, - PipelineState& smem_pipe_read_k, - PipelineState& smem_pipe_read_vt, - FrgTensorO& tOrO, - Softmax& softmax, - int n_block_count, - int thread_idx, - int work_idx, - int m_block, - SharedStorage& shared_storage - ) { - static_assert(is_rmem::value, "O tensor must be rmem resident."); - - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{}); - - typename Ktraits::TiledMma0 tiled_mma0; - typename Ktraits::TiledMma1 tiled_mma1; - auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); - auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); - - // Allocate "fragments/descriptors" for first matmul. - Tensor tSrQ = threadMma0.partition_fragment_A(sQ); - Tensor tSrK = threadMma0.partition_fragment_B(sK); - // Allocate "fragments/descriptors" for second matmul. - // Note: S becomes P. - Tensor tOrV = threadMma1.partition_fragment_B(sVt); - - auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - }; - - tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; - int const seqlen_q = get<0>(mainloop_params.shape_Q); - int const seqlen_k = get<0>(mainloop_params.shape_K); - int n_block = n_block_count - 1; - - // consumer wait for Q - cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); - if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } - - // Allocate accumulator for S - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - // first consumer wait for K - consumer_wait(pipeline_k, smem_pipe_read_k); // wait K - - warp_scheduler_barrier_sync(); // used for pingpong - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - // overlap first gemm with prior epilogue - if (work_idx != 0) { - int lane_predicate = cute::elect_one_sync(); - if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { - tma_store_wait<0>(); - #pragma unroll - for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { - shared_storage.barrier_O.arrive(cta_id, lane_predicate); - } - } - } - warpgroup_wait<0>(); - warp_scheduler_barrier_arrive(); // used for pingpong - pipeline_k.consumer_release(smem_pipe_read_k); // release current K - ++smem_pipe_read_k; - - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - - // for fp8 use different layout reshape - Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout())); - - Tensor scores_scale = make_fragment_like(softmax.row_max); - clear(scores_scale); - - // CUTLASS_PRAGMA_NO_UNROLL - // for (; n_block > 0; --n_block) { - - // softmax.rescale_o(tOrO, scores_scale); - // consumer_wait(pipeline_vt, smem_pipe_read_vt); // wait V - // warp_scheduler_barrier_sync(); // used for pingpong - // flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_vt.index()), tOrO); - // warp_scheduler_barrier_arrive(); // used for pingpong - // pipeline_vt.consumer_release(smem_pipe_read_vt); // release current V - - // Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - - // consumer_wait(pipeline_k, smem_pipe_read_k); // wait K - // flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - - // pipeline_k.consumer_release(smem_pipe_read_k); // release current K - - // cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - - // softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - - - // ++smem_pipe_read_k; // advance pipeline state read - // ++smem_pipe_read_vt; // advance pipeline state read - - // cute::copy(make_tensor(convert_type(tSrS).data(), - // convert_layout_acc_Aregs_fp8(tSrS.layout())), tOrP); - // permute_regs_A_to_C(tOrP); - // } - - CUTLASS_PRAGMA_NO_UNROLL - for (; n_block > 0; --n_block) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read_k); // wait K - warp_scheduler_barrier_sync(); // used for pingpong - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - permute_regs_A_to_C(tOrP); - softmax.rescale_o(tOrO, scores_scale); - consumer_wait(pipeline_vt, smem_pipe_read_vt); // wait V - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_vt.index()), tOrO); - warp_scheduler_barrier_arrive(); // used for pingpong - warpgroup_wait<1>(); - - pipeline_k.consumer_release(smem_pipe_read_k); // release current K - - cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - - warpgroup_wait<0>(); - pipeline_vt.consumer_release(smem_pipe_read_vt); // release current V - ++smem_pipe_read_k; // advance pipeline state read - ++smem_pipe_read_vt; // advance pipeline state read - - cute::copy(make_tensor(convert_type(tSrS).data(), - convert_layout_acc_Aregs_fp8(tSrS.layout())), tOrP); - - } - - // Tell warpgroup 0 that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - - permute_regs_A_to_C(tOrP); - softmax.rescale_o(tOrO, scores_scale); - consumer_wait(pipeline_vt, smem_pipe_read_vt); - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_vt.index()), tOrO); - - cute::copy(softmax.template finalize(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - warpgroup_wait<0>(); - pipeline_vt.consumer_release(smem_pipe_read_vt); // release V, otherwise producers will hang - ++smem_pipe_read_vt; - - softmax.rescale_o(tOrO, scores_scale); - return; - } - template CUTLASS_DEVICE void - mma_fp8_ver2(Params const& mainloop_params, + mma_fp8(Params const& mainloop_params, MainloopPipeline pipeline_k, MainloopPipelineNoTMA pipeline_vt, PipelineState& smem_pipe_read, @@ -1442,8 +846,7 @@ struct CollectiveMainloopFwd { consumer_wait(pipeline_k, smem_pipe_read); warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - // warp_scheduler_barrier_arrive(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if (work_idx != 0) { int lane_predicate = cute::elect_one_sync(); if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { @@ -1453,7 +856,7 @@ struct CollectiveMainloopFwd { shared_storage.barrier_O.arrive(cta_id, lane_predicate); } } - } + } warpgroup_wait<0>(); warp_scheduler_barrier_arrive(); pipeline_k.consumer_release(smem_pipe_read); @@ -1527,9 +930,7 @@ struct CollectiveMainloopFwd { if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); } ++smem_pipe_read; } - } - #if 1 - else { + } else { CUTLASS_PRAGMA_UNROLL for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) { Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); @@ -1557,7 +958,6 @@ struct CollectiveMainloopFwd { ++smem_pipe_read; } } - #endif if constexpr(Delay_V_release) { warp_scheduler_barrier_sync(); diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp index 67be81cde..202f1aaeb 100644 --- a/hopper/named_barrier.hpp +++ b/hopper/named_barrier.hpp @@ -19,10 +19,6 @@ enum class FwdNamedBarriers { WarpSchedulerWG2 = 5, WarpSchedulerWG3 = 6, ProducerWG = 7 - -// #ifdef NEW_FP8_EPI_BARRIER -// OutEmpty = 7 -// #endif }; } // flash \ No newline at end of file diff --git a/hopper/setup.py b/hopper/setup.py index 13921e01b..68ba650c8 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -110,15 +110,15 @@ def append_nvcc_threads(nvcc_extra_args): cutlass_dir = repo_dir / "csrc" / "cutlass" sources = [ "flash_api.cpp", - # "flash_fwd_hdim64_fp16_sm90.cu", - # "flash_fwd_hdim64_bf16_sm90.cu", - # "flash_fwd_hdim128_fp16_sm90.cu", - # "flash_fwd_hdim128_bf16_sm90.cu", - # "flash_fwd_hdim256_fp16_sm90.cu", - # "flash_fwd_hdim256_bf16_sm90.cu", - # "flash_bwd_hdim64_fp16_sm90.cu", - # "flash_bwd_hdim128_fp16_sm90.cu", - # "flash_bwd_hdim256_fp16_sm90.cu", + "flash_fwd_hdim64_fp16_sm90.cu", + "flash_fwd_hdim64_bf16_sm90.cu", + "flash_fwd_hdim128_fp16_sm90.cu", + "flash_fwd_hdim128_bf16_sm90.cu", + "flash_fwd_hdim256_fp16_sm90.cu", + "flash_fwd_hdim256_bf16_sm90.cu", + "flash_bwd_hdim64_fp16_sm90.cu", + "flash_bwd_hdim128_fp16_sm90.cu", + "flash_bwd_hdim256_fp16_sm90.cu", "flash_fwd_hdim64_e4m3_sm90.cu", "flash_fwd_hdim128_e4m3_sm90.cu", "flash_fwd_hdim256_e4m3_sm90.cu" @@ -136,18 +136,11 @@ def append_nvcc_threads(nvcc_extra_args): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - # "--ptxas-options=-v", # printing out number of registers - # "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers - # "-lineinfo", + "--ptxas-options=-v", # printing out number of registers + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers + "-lineinfo", "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging - "-DNDEBUG", # Important, otherwise performance is severely impacted - "-DCOLUMN_PERMUTE", - # "-DDISABLE_CAUSAL", - # "-DUSE_TRI_MMA_FP8" - # "-DUSE_CUSTOM_SOFTMAX", - # "-DNO_UNION" - # "-DNEW_FP8_EPI_BARRIER", - "-DRELEASE_PATTERN", + "-DNDEBUG", # Important, otherwise performance is severely impacted ] include_dirs = [ # Path(this_dir) / "fmha-pipeline", diff --git a/hopper/softmax.h b/hopper/softmax.h index 1f4e0c9e3..89a48527e 100644 --- a/hopper/softmax.h +++ b/hopper/softmax.h @@ -128,8 +128,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor &tenso //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct Softmax { - +struct Softmax { constexpr static float max_offset = 8.0f; using TensorT = decltype(make_tensor(Shape>{})); @@ -194,8 +193,7 @@ struct Softmax { } return scores_scale; }; - - // TODO: handle offset + template __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float softmax_scale_log2, float rp_dropout=1.0) { // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) @@ -226,7 +224,7 @@ struct Softmax { } }; - // combined method + // combined online softmax method with arbitrary predication template __forceinline__ __device__ void