Skip to content

Commit

Permalink
[Bitmask] Add allocate_token_bitmask to xgr from testing (#72)
Browse files Browse the repository at this point in the history
This PR changes from `from xgrammar.testing import
_allocate_token_bitmask` to `from xgrammar import
allocate_token_bitmask`.

Passed all test with `pytest .`
  • Loading branch information
CharlieFRuan authored Nov 21, 2024
1 parent 3d37bc2 commit 035495f
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 41 deletions.
4 changes: 3 additions & 1 deletion python/xgrammar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from .matcher import (
GrammarMatcher,
apply_token_bitmask_inplace,
get_bitmask_dtype,
get_bitmask_shape,
allocate_token_bitmask,
bitmask_dtype,
)
from .tokenizer_info import TokenizerInfo, VocabType
from . import testing
41 changes: 29 additions & 12 deletions python/xgrammar/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,48 @@
apply_token_bitmask_inplace as apply_token_bitmask_inplace_cuda,
)

bitmask_dtype = torch.int32


def get_bitmask_shape(batch_size: int, vocab_size: int) -> Tuple[int, int]:
"""Allocate the bitmask for the next token prediction. The bitmask is a int32 tensor on CPU
with shape (batch_size, ceil(vocab_size / 32)). If the batch size is None, the bitmask is
a 1D tensor with shape (ceil(vocab_size / 32),).
"""Return the shape of the bitmask (batch_size, ceil(vocab_size / 32))"""
return (batch_size, math.ceil(vocab_size / 32))


def allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor:
"""Allocate the bitmask for the next token prediction. The bitmask is an int32 tensor on CPU
with shape (batch_size, ceil(vocab_size / 32)). This function defaults to
.. code:: python
return torch.empty(
xgr.get_bitmask_shape(batch_size, vocab_size),
dtype=xgr.bitmask_dtype,
pin_memory=True,
)
Parameters
----------
batch_size : int
The batch size of the bitmask.
vocab_size : int
The size of the vocabulary.
batch_size : Optional[int], default: None
The batch size of the bitmask. If None, the bitmask is a 1D tensor.
Returns
-------
bitmask : torch.Tensor
The shape of the bitmask.
"""
return (batch_size, math.ceil(vocab_size / 32))

def get_bitmask_dtype() -> torch.dtype:
"""Get the dtype of the bitmask."""
return torch.int32
Note
----
This is the default way of allocating a bitmask. You can also customize the implementation.
"""
return torch.empty(
get_bitmask_shape(batch_size, vocab_size),
dtype=bitmask_dtype,
pin_memory=True,
)


