Skip to content

Commit

Permalink
[Refactor] Rename compile_bnf_grammar to compile_grammar (#73)
Browse files Browse the repository at this point in the history
This PR renames compile_bnf_grammar to compile_grammar to reflect the latest API changes.
  • Loading branch information
Ubospica authored Nov 21, 2024
1 parent a99cb18 commit 3d37bc2
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 17 deletions.
6 changes: 3 additions & 3 deletions cpp/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class GrammarCompiler::Impl {
cache_enabled_(cache_enabled),
compile_json_schema_cache_(GetCompileJSONSchemaCacheFunc(cache_enabled_)),
compile_builtin_json_grammar_cache_(GetCompileBuiltinJSONGrammarCacheFunc(cache_enabled_)),
compile_bnf_grammar_cache_(GetCompileGrammarCacheFunc(cache_enabled_)) {}
compile_grammar_cache_(GetCompileGrammarCacheFunc(cache_enabled_)) {}

CompiledGrammar CompileBuiltinJSONGrammar();

Expand Down Expand Up @@ -371,7 +371,7 @@ class GrammarCompiler::Impl {
/*! \brief The cache for the compiled grammar for JSON. */
ThreadSafeCache<CompiledGrammar> compile_builtin_json_grammar_cache_;
/*! \brief The cache for the compiled grammar for bnf grammar. */
ThreadSafeCache<GrammarKey, CompiledGrammar> compile_bnf_grammar_cache_;
ThreadSafeCache<GrammarKey, CompiledGrammar> compile_grammar_cache_;
};

CompiledGrammar GrammarCompiler::Impl::CompileBuiltinJSONGrammar() {
Expand Down Expand Up @@ -406,7 +406,7 @@ CompiledGrammar GrammarCompiler::Impl::CompileGrammar(const Grammar& grammar) {
return MultiThreadCompileGrammar(grammar, tokenizer_info_, max_threads_);
}
auto key = std::make_pair(grammar.ToString(), grammar->GetRootRule().name);
return compile_bnf_grammar_cache_.Get(key);
return compile_grammar_cache_.Get(key);
}

void GrammarCompiler::Impl::ClearCache() {
Expand Down
2 changes: 1 addition & 1 deletion cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
py::call_guard<py::gil_scoped_release>()
)
.def(
"compile_bnf_grammar",
"compile_grammar",
&GrammarCompiler::CompileGrammar,
py::call_guard<py::gil_scoped_release>()
)
Expand Down
12 changes: 4 additions & 8 deletions python/xgrammar/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,18 @@ def compile_builtin_json_grammar(self) -> CompiledGrammar:
return CompiledGrammar._create_from_handle(self._handle.compile_builtin_json_grammar())

@overload
def compile_bnf_grammar(
self, grammar: str, *, root_rule_name: str = "root"
) -> CompiledGrammar: ...
def compile_grammar(self, grammar: str, *, root_rule_name: str = "root") -> CompiledGrammar: ...

@overload
def compile_bnf_grammar(self, grammar: Grammar) -> CompiledGrammar: ...
def compile_grammar(self, grammar: Grammar) -> CompiledGrammar: ...

def compile_bnf_grammar(
def compile_grammar(
self, grammar: Union[str, Grammar], *, root_rule_name: str = "root"
) -> CompiledGrammar:
"""Compile a BNF grammar."""
if isinstance(grammar, str):
grammar = Grammar.from_ebnf(grammar, root_rule_name=root_rule_name)
return CompiledGrammar._create_from_handle(
self._handle.compile_bnf_grammar(grammar._handle)
)
return CompiledGrammar._create_from_handle(self._handle.compile_grammar(grammar._handle))

def clear_cache(self) -> None:
"""Clear all cached compiled grammars."""
Expand Down
4 changes: 2 additions & 2 deletions python/xgrammar/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _match_grammar_with_string(
if isinstance(grammar, str):
grammar = Grammar.from_ebnf(grammar)
grammar_compiler = GrammarCompiler(TokenizerInfo([]), cache_enabled=False)
compiled_grammar = grammar_compiler.compile_bnf_grammar(grammar)
compiled_grammar = grammar_compiler.compile_grammar(grammar)
matcher = GrammarMatcher(compiled_grammar, terminate_without_stop_token=True)
if not matcher._debug_accept_string(input_str, debug_print=debug_print):
return False
Expand Down Expand Up @@ -147,5 +147,5 @@ def _get_matcher_from_grammar_and_tokenizer_info(
if tokenizer_info is None:
tokenizer_info = TokenizerInfo([])
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
compiled_grammar = grammar_compiler.compile_bnf_grammar(grammar)
compiled_grammar = grammar_compiler.compile_grammar(grammar)
return GrammarMatcher(compiled_grammar, **kwargs)
2 changes: 1 addition & 1 deletion tests/python/test_custom_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_fill_next_token_bitmask(
compiler = xgr.GrammarCompiler(tokenizer_info)

time_start = time.monotonic_ns()
matcher = xgr.GrammarMatcher(compiler.compile_bnf_grammar(json_grammar_ebnf))
matcher = xgr.GrammarMatcher(compiler.compile_grammar(json_grammar_ebnf))
time_end = time.monotonic_ns()
print(f"Time to init GrammarMatcher: {(time_end - time_start) / 1e3} us")

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_grammar_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_compiled_grammar():
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
compiler = xgr.GrammarCompiler(tokenizer_info)
time_start = time.monotonic_ns()
context = compiler.compile_bnf_grammar(grammar)
context = compiler.compile_grammar(grammar)
time_end = time.monotonic_ns()
print(f"Time to get compiled grammar: {(time_end - time_start) / 1e3} us")

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_regex_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_mask_generation(tokenizer_path: str, regex: str, instance: str):
grammar_compiler = xgr.GrammarCompiler(tokenizer_info, cache_enabled=False)

time_start = time.monotonic_ns()
matcher_compiled_grammar = grammar_compiler.compile_bnf_grammar(_regex_to_ebnf(regex))
matcher_compiled_grammar = grammar_compiler.compile_grammar(_regex_to_ebnf(regex))
time_end = time.monotonic_ns()
print(f"Time for preprocessing: {(time_end - time_start) / 1e3} us")
matcher = xgr.GrammarMatcher(matcher_compiled_grammar)
Expand Down

0 comments on commit 3d37bc2

Please sign in to comment.