Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Update apply mask kernels #128

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions cpp/matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,68 @@ void _DebugGetMaskedTokensFromBitmask(
}
}

void ApplyTokenBitmaskInplaceCPU(
DLTensor* logits, const DLTensor& bitmask, std::optional<std::vector<int>> indices
) {
XGRAMMAR_CHECK(logits->device.device_type == kDLCPU)
<< "The provided logits's device is not valid: should be CPU";
XGRAMMAR_CHECK(bitmask.device.device_type == kDLCPU)
<< "The provided bitmask's device is not valid: should be CPU";
int batch_size;
int vocab_size;
if (logits->ndim == 2) {
batch_size = logits->shape[0];
vocab_size = logits->shape[1];
} else {
batch_size = 1;
vocab_size = logits->shape[0];
}
int bitmask_size = GetBitmaskSize(vocab_size);
if (bitmask.ndim == 2) {
XGRAMMAR_CHECK(bitmask.shape[0] == batch_size)
<< "The provided bitmask's batch size is not consistent with logits";
XGRAMMAR_CHECK(bitmask.shape[1] == bitmask_size)
<< "The provided bitmask's bitmask size is not consistent with logits";
} else {
XGRAMMAR_CHECK(bitmask.ndim == 1)
<< "The provided bitmask's shape is not valid: should be (batch_size, vocab_size)";
XGRAMMAR_CHECK(bitmask.shape[0] == bitmask_size)
<< "The provided bitmask's bitmask size is not consistent with logits";
}
XGRAMMAR_CHECK(
logits->dtype.code == kDLFloat && logits->dtype.bits == 32 && logits->dtype.lanes == 1
) << "The provided logits's dtype is not valid: should be float32";
XGRAMMAR_CHECK(
bitmask.dtype.code == kDLInt && bitmask.dtype.bits == 32 && bitmask.dtype.lanes == 1
) << "The provided bitmask's dtype is not valid: should be int32";

std::vector<int> indices_value;
if (indices.has_value()) {
indices_value = indices.value();
std::sort(indices_value.begin(), indices_value.end());
indices_value.erase(
std::unique(indices_value.begin(), indices_value.end()), indices_value.end()
);
XGRAMMAR_CHECK(indices_value.back() < batch_size)
<< "The provided indices is out of bounds: " << indices_value.back()
<< " >= " << batch_size;
} else {
indices_value.resize(batch_size);
for (int i = 0; i < batch_size; ++i) {
indices_value[i] = i;
}
}

for (auto idx : indices_value) {
uint32_t* data_ptr = reinterpret_cast<uint32_t*>(bitmask.data) + idx * bitmask_size;
DynamicBitset bitset(vocab_size, data_ptr);
auto logits_ptr = reinterpret_cast<float*>(logits->data) + idx * vocab_size;
for (int i = bitset.FindFirstZero(); i != -1; i = bitset.FindNextZero(i)) {
logits_ptr[i] = -std::numeric_limits<float>::infinity();
}
}
}

