diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index b204ab3f..0b159871 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -253,6 +253,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -356,6 +357,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp.real_ = 0.f; delta_a_exp.imag_ = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 8ecf126d..42a95b9d 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -223,6 +223,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -248,6 +249,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; thread_data[i].y = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else {