Skip to content

Commit

Permalink
[CUDA] JIT compilation for ApplyTokenBitmask kernel
Browse files Browse the repository at this point in the history
This PR introduces the JIT compilation for the CUDA kernel of
ApplyTokenBitmask. The JIT compilation is enabled by the cuda-python
package.

With JIT compilation, we can remove the AOT kernel compilation
whcih introduces extra dependency when building the package.
  • Loading branch information
MasterJH5574 committed Nov 18, 2024
1 parent 33e3dc9 commit 642a209
Show file tree
Hide file tree
Showing 15 changed files with 224 additions and 249 deletions.
6 changes: 0 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ endif()

option(XGRAMMAR_BUILD_PYTHON_BINDINGS "Build Python bindings" ON)
option(XGRAMMAR_BUILD_CXX_TESTS "Build C++ tests" ON)
option(XGRAMMAR_BUILD_CUDA_KERNELS "Build CUDA kernels" ON)
set(XGRAMMAR_CUDA_ARCHITECTURES
native
CACHE STRING "CUDA architectures"
Expand All @@ -37,7 +36,6 @@ endif()
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Build Python bindings: ${XGRAMMAR_BUILD_PYTHON_BINDINGS}")
message(STATUS "Build C++ tests: ${XGRAMMAR_BUILD_CXX_TESTS}")
message(STATUS "Build CUDA kernels: ${XGRAMMAR_BUILD_CUDA_KERNELS}")
message(STATUS "CUDA architectures: ${XGRAMMAR_CUDA_ARCHITECTURES}")

if(MSVC)
Expand All @@ -63,10 +61,6 @@ list(FILTER XGRAMMAR_SOURCES_PATH EXCLUDE REGEX "${PROJECT_SOURCE_DIR}/cpp/pybin
add_library(xgrammar STATIC ${XGRAMMAR_SOURCES_PATH})
target_include_directories(xgrammar PUBLIC ${XGRAMMAR_INCLUDE_PATH})

if(XGRAMMAR_BUILD_KERNELS)
add_subdirectory(${PROJECT_SOURCE_DIR}/cpp/kernels)
endif()

if(XGRAMMAR_BUILD_PYTHON_BINDINGS)
add_subdirectory(${PROJECT_SOURCE_DIR}/cpp/pybind)
endif()
Expand Down
1 change: 0 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
set(CMAKE_BUILD_TYPE RelWithDebInfo)
set(XGRAMMAR_BUILD_PYTHON_BINDINGS ON)
set(XGRAMMAR_BUILD_CXX_TESTS OFF)
set(XGRAMMAR_BUILD_KERNELS ON)
# set it to your own architecture
set(XGRAMMAR_CUDA_ARCHITECTURES
native
Expand Down
29 changes: 0 additions & 29 deletions cpp/kernels/CMakeLists.txt

This file was deleted.

110 changes: 0 additions & 110 deletions cpp/kernels/apply_token_mask_inplace.cu

This file was deleted.

20 changes: 0 additions & 20 deletions cpp/kernels/kernels.h

This file was deleted.

9 changes: 0 additions & 9 deletions cpp/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,9 @@ find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib
# -D_GLIBCXX_USE_CXX11_ABI=0. So we compile bindings separately.
file(GLOB_RECURSE XGRAMMAR_BINDINGS_PATH ${PROJECT_SOURCE_DIR}/cpp/*.cc)

if(XGRAMMAR_BUILD_KERNELS)
file(GLOB_RECURSE XGRAMMAR_KERNELS_PATH ${PROJECT_SOURCE_DIR}/cpp/kernels/*.cu)
list(APPEND XGRAMMAR_BINDINGS_PATH ${XGRAMMAR_KERNELS_PATH})
endif()

pybind11_add_module(xgrammar_bindings ${XGRAMMAR_BINDINGS_PATH})
target_include_directories(xgrammar_bindings PUBLIC ${XGRAMMAR_INCLUDE_PATH})

if(XGRAMMAR_BUILD_KERNELS)
target_compile_definitions(xgrammar_bindings PUBLIC -DXGRAMMAR_BUILD_KERNELS)
endif()

target_link_libraries(xgrammar_bindings PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
set(LIB_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/python/xgrammar")
set_target_properties(xgrammar_bindings PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${LIB_OUTPUT_DIRECTORY})
Expand Down
5 changes: 0 additions & 5 deletions cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,4 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
.def_property_readonly("vocab_size", &GrammarMatcher::GetVocabSize)
.def_property_readonly("max_rollback_tokens", &GrammarMatcher::GetMaxRollbackTokens)
.def_property_readonly("stop_token_ids", &GrammarMatcher::GetStopTokenIds);
#ifdef XGRAMMAR_BUILD_KERNELS
pyGrammarMatcher.def_static(
"apply_token_bitmask_inplace", &GrammarMatcher_ApplyTokenBitmaskInplace
);
#endif
}
56 changes: 0 additions & 56 deletions cpp/pybind/python_methods.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "python_methods.h"

#include <ATen/DLConvertor.h>
#include <xgrammar/xgrammar.h>

#include <algorithm>
Expand All @@ -16,10 +15,6 @@
#include "../support/dynamic_bitset.h"
#include "../support/logging.h"

#ifdef XGRAMMAR_BUILD_KERNELS
#include "../kernels/kernels.h"
#endif

namespace xgrammar {

// Parse the EBNF string but not normalize it
Expand Down Expand Up @@ -130,55 +125,4 @@ std::vector<int> GrammarMatcher_DebugGetMaskedTokensFromBitmask(
return result;
}

#ifdef XGRAMMAR_BUILD_KERNELS
void GrammarMatcher_ApplyTokenBitmaskInplace(torch::Tensor logits, torch::Tensor token_bitmask) {
auto logits_shape = logits.sizes();
int batch_size = 1;
int vocab_size;
if (logits_shape.size() == 1) {
vocab_size = logits_shape[0];
} else if (logits_shape.size() == 2) {
batch_size = logits_shape[0];
vocab_size = logits_shape[1];
} else {
XGRAMMAR_LOG(FATAL) << "logits tensor must be 1D or 2D";
}

auto bitmask_shape = token_bitmask.sizes();
int expected_bitmask_size = DynamicBitset::GetBufferSize(vocab_size);
if (bitmask_shape.size() == 1) {
XGRAMMAR_CHECK(bitmask_shape[0] == expected_bitmask_size)
<< "The last dimension of the token bitmask tensor must be " << expected_bitmask_size
<< ", but got " << bitmask_shape[0];
} else if (bitmask_shape.size() == 2) {
XGRAMMAR_CHECK(bitmask_shape[0] == batch_size)
<< "The first dimension of the token bitmask tensor must be " << batch_size << ", but got "
<< bitmask_shape[0];
XGRAMMAR_CHECK(bitmask_shape[1] == expected_bitmask_size)
<< "The last dimension of the token bitmask tensor must be " << expected_bitmask_size
<< ", but got " << bitmask_shape[1];
} else {
XGRAMMAR_LOG(FATAL) << "token_bitmask tensor must be 1D or 2D";
}

DTypeFlag dtype_flag;
if (logits.dtype() == torch::kFloat16) {
dtype_flag = DTypeFlag::DTYPE_FLOAT16;
} else if (logits.dtype() == torch::kFloat32) {
dtype_flag = DTypeFlag::DTYPE_FLOAT32;
} else if (logits.dtype() == torch::kFloat64) {
dtype_flag = DTypeFlag::DTYPE_FLOAT64;
} else {
XGRAMMAR_LOG(FATAL) << "logits tensor must be of type float16, float32, or float64";
}

XGRAMMAR_CHECK(token_bitmask.dtype() == torch::kInt32)
<< "token bitmask tensor must be of type int32";

ApplyTokenBitmaskInplace(
logits.data_ptr(), dtype_flag, token_bitmask.data_ptr<int32_t>(), batch_size, vocab_size
);
}
#endif

} // namespace xgrammar
4 changes: 0 additions & 4 deletions cpp/pybind/python_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ std::vector<int> GrammarMatcher_DebugGetMaskedTokensFromBitmask(
GrammarMatcher& matcher, torch::Tensor token_bitmask, int batch_id
);

#ifdef XGRAMMAR_BUILD_KERNELS
void GrammarMatcher_ApplyTokenBitmaskInplace(torch::Tensor logits, torch::Tensor token_bitmask);
#endif

} // namespace xgrammar

#endif // XGRAMMAR_PYBIND_PYTHON_METHODS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

__version__ = "0.0.3"
Loading

0 comments on commit 642a209

Please sign in to comment.