def apply_token_bitmask_inplace(
Expand Down
10 changes: 1 addition & 9 deletions python/xgrammar/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .base import _core
from .compiler import GrammarCompiler
from .grammar import Grammar
from .matcher import GrammarMatcher, get_bitmask_dtype, get_bitmask_shape
from .matcher import GrammarMatcher, get_bitmask_shape
from .tokenizer_info import TokenizerInfo


Expand Down Expand Up @@ -110,14 +110,6 @@ def _match_grammar_with_string(
return matcher.is_terminated()


def _allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor:
return torch.empty(
get_bitmask_shape(batch_size, vocab_size),
dtype=get_bitmask_dtype(),
pin_memory=True,
)


def _get_masked_tokens_from_bitmask(
bitmask: torch.Tensor, vocab_size: int, index: int = 0
) -> List[int]:
Expand Down
11 changes: 11 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
To test, run `pytest .` under `xgrammar` folder. You may need to do the following:

```bash
pip install sentencepiece
pip install protobuf
pip install -U "huggingface_hub[cli]"
huggingface-cli login --token YOUR_HF_TOKEN
```

Make sure you also have access to the gated models, which should only require you to agree
some terms on the models' website on huggingface.
3 changes: 1 addition & 2 deletions tests/python/test_builtin_grammar_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import xgrammar as xgr
from xgrammar.testing import (
_allocate_token_bitmask,
_get_masked_tokens_from_bitmask,
_match_grammar_with_string,
)
Expand Down Expand Up @@ -290,7 +289,7 @@ def test_fill_next_token_bitmask(
time_end = time.monotonic_ns()
print(f"Time to init GrammarMatcher: {(time_end - time_start) / 1e3} us")

token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
logits_gpu = torch.zeros(tokenizer_info.vocab_size, dtype=torch.float32, device="cuda")

input_bytes = input_str.encode("utf-8")
Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_builtin_grammar_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import xgrammar as xgr
from xgrammar.testing import (
_allocate_token_bitmask,
_get_masked_tokens_from_bitmask,
_get_matcher_from_grammar_and_tokenizer_info,
)
Expand Down Expand Up @@ -98,7 +97,7 @@ def test_fill_next_token_bitmask(tokenizer_path: str):
time_end = time.monotonic_ns()
print(f"Time to init GrammarMatcher: {(time_end - time_start) / 1e3} us")

token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
logits_gpu = torch.zeros(tokenizer_info.vocab_size, dtype=torch.float32, device="cuda")

input_bytes = instance_str.encode("utf-8")
Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_custom_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import xgrammar as xgr
from xgrammar.testing import (
_allocate_token_bitmask,
_get_masked_tokens_from_bitmask,
_match_grammar_with_string,
)
Expand Down Expand Up @@ -335,7 +334,7 @@ def test_fill_next_token_bitmask(
time_end = time.monotonic_ns()
print(f"Time to init GrammarMatcher: {(time_end - time_start) / 1e3} us")

token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
logits_gpu = torch.zeros(tokenizer_info.vocab_size, dtype=torch.float32, device="cuda")

input_bytes = input_str.encode("utf-8")
Expand Down
21 changes: 10 additions & 11 deletions tests/python/test_grammar_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import xgrammar as xgr
from xgrammar.testing import (
_allocate_token_bitmask,
_get_masked_tokens_from_bitmask,
_get_matcher_from_grammar_and_tokenizer_info,
_match_grammar_with_string,
Expand Down Expand Up @@ -84,7 +83,7 @@ def test_fill_next_token_bitmask(
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
matcher = _get_matcher_from_grammar_and_tokenizer_info(json_grammar, tokenizer_info)

token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)

input_bytes = input_str.encode("utf-8")
rejected_sizes = []
Expand Down Expand Up @@ -121,7 +120,7 @@ def test_token_operations():

tokenizer_info = xgr.TokenizerInfo(vocab)
matcher = _get_matcher_from_grammar_and_tokenizer_info(json_grammar, tokenizer_info)
token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)

expected = [
["{"],
Expand Down Expand Up @@ -241,22 +240,22 @@ def test_rollback():

for i_1, i_2 in input_ids_splitted:
orig_result = []
token_bitmask1 = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask1 = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
matcher.fill_next_token_bitmask(token_bitmask1)
orig_result.append(token_bitmask1)
assert matcher.accept_token(i_1)
token_bitmask2 = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask2 = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
matcher.fill_next_token_bitmask(token_bitmask2)
orig_result.append(token_bitmask2)
assert matcher.accept_token(i_2)

matcher.rollback(2)
result_after_rollback = []
new_token_bitmask1 = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
new_token_bitmask1 = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
matcher.fill_next_token_bitmask(new_token_bitmask1)
result_after_rollback.append(new_token_bitmask1)
assert matcher.accept_token(i_1)
new_token_bitmask2 = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
new_token_bitmask2 = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
matcher.fill_next_token_bitmask(new_token_bitmask2)
result_after_rollback.append(new_token_bitmask2)
assert matcher.accept_token(i_2)
Expand All @@ -278,7 +277,7 @@ def test_reset():
orig_result = []

for i in input_ids:
token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
matcher.fill_next_token_bitmask(token_bitmask)
orig_result.append(token_bitmask)
assert matcher.accept_token(i)
Expand All @@ -288,7 +287,7 @@ def test_reset():
result_after_reset = []

for i in input_ids:
token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
matcher.fill_next_token_bitmask(token_bitmask)
result_after_reset.append(token_bitmask)
assert matcher.accept_token(i)
Expand Down Expand Up @@ -321,7 +320,7 @@ def test_termination():
matcher = _get_matcher_from_grammar_and_tokenizer_info(
json_grammar, tokenizer_info, max_rollback_tokens=5
)
token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)

for i in input_ids:
matcher.fill_next_token_bitmask(token_bitmask)
Expand Down Expand Up @@ -360,7 +359,7 @@ def test_vocab_size():
tokenizer_info = xgr.TokenizerInfo(vocab, vocab_size=64)
matcher = _get_matcher_from_grammar_and_tokenizer_info(json_grammar, tokenizer_info)

token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
matcher.fill_next_token_bitmask(token_bitmask)
assert token_bitmask.shape == (1, 2)

Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_regex_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import xgrammar as xgr
from xgrammar.testing import (
_allocate_token_bitmask,
_match_grammar_with_string,
_regex_to_ebnf,
)
Expand Down Expand Up @@ -309,7 +308,7 @@ def test_mask_generation(tokenizer_path: str, regex: str, instance: str):
time_end = time.monotonic_ns()
print(f"Time for preprocessing: {(time_end - time_start) / 1e3} us")
matcher = xgr.GrammarMatcher(matcher_compiled_grammar)
token_bitmask = _allocate_token_bitmask(1, tokenizer_info.vocab_size)
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)

for c in instance.encode("utf-8"):
time_start = time.monotonic_ns()
Expand Down

0 comments on commit 035495f

Please sign in to comment.