From 35591d896e1e4f3f95dc8fa0f32fb98b0d1f4a03 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 10 Sep 2024 15:46:08 -0700 Subject: [PATCH] Support boolean JSON schemas (#1015) We were previously supporting `True`/`False` schemas only when nested in certain places like `items` and `additionalProperties`. This expands our coverage to handle top-level boolean schemas as well. Note that we will now raise a `ValueError` for `False` schemas, simply because there is nothing we can generate in this case. May need to revisit this (see #1018) --- guidance/library/_json.py | 49 +++++++++++++++++++-------------- tests/unit/library/test_json.py | 33 ++++++++++++++++++++++ 2 files changed, 61 insertions(+), 21 deletions(-) diff --git a/guidance/library/_json.py b/guidance/library/_json.py index 6348947de..682deaabe 100644 --- a/guidance/library/_json.py +++ b/guidance/library/_json.py @@ -27,6 +27,7 @@ from ._pydantic import pydantic_to_json_schema from ._subgrammar import lexeme, subgrammar +JSONSchema = Union[bool, Mapping[str, Any]] def _to_compact_json(target: Any) -> str: # See 'Compact Encoding': @@ -150,8 +151,8 @@ def _gen_json_string( def _gen_json_object( lm, *, - properties: Mapping[str, Any], - additional_properties: Union[bool, Mapping[str, Any]], + properties: Mapping[str, JSONSchema], + additional_properties: JSONSchema, required: Sequence[str], definitions: Mapping[str, Callable[[], GrammarFunction]], ): @@ -206,16 +207,12 @@ def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool def _gen_json_array( lm, *, - prefix_items_schema: Sequence[Mapping[str, Any]], - item_schema: Union[bool, Mapping[str, Any]], + prefix_items_schema: Sequence[JSONSchema], + item_schema: JSONSchema, min_items: int, max_items: Optional[int], definitions: Mapping[str, Callable[[], GrammarFunction]], ): - if item_schema is True: - # True means that anything goes - item_schema = {} - if len(prefix_items_schema) < min_items and item_schema is False: raise ValueError( f"PrefixItems has too few elements ({len(prefix_items_schema)}) to" @@ -282,7 +279,7 @@ def _gen_json_array( def _process_anyOf( lm, *, - anyof_list: Sequence[Mapping[str, Any]], + anyof_list: Sequence[JSONSchema], definitions: Mapping[str, Callable[[], GrammarFunction]], ): options = [_gen_json(json_schema=item, definitions=definitions) for item in anyof_list] @@ -329,9 +326,14 @@ def _gen_json_any(lm): @guidance(stateless=True) def _gen_json( lm, - json_schema: Mapping[str, Any], + json_schema: JSONSchema, definitions: Mapping[str, Callable[[], GrammarFunction]], ): + if json_schema is True: + json_schema = {} + elif json_schema is False: + raise ValueError("No valid JSON can be generated from a schema of `False`") + validate_json_node_keys(json_schema) if Keyword.ANYOF in json_schema: @@ -403,7 +405,7 @@ def json( *, schema: Union[ None, - Mapping[str, Any], + JSONSchema, Type["pydantic.BaseModel"], "pydantic.TypeAdapter", ] = None, @@ -457,20 +459,25 @@ def json( If True, the generated JSON will be forced to be compact (no whitespace). If False, output will be whitespace-flexible (i.e. decided by the model). """ - if isinstance(schema, Mapping): + if schema is None: + # Default schema is empty, "anything goes" schema + # TODO: consider default being `{"type": "object"}` + schema = {} + elif isinstance(schema, (Mapping, bool)): # Raises jsonschema.exceptions.SchemaError or ValueError # if schema is not valid jsonschema.validators.Draft202012Validator.check_schema(schema) - elif schema is None: - schema = {} - else: + elif isinstance(schema, pydantic.TypeAdapter) or (isinstance(schema, type) and issubclass(schema, pydantic.BaseModel)): schema = pydantic_to_json_schema(schema) + else: + raise TypeError(f"Unsupported schema type: {type(schema)}") definitions: Mapping[str, Callable[[], GrammarFunction]] = {} - for dk in DEFS_KEYS: - if dk in schema: - assert len(definitions) == 0, "Found duplicate definitions" - definitions = _build_definitions(schema[dk]) + if isinstance(schema, Mapping): + for dk in DEFS_KEYS: + if dk in schema: + assert len(definitions) == 0, "Found duplicate definitions" + definitions = _build_definitions(schema[dk]) return lm + with_temperature( subgrammar( @@ -488,11 +495,11 @@ def json( def _build_definitions( - raw_definitions: Mapping[str, Any] + raw_definitions: Mapping[str, JSONSchema] ) -> Mapping[str, Callable[[], GrammarFunction]]: definitions: Dict[str, Callable[[], GrammarFunction]] = {} - def build_definition(json_schema: Mapping[str, Any]) -> Callable[[], GrammarFunction]: + def build_definition(json_schema: JSONSchema) -> Callable[[], GrammarFunction]: @guidance(stateless=True, dedent=False, cache=True) def closure(lm): return lm + _gen_json(json_schema=json_schema, definitions=definitions) diff --git a/tests/unit/library/test_json.py b/tests/unit/library/test_json.py index 2b79eac3b..3f3505c02 100644 --- a/tests/unit/library/test_json.py +++ b/tests/unit/library/test_json.py @@ -2218,3 +2218,36 @@ def test_all_required_properties_doesnt_blow_up(self, num_properties): HITS_MAGIC_NUMBER = 1 expected_hits = 0 assert cache_info.hits <= expected_hits + HITS_MAGIC_NUMBER + +class TestBooleanSchema: + @pytest.mark.parametrize( + "target_obj", + [ + 123, + "hello", + [1, 2, 3], + {"a": 1}, + None, + [{"a": 1}], + {"a": [1, 2, 3]}, + {"a": {"b": 1}}, + False, + True + ], + ) + def test_true_schema(self, target_obj): + # should be the same as an empty schema + schema_obj = True + generate_and_check(target_obj, schema_obj) + + @pytest.mark.parametrize( + "schema_obj", + [ + False, + {"type": "object", "properties": {"a": False}, "required": ["a"]}, + ] + ) + def test_false_schema(self, schema_obj): + with pytest.raises(ValueError) as ve: + gen_json(schema=schema_obj) + assert ve.value.args[0] == "No valid JSON can be generated from a schema of `False`"