Skip to content

Commit

Permalink
Remove try-catch block, make GenerateRangeRegex a static member, and …
Browse files Browse the repository at this point in the history
…update tests
  • Loading branch information
Ubospica committed Dec 6, 2024
1 parent adb76d2 commit 2ad33d4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
66 changes: 31 additions & 35 deletions cpp/json_schema_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class JSONSchemaConverter {
* will be ignored when finding the corresponding cache rule. */
std::string GetSchemaCacheIndex(const picojson::value& schema);

/*! \brief Generate the regex for the range. */
static std::string GenerateRangeRegex(std::optional<int> start, std::optional<int> end);

/*!
* \brief Create a rule with the given schema and rule name hint.
* \returns The name of the rule will be returned. That is not necessarily the same as the
Expand Down Expand Up @@ -654,7 +657,9 @@ std::string JSONSchemaConverter::VisitAny(
kBasicArray + " | " + kBasicObject;
}

std::string generateRangeRegex(std::optional<int> start, std::optional<int> end) {
std::string JSONSchemaConverter::GenerateRangeRegex(
std::optional<int> start, std::optional<int> end
) {
if (!start && !end) {
return "^\\d+$"; // Match any positive number if no start or end is specified
}
Expand Down Expand Up @@ -782,34 +787,30 @@ std::string JSONSchemaConverter::VisitInteger(
}
);
std::string range_regex = "";
try {
if (schema.count("minimum") || schema.count("maximum") || schema.count("exclusiveMinimum") ||
schema.count("exclusiveMaximum")) {
std::optional<int> start, end;
if (schema.count("minimum")) {
double start_double = schema.at("minimum").get<double>();
start = static_cast<int>(start_double);
}
if (schema.count("exclusiveMinimum")) {
double start_double = schema.at("exclusiveMinimum").get<double>();
start = static_cast<int>(start_double);
}
if (schema.count("maximum")) {
double end_double = schema.at("maximum").get<double>();
end = static_cast<int>(end_double);
}
if (schema.count("exclusiveMaximum")) {
double end_double = schema.at("exclusiveMaximum").get<double>();
end = static_cast<int>(end_double);
}
range_regex = generateRangeRegex(start, end);
if (schema.count("minimum") || schema.count("maximum") || schema.count("exclusiveMinimum") ||
schema.count("exclusiveMaximum")) {
std::optional<int> start, end;
if (schema.count("minimum")) {
double start_double = schema.at("minimum").get<double>();
start = static_cast<int>(start_double);
}
if (schema.count("exclusiveMinimum")) {
double start_double = schema.at("exclusiveMinimum").get<double>();
start = static_cast<int>(start_double);
}
if (schema.count("maximum")) {
double end_double = schema.at("maximum").get<double>();
end = static_cast<int>(end_double);
}
if (!range_regex.empty()) {
std::string converted_regex = RegexToEBNF(range_regex, false);
return converted_regex; // not " " for numbers
if (schema.count("exclusiveMaximum")) {
double end_double = schema.at("exclusiveMaximum").get<double>();
end = static_cast<int>(end_double);
}
} catch (const std::exception& e) {
XGRAMMAR_LOG(WARNING) << "Failed to convert range for integer schema";
range_regex = GenerateRangeRegex(start, end);
}
if (!range_regex.empty()) {
std::string converted_regex = RegexToEBNF(range_regex, false);
return converted_regex; // not " " for numbers
}
return "(\"0\" | \"-\"? [1-9] [0-9]*)";
}
Expand Down Expand Up @@ -846,14 +847,9 @@ std::string JSONSchemaConverter::VisitString(
}
);
if (schema.count("pattern")) {
try {
std::string regex_pattern = schema.at("pattern").get<std::string>();
std::string converted_regex = RegexToEBNF(regex_pattern, false);
return "\"\\\"\" " + converted_regex + " \"\\\"\"";
} catch (const std::exception& e) {
XGRAMMAR_LOG(WARNING) << "Failed to convert regex pattern "
<< schema.at("pattern").get<std::string>();
}
std::string regex_pattern = schema.at("pattern").get<std::string>();
std::string converted_regex = RegexToEBNF(regex_pattern, false);
return "\"\\\"\" " + converted_regex + " \"\\\"\"";
}
return "[\"] " + kBasicStringSub;
}
Expand Down
7 changes: 5 additions & 2 deletions tests/python/test_json_schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,18 +495,21 @@ class MainModel(BaseModel):


def test_complex_restrictions() -> None:

string_without_quotes = Annotated[str, WithJsonSchema({"type": "string", "pattern": r"[^\"]*"})]

class RestrictedModel(BaseModel):
restricted_string: string_without_quotes
restricted_value: Annotated[int, Field(strict=True, ge=0, lt=44)]

# working instance
instance = RestrictedModel(restricted_string="a", restricted_value=42)
instance = RestrictedModel(restricted_string="abd", restricted_value=42)
instance_str = json.dumps(instance.model_dump(mode="json"))
check_schema_with_json(RestrictedModel.model_json_schema(), instance_str)

instance_err = RestrictedModel(restricted_string='"', restricted_value=42)
instance_str = json.dumps(instance_err.model_dump(mode="json"))
check_schema_with_json(RestrictedModel.model_json_schema(), instance_str, check_accepted=False)

check_schema_with_json(
RestrictedModel.model_json_schema(),
'{"restricted_string": "j", "restricted_value": 45}',
Expand Down

0 comments on commit 2ad33d4

Please sign in to comment.