Skip to content

Commit

Permalink
Support boolean JSON schemas (#1015)
Browse files Browse the repository at this point in the history
We were previously supporting `True`/`False` schemas only when nested in
certain places like `items` and `additionalProperties`. This expands our
coverage to handle top-level boolean schemas as well.

Note that we will now raise a `ValueError` for `False` schemas, simply
because there is nothing we can generate in this case.
May need to revisit this (see #1018)
  • Loading branch information
hudson-ai authored Sep 10, 2024
1 parent 418fc03 commit 35591d8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 21 deletions.
49 changes: 28 additions & 21 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ._pydantic import pydantic_to_json_schema
from ._subgrammar import lexeme, subgrammar

JSONSchema = Union[bool, Mapping[str, Any]]

def _to_compact_json(target: Any) -> str:
# See 'Compact Encoding':
Expand Down Expand Up @@ -150,8 +151,8 @@ def _gen_json_string(
def _gen_json_object(
lm,
*,
properties: Mapping[str, Any],
additional_properties: Union[bool, Mapping[str, Any]],
properties: Mapping[str, JSONSchema],
additional_properties: JSONSchema,
required: Sequence[str],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
Expand Down Expand Up @@ -206,16 +207,12 @@ def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool
def _gen_json_array(
lm,
*,
prefix_items_schema: Sequence[Mapping[str, Any]],
item_schema: Union[bool, Mapping[str, Any]],
prefix_items_schema: Sequence[JSONSchema],
item_schema: JSONSchema,
min_items: int,
max_items: Optional[int],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
if item_schema is True:
# True means that anything goes
item_schema = {}

if len(prefix_items_schema) < min_items and item_schema is False:
raise ValueError(
f"PrefixItems has too few elements ({len(prefix_items_schema)}) to"
Expand Down Expand Up @@ -282,7 +279,7 @@ def _gen_json_array(
def _process_anyOf(
lm,
*,
anyof_list: Sequence[Mapping[str, Any]],
anyof_list: Sequence[JSONSchema],
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
options = [_gen_json(json_schema=item, definitions=definitions) for item in anyof_list]
Expand Down Expand Up @@ -329,9 +326,14 @@ def _gen_json_any(lm):
@guidance(stateless=True)
def _gen_json(
lm,
json_schema: Mapping[str, Any],
json_schema: JSONSchema,
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
if json_schema is True:
json_schema = {}
elif json_schema is False:
raise ValueError("No valid JSON can be generated from a schema of `False`")

validate_json_node_keys(json_schema)

if Keyword.ANYOF in json_schema:
Expand Down Expand Up @@ -403,7 +405,7 @@ def json(
*,
schema: Union[
None,
Mapping[str, Any],
JSONSchema,
Type["pydantic.BaseModel"],
"pydantic.TypeAdapter",
] = None,
Expand Down Expand Up @@ -457,20 +459,25 @@ def json(
If True, the generated JSON will be forced to be compact (no whitespace).
If False, output will be whitespace-flexible (i.e. decided by the model).
"""
if isinstance(schema, Mapping):
if schema is None:
# Default schema is empty, "anything goes" schema
# TODO: consider default being `{"type": "object"}`
schema = {}
elif isinstance(schema, (Mapping, bool)):
# Raises jsonschema.exceptions.SchemaError or ValueError
# if schema is not valid
jsonschema.validators.Draft202012Validator.check_schema(schema)
elif schema is None:
schema = {}
else:
elif isinstance(schema, pydantic.TypeAdapter) or (isinstance(schema, type) and issubclass(schema, pydantic.BaseModel)):
schema = pydantic_to_json_schema(schema)
else:
raise TypeError(f"Unsupported schema type: {type(schema)}")

definitions: Mapping[str, Callable[[], GrammarFunction]] = {}
for dk in DEFS_KEYS:
if dk in schema:
assert len(definitions) == 0, "Found duplicate definitions"
definitions = _build_definitions(schema[dk])
if isinstance(schema, Mapping):
for dk in DEFS_KEYS:
if dk in schema:
assert len(definitions) == 0, "Found duplicate definitions"
definitions = _build_definitions(schema[dk])

return lm + with_temperature(
subgrammar(
Expand All @@ -488,11 +495,11 @@ def json(


def _build_definitions(
raw_definitions: Mapping[str, Any]
raw_definitions: Mapping[str, JSONSchema]
) -> Mapping[str, Callable[[], GrammarFunction]]:
definitions: Dict[str, Callable[[], GrammarFunction]] = {}

def build_definition(json_schema: Mapping[str, Any]) -> Callable[[], GrammarFunction]:
def build_definition(json_schema: JSONSchema) -> Callable[[], GrammarFunction]:
@guidance(stateless=True, dedent=False, cache=True)
def closure(lm):
return lm + _gen_json(json_schema=json_schema, definitions=definitions)
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2218,3 +2218,36 @@ def test_all_required_properties_doesnt_blow_up(self, num_properties):
HITS_MAGIC_NUMBER = 1
expected_hits = 0
assert cache_info.hits <= expected_hits + HITS_MAGIC_NUMBER

class TestBooleanSchema:
@pytest.mark.parametrize(
"target_obj",
[
123,
"hello",
[1, 2, 3],
{"a": 1},
None,
[{"a": 1}],
{"a": [1, 2, 3]},
{"a": {"b": 1}},
False,
True
],
)
def test_true_schema(self, target_obj):
# should be the same as an empty schema
schema_obj = True
generate_and_check(target_obj, schema_obj)

@pytest.mark.parametrize(
"schema_obj",
[
False,
{"type": "object", "properties": {"a": False}, "required": ["a"]},
]
)
def test_false_schema(self, schema_obj):
with pytest.raises(ValueError) as ve:
gen_json(schema=schema_obj)
assert ve.value.args[0] == "No valid JSON can be generated from a schema of `False`"

0 comments on commit 35591d8

Please sign in to comment.