diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index c9c7bd551..c33eb2ec4 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -989,6 +989,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; + at::Tensor tmp_rng_state; if ( rng_state.has_value() ) { params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); @@ -997,6 +998,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); auto seeds = at::cuda::philox::unpack(params.philox_args); + tmp_rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); + params.rng_state = reinterpret_cast(tmp_rng_state.data_ptr()); params.rng_state[0] = std::get<0>(seeds); params.rng_state[1] = std::get<1>(seeds); } @@ -1243,6 +1246,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; + at::Tensor tmp_rng_state; if ( rng_state.has_value() ) { params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); @@ -1251,6 +1255,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); auto seeds = at::cuda::philox::unpack(params.philox_args); + tmp_rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); + params.rng_state = reinterpret_cast(tmp_rng_state.data_ptr()); params.rng_state[0] = std::get<0>(seeds); params.rng_state[1] = std::get<1>(seeds); }