From 7b31407e8c1a309176f1936eda293a471dd66fa4 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Sat, 23 Nov 2024 20:50:15 -0500 Subject: [PATCH] [Fix] Fix C++ binding for fill_next_token_bitmask (#93) This PR fixes the problem in fill_next_token_bitmask related to C++ binding. --- python/xgrammar/matcher.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/xgrammar/matcher.py b/python/xgrammar/matcher.py index 9cdf645..74da739 100644 --- a/python/xgrammar/matcher.py +++ b/python/xgrammar/matcher.py @@ -204,7 +204,11 @@ def fill_next_token_bitmask(self, bitmask: torch.Tensor, index: int = 0) -> None index : int, default: 0 The batch id of the bitmask. """ - self._handle.fill_next_token_bitmask(bitmask, index) + if bitmask.device.type != "cpu": + raise ValueError("bitmask should be on CPU.") + if bitmask.dtype != bitmask_dtype: + raise ValueError(f"bitmask should be of type {bitmask_dtype}.") + self._handle.fill_next_token_bitmask(bitmask.data_ptr(), list(bitmask.shape), index) def find_jump_forward_string(self) -> str: """Find the jump-forward string for jump-forward decoding. This is the longest string that