Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prefetch device transfer for ptrs to CPU (#529)
The FlashInfer kernel [here](https://github.com/flashinfer-ai/flashinfer/blob/main/python/flashinfer/jit/batch_prefill_templ.py#L52C15-L53) does: ``` qo_indptr = qo_indptr.to(torch::kCPU); kv_indptr = kv_indptr.to(torch::kCPU); ``` which is a blocking device synchronization for the CPU worker. We would like to avoid this for certain optimizations. Accordingly, this PR schedules the device transfer ahead of time in the python code before the kernel to avoid blocking the CPU worker.
- Loading branch information