Skip to content

Commit

Permalink
Prefetch device transfer for ptrs to CPU (#529)
Browse files Browse the repository at this point in the history
The FlashInfer kernel
[here](https://github.com/flashinfer-ai/flashinfer/blob/main/python/flashinfer/jit/batch_prefill_templ.py#L52C15-L53)
does:
```
qo_indptr = qo_indptr.to(torch::kCPU);
kv_indptr = kv_indptr.to(torch::kCPU);
```
which is a blocking device synchronization for the CPU worker. We would
like to avoid this for certain optimizations.

Accordingly, this PR schedules the device transfer ahead of time in the
python code before the kernel to avoid blocking the CPU worker.
  • Loading branch information
mvpatel2000 authored Oct 15, 2024
1 parent 93b5d4e commit ddef3f3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,9 @@ def plan(
if self.use_tensor_cores:
self._qo_indptr_buf = qo_indptr.to(self.device)

qo_indptr = qo_indptr.to('cpu', non_blocking=True)
indptr = indptr.to('cpu', non_blocking=True)

data_type = canonicalize_torch_dtype(data_type)
if not q_data_type:
q_data_type = data_type
Expand Down
3 changes: 3 additions & 0 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,9 @@ def plan(
self._custom_mask_buf = packed_custom_mask.to(self.device)
self._qk_indptr_buf = qk_indptr.to(self.device)

qo_indptr = qo_indptr.to('cpu', non_blocking=True)
paged_kv_indptr = paged_kv_indptr.to('cpu', non_blocking=True)

if packed_custom_mask is not None:
mask_mode = MaskMode.CUSTOM.value
else:
Expand Down

0 comments on commit ddef3f3

Please sign in to comment.