/*
* Note on the matching algorithm
*
Expand Down
3 changes: 3 additions & 0 deletions cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,7 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
)
.def("_regex_to_ebnf", &RegexToEBNF)
.def("_get_masked_tokens_from_bitmask", &Matcher_DebugGetMaskedTokensFromBitmask);

auto pyKernelsModule = m.def_submodule("kernels");
pyKernelsModule.def("apply_token_bitmask_inplace_cpu", &Kernels_ApplyTokenBitmaskInplaceCPU);
}
34 changes: 34 additions & 0 deletions cpp/pybind/python_methods.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <xgrammar/xgrammar.h>

#include <algorithm>
#include <array>
#include <chrono>
#include <cstdint>
#include <iostream>
Expand Down Expand Up @@ -92,4 +93,37 @@ std::vector<int> Matcher_DebugGetMaskedTokensFromBitmask(
return result;
}

void Kernels_ApplyTokenBitmaskInplaceCPU(
intptr_t logits_ptr,
std::pair<int64_t, int64_t> logits_shape,
intptr_t bitmask_ptr,
std::pair<int64_t, int64_t> bitmask_shape,
std::optional<std::vector<int>> indices
) {
std::array<int64_t, 2> logits_shape_arr = {logits_shape.first, logits_shape.second};
std::array<int64_t, 2> bitmask_shape_arr = {bitmask_shape.first, bitmask_shape.second};

DLTensor logits_dltensor{
reinterpret_cast<void*>(logits_ptr),
DLDevice{kDLCPU, 0},
2,
DLDataType{kDLFloat, 32, 1},
logits_shape_arr.data(),
nullptr,
0
};

DLTensor bitmask_dltensor{
reinterpret_cast<void*>(bitmask_ptr),
DLDevice{kDLCPU, 0},
2,
GetBitmaskDLType(),
bitmask_shape_arr.data(),
nullptr,
0
};

ApplyTokenBitmaskInplaceCPU(&logits_dltensor, bitmask_dltensor, indices);
}

} // namespace xgrammar
8 changes: 8 additions & 0 deletions cpp/pybind/python_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ std::vector<int> Matcher_DebugGetMaskedTokensFromBitmask(
intptr_t token_bitmask_ptr, std::vector<int64_t> shape, int32_t vocab_size, int32_t index
);

void Kernels_ApplyTokenBitmaskInplaceCPU(
intptr_t logits_ptr,
std::pair<int64_t, int64_t> logits_shape,
intptr_t bitmask_ptr,
std::pair<int64_t, int64_t> bitmask_shape,
std::optional<std::vector<int>> indices
);

} // namespace xgrammar

#endif // XGRAMMAR_PYBIND_PYTHON_METHODS_H_
6 changes: 6 additions & 0 deletions include/xgrammar/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ void _DebugGetMaskedTokensFromBitmask(
std::vector<int>* rejected_tokens, const DLTensor& token_bitmask, int vocab_size, int index = 0
);

void ApplyTokenBitmaskInplaceCPU(
DLTensor* logits,
const DLTensor& bitmask,
std::optional<std::vector<int>> indices = std::nullopt
);

/*!
* \brief A stateful matcher to match tokens to the specified BNF grammar. This class is the core
* logic of the grammar-guided generation.
Expand Down
1 change: 1 addition & 0 deletions python/xgrammar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
apply_token_bitmask_inplace,
bitmask_dtype,
get_bitmask_shape,
reset_token_bitmask,
)
from .tokenizer_info import TokenizerInfo, VocabType
1 change: 0 additions & 1 deletion python/xgrammar/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""The kernels for XGrammar."""

from .apply_token_bitmask_inplace_cpu import apply_token_bitmask_inplace_cpu
from .apply_token_bitmask_inplace_cuda import apply_token_bitmask_inplace_cuda
from .apply_token_bitmask_inplace_triton import apply_token_bitmask_inplace_triton
66 changes: 25 additions & 41 deletions python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,41 @@
"""CPU implementation for in-place applying token mask."""

import time
from typing import List, Optional, Union

import torch


def _bitmask_to_bool_mask(bitmask: torch.Tensor, vocab_size: int) -> torch.Tensor:
bits_per_block = 32
bitmask_size = bitmask.size(-1)
# Expand bitmask to bits
shifts = torch.arange(bits_per_block, device=bitmask.device, dtype=torch.int32)
bits = (bitmask.unsqueeze(-1) >> shifts) & 1 # Shape (*, bits_per_block)
bits = bits.view(bitmask.size(0), -1) # Shape (batch_size, bitmask_size * bits_per_block)
bool_mask = bits[:, :vocab_size].to(torch.bool) # Truncate to vocab_size
return bool_mask
from ..base import _core


def apply_token_bitmask_inplace_cpu(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
) -> None:
"""Exactly the same as `apply_token_bitmask_inplace()`, but `logits` is on the CPU.
So we use CPU implementation rather than launching a CUDA kernel.
"""
"""Apply token bitmask in-place on CPU."""
if logits.device.type != "cpu":
raise ValueError("logits must be on CPU")
if bitmask.device != logits.device:
raise ValueError("bitmask must be on the same device as logits")
if bitmask.dim() != logits.dim():
raise ValueError(
f"bitmask and logits must have the same number of dimensions, but "
+ f"got {bitmask.dim()} and {logits.dim()}"
)
if bitmask.device.type != "cpu":
raise ValueError("bitmask must be on CPU")
if logits.dtype != torch.float32:
raise ValueError("logits must be of type float32")
if bitmask.dtype != torch.int32:
raise ValueError("bitmask must be of type int32")
if logits.dim() != 1 and logits.dim() != 2:
raise ValueError("logits should be 1D or 2D, but got {}D".format(logits.dim()))
if bitmask.dim() != 1 and bitmask.dim() != 2:
raise ValueError("Unsupported logits and bitmask dimensions: {}".format(bitmask.dim()))

# Expand both to (batch_size, xxx_size), where batch_size is 1
if logits.dim() == 1:
logits = logits.unsqueeze(0)
bitmask = bitmask.unsqueeze(0)

batch_size, vocab_size = logits.size()
if indices is None:
if batch_size != bitmask.size(0):
raise ValueError("Batch size of logits and bitmask must match")
bool_mask = _bitmask_to_bool_mask(bitmask, vocab_size) # Shape (batch_size, vocab_size)
logits.masked_fill_(~bool_mask, -float("inf"))
else:
if not isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
len_indices = len(indices)
if len_indices != bitmask.size(0):
raise ValueError("The length of indices and bitmask's batch size must match.")
bool_mask = _bitmask_to_bool_mask(bitmask, vocab_size) # Shape (len_indices, vocab_size)
logits[indices] = logits[indices].masked_fill_(~bool_mask, -float("inf"))
raise ValueError("bitmask should be 1D or 2D, but got {}D".format(bitmask.dim()))

logits_shape = (1, logits.shape[0]) if logits.dim() == 1 else (logits.shape[0], logits.shape[1])
bitmask_shape = (
(1, bitmask.shape[0]) if bitmask.dim() == 1 else (bitmask.shape[0], bitmask.shape[1])
)

_core.kernels.apply_token_bitmask_inplace_cpu(
logits.data_ptr(),
logits_shape,
bitmask.data_ptr(),
bitmask_shape,
indices,
)
Loading
Loading