Skip to content

Commit

Permalink
[Refactor] Introduce EBNFScriptCreator (#133)
Browse files Browse the repository at this point in the history
This PR adds a class EBNFScriptCreator to handle construction of ebnf in
converters from other structures to ebnf, with auto renaming.
  • Loading branch information
Ubospica authored Dec 22, 2024
1 parent 8bc7b1d commit 49655f4
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 19 deletions.
86 changes: 86 additions & 0 deletions cpp/ebnf_script_creator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*!
* Copyright (c) 2023 by Contributors
* \file tokenizer.cc
*/
#include "ebnf_script_creator.h"

#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>

#include "support/logging.h"

namespace xgrammar {

class EBNFScriptCreator::Impl {
public:
Impl() {}

std::string AddRule(const std::string& rule_name_hint, const std::string& rule_body);
std::string GetScript();
std::string GetRuleContent(const std::string& rule_name);

private:
std::string GetRuleName(const std::string& rule_name_hint);
std::vector<std::pair<std::string, std::string>> rules_;
std::unordered_set<std::string> rule_names_;
const int NAME_SUFFIX_MAXIMUM = 10000;
};

std::string EBNFScriptCreator::Impl::GetRuleName(const std::string& rule_name_hint) {
if (rule_names_.find(rule_name_hint) == rule_names_.end()) {
rule_names_.insert(rule_name_hint);
return rule_name_hint;
}
for (int i = 0; i < NAME_SUFFIX_MAXIMUM; ++i) {
std::string rule_name = rule_name_hint + "_" + std::to_string(i);
if (rule_names_.find(rule_name) == rule_names_.end()) {
rule_names_.insert(rule_name);
return rule_name;
}
}
XGRAMMAR_LOG(FATAL) << "Cannot find a unique rule name for " << rule_name_hint;
}

std::string EBNFScriptCreator::Impl::AddRule(
const std::string& rule_name_hint, const std::string& rule_body
) {
std::string rule_name = GetRuleName(rule_name_hint);
rules_.emplace_back(rule_name, rule_body);
return rule_name;
}

std::string EBNFScriptCreator::Impl::GetScript() {
std::string script = "";
for (const auto& rule : rules_) {
script += rule.first + " ::= " + rule.second + "\n";
}
return script;
}

std::string EBNFScriptCreator::Impl::GetRuleContent(const std::string& rule_name) {
auto it = std::find_if(rules_.begin(), rules_.end(), [rule_name](const auto& rule) {
return rule.first == rule_name;
});
if (it != rules_.end()) {
return it->second;
}
return "";
}

EBNFScriptCreator::EBNFScriptCreator(EmptyConstructorTag) : pimpl_(std::make_shared<Impl>()) {}

std::string EBNFScriptCreator::AddRule(
const std::string& rule_name_hint, const std::string& rule_body
) {
return pimpl_->AddRule(rule_name_hint, rule_body);
}

std::string EBNFScriptCreator::GetScript() { return pimpl_->GetScript(); }

std::string EBNFScriptCreator::GetRuleContent(const std::string& rule_name) {
return pimpl_->GetRuleContent(rule_name);
}

} // namespace xgrammar
53 changes: 53 additions & 0 deletions cpp/ebnf_script_creator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/ebnf_script_creator.h
* \brief The header for the creating EBNF script.
*/

#ifndef XGRAMMAR_EBNF_SCRIPT_CREATOR_H_
#define XGRAMMAR_EBNF_SCRIPT_CREATOR_H_

#include <xgrammar/object.h>

#include <string>

namespace xgrammar {

/*!
* \brief A class for creating EBNF grammar scripts.
*
* This class helps build EBNF (Extended Backus-Naur Form) grammar scripts
* by managing rules and their content.
*/
class EBNFScriptCreator {
public:
/*! \brief Constructor using empty constructor tag pattern */
EBNFScriptCreator(EmptyConstructorTag);

/*!
* \brief Adds a new rule to the grammar
* \param rule_name_hint Suggested name for the rule
* \param rule_body The EBNF content/definition of the rule
* \return The actual name assigned to the rule
*/
std::string AddRule(const std::string& rule_name_hint, const std::string& rule_body);

/*!
* \brief Gets the complete EBNF grammar script
* \return The full EBNF grammar script as a string
*/
std::string GetScript();

/*!
* \brief Retrieves the content/definition of a specific rule
* \param rule_name The name of the rule to look up
* \return The EBNF content/definition of the specified rule
*/
std::string GetRuleContent(const std::string& rule_name);

XGRAMMAR_DEFINE_PIMPL_METHODS(EBNFScriptCreator);
};

} // namespace xgrammar

#endif // XGRAMMAR_EBNF_SCRIPT_CREATOR_H_
28 changes: 13 additions & 15 deletions cpp/json_schema_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <utility>
#include <vector>

#include "ebnf_script_creator.h"
#include "regex_converter.h"
#include "support/logging.h"

Expand Down Expand Up @@ -336,6 +337,8 @@ class JSONSchemaConverter {
const std::string& rule_name
);

// The EBNF script creator
EBNFScriptCreator ebnf_script_creator_{EmptyConstructorTag{}};
// The indent manager to get separators
std::optional<IndentManager> indentManager_;
// The root JSON schema
Expand All @@ -346,8 +349,6 @@ class JSONSchemaConverter {
bool allow_empty_;
// The colon separator
std::string colon_pattern_;
// The rules constructed
std::vector<std::pair<std::string, std::string>> rules_;
// The cache for basic rules. Mapping from the key of schema returned by GetSchemaCacheIndex()
// to the basic rule name.
std::map<std::string, std::string> basic_rules_cache_;
Expand Down Expand Up @@ -386,11 +387,7 @@ JSONSchemaConverter::JSONSchemaConverter(

std::string JSONSchemaConverter::Convert() {
CreateRuleFromSchema(json_schema_, "root");
std::string res;
for (auto& rule : rules_) {
res += rule.first + " ::= " + rule.second + "\n";
}
return res;
return ebnf_script_creator_.GetScript();
}

void JSONSchemaConverter::AddBasicRules() {
Expand Down Expand Up @@ -433,14 +430,14 @@ void JSONSchemaConverter::AddBasicRules() {
}

void JSONSchemaConverter::AddHelperRules() {
rules_.push_back(std::make_pair(
ebnf_script_creator_.AddRule(
kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]"
));
rules_.push_back(std::make_pair(
);
ebnf_script_creator_.AddRule(
kBasicStringSub,
"(\"\\\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + " | \"\\\\\" " + kBasicEscape + " " +
kBasicStringSub + ") (= [ \\n\\t]* [,}\\]:])"
));
);
}

void JSONSchemaConverter::CreateBasicRule(const picojson::value& schema, const std::string& name) {
Expand Down Expand Up @@ -485,8 +482,9 @@ std::string JSONSchemaConverter::CreateRuleFromSchema(
return basic_rules_cache_[idx];
}

rules_.push_back(std::make_pair(rule_name_hint, VisitSchema(schema, rule_name_hint)));
return rule_name_hint;
std::string rule_name =
ebnf_script_creator_.AddRule(rule_name_hint, VisitSchema(schema, rule_name_hint));
return rule_name;
}

std::string JSONSchemaConverter::GetSchemaCacheIndex(const picojson::value& schema) {
Expand Down Expand Up @@ -1036,7 +1034,7 @@ std::string JSONSchemaConverter::GetPartialRuleForPropertiesAllOptional(
std::string last_rule_body = "(" + mid_sep + " " + additional_prop_pattern + ")*";
std::string last_rule_name =
rule_name + "_part_" + std::to_string(static_cast<int>(properties.size()) - 1);
rules_.push_back(std::make_pair(last_rule_name, last_rule_body));
last_rule_name = ebnf_script_creator_.AddRule(last_rule_name, last_rule_body);
rule_names.back() = last_rule_name;
} else {
rule_names.back() = "\"\"";
Expand All @@ -1049,7 +1047,7 @@ std::string JSONSchemaConverter::GetPartialRuleForPropertiesAllOptional(
std::string cur_rule_body =
last_rule_name + " | " + mid_sep + " " + prop_pattern + " " + last_rule_name;
std::string cur_rule_name = rule_name + "_part_" + std::to_string(i);
rules_.push_back(std::make_pair(cur_rule_name, cur_rule_body));
cur_rule_name = ebnf_script_creator_.AddRule(cur_rule_name, cur_rule_body);
rule_names[i] = cur_rule_name;
}

Expand Down
6 changes: 3 additions & 3 deletions include/xgrammar/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ class CompiledGrammar {
};

/*!
* \brief A cache to get the grammar state compiled grammar for grammar or schema. This class avoids
* \brief A cache to get the compiled grammar for grammar or schema. This class avoids
* redundant preprocessing of the grammar or schema when constructing a CompiledGrammar.
* \note This class is associated with a vocabulary when constructed. The vocabulary is used to
* create every grammar state compiled grammar. If multiple toke tables are used to create init
* create every compiled grammar. If multiple toke tables are used to create init
* contexts, an instance of this class for each vocabulary should be created.
*/
class GrammarCompiler {
public:
/*!
* \brief Construct a GrammarCompiler with a vocabulary. This class will always
* create grammar state compiled grammars with this vocabulary.
* create compiled grammars with this vocabulary.
* \param decoded_vocab The vocabulary that the grammar will use.
*/
GrammarCompiler(
Expand Down
11 changes: 10 additions & 1 deletion include/xgrammar/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@

namespace xgrammar {

/*!
* \brief A tag type for empty constructor.
*
* Since XGRAMMAR_DEFINE_PIMPL_METHODS already occupies the default constructor to
* construct a null object, this tag is used to define an empty constructor for
* the object.
*/
struct EmptyConstructorTag {};

#define XGRAMMAR_DEFINE_PIMPL_METHODS(TypeName) \
public: \
class Impl; \
Expand All @@ -30,7 +39,7 @@ namespace xgrammar {
const Impl* operator->() const { return pimpl_.get(); } \
\
private: \
std::shared_ptr<Impl> pimpl_
std::shared_ptr<Impl> pimpl_;

} // namespace xgrammar

Expand Down

0 comments on commit 49655f4

Please sign in to comment.