diff --git a/rwkv_pip_package/src/rwkv/cuda/rwkv5.cu b/rwkv_pip_package/src/rwkv/cuda/rwkv5.cu index a7f13c3a..d955697f 100644 --- a/rwkv_pip_package/src/rwkv/cuda/rwkv5.cu +++ b/rwkv_pip_package/src/rwkv/cuda/rwkv5.cu @@ -15,7 +15,7 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int const int i = threadIdx.x; _w += h*_N_; _u += h*_N_; - _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! + _state += b*H*_N_*_N_ + h*_N_*_N_ + i*_N_; // Correct if B >= 1 !!! __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];