diff --git a/guidance/library/_json.py b/guidance/library/_json.py index 0ca544726..e03e20dd2 100644 --- a/guidance/library/_json.py +++ b/guidance/library/_json.py @@ -406,10 +406,11 @@ def get_sibling_keys(node: Mapping[str, Any], key: str) -> set[str]: class GenJson: item_separator = ", " key_separator = ": " - def __init__(self, schema: JSONSchema, separators: Optional[tuple[str, str]] = None) -> None: + def __init__(self, schema: JSONSchema, separators: Optional[tuple[str, str]] = None, strict_properties: bool = True) -> None: self.schema = schema if separators is not None: self.item_separator, self.key_separator = separators + self.strict_properties = strict_properties registry: referencing.Registry[JSONSchema] = referencing.Registry() resource: referencing.Resource[JSONSchema] = referencing.jsonschema.DRAFT202012.create_resource(schema) @@ -799,18 +800,8 @@ def any(self, lm): self.json(json_schema={"type": "number"}), self.json(json_schema={"type": "string"}), # Recursive cases - self.json( - json_schema={ - "type": "array", - "items": True, - }, - ), - self.json( - json_schema={ - "type": "object", - "additionalProperties": True, - }, - ), + self.json(json_schema={"type": "array"}), + self.json(json_schema={"type": "object"}), ] ) @@ -940,7 +931,10 @@ def json( elif target_type == JSONType.OBJECT: option = self.object( properties=json_schema.get(ObjectKeywords.PROPERTIES, {}), - additional_properties=json_schema.get(ObjectKeywords.ADDITIONAL_PROPERTIES, True), + # For "true" adherence to the JSON schema spec, additionalProperties should be `False` by default. + # However, defaulting to `False` is going to give a more terse output, which is probably + # more useful for most users. + additional_properties=json_schema.get(ObjectKeywords.ADDITIONAL_PROPERTIES, not self.strict_properties), required=json_schema.get(ObjectKeywords.REQUIRED, set()), ) else: @@ -962,10 +956,11 @@ def json( Type["pydantic.BaseModel"], "pydantic.TypeAdapter", ] = None, - temperature: float = 0.0, - max_tokens: int = 100000000, separators: Optional[tuple[str, str]] = None, whitespace_flexible: bool = False, + strict_properties: bool = True, + temperature: float = 0.0, + max_tokens: int = 100000000, **kwargs, ): """Generate valid JSON according to the supplied JSON schema or `pydantic` model. @@ -1010,6 +1005,23 @@ def json( - A JSON schema object. This is a JSON schema string which has been passed to ``json.loads()`` - A subclass of ``pydantic.BaseModel`` - An instance of ``pydantic.TypeAdapter`` + separators : Optional[Tuple[str, str]] + The (item, key) separators to use when generating the JSON. For maximal compactness/token savings, use `(",", ":")`. + Note that most JSON "in the wild" will include whitespace around these separators, and compact JSON may be out-of-distribution + for some models. + Default: (", ", ": ") + whitespace_flexible : bool + If True, allow for whitespace to be inserted between tokens in the JSON output. This gives maximal control to the model in terms of + formatting (and therefore might be more likely to be in-distribution for some models), but guidance will be able to accelerate fewer + tokens. + strict_properties : bool + If True, the generated JSON will only include additional properties if they are explicitly allowed in the schema. If False, additional + properties will be allowed unless they are explicitly disallowed. + max_tokens : int + The maximum number of tokens to generate. + Note setting this to a small number will likely result in an incomplete JSON object. + temperature : float + The temperature to use when generating the JSON. """ if "compact" in kwargs: warnings.warn("The 'compact' argument is deprecated and has no effect. It will be removed in a future release.", category=DeprecationWarning) @@ -1020,6 +1032,9 @@ def json( # Default schema is empty, "anything goes" schema # TODO: consider default being `{"type": "object"}` schema = {} + # In this case, we don't want to use strict_properties + # because we're not actually validating against a schema + strict_properties = False elif isinstance(schema, (Mapping, bool, str)): if isinstance(schema, str): schema = cast(JSONSchema, json_loads(schema)) @@ -1038,10 +1053,16 @@ def json( else: skip_regex = None + body = GenJson( + schema=schema, + separators=separators, + strict_properties=strict_properties, + ).root() + return lm + with_temperature( subgrammar( name, - body=GenJson(schema=schema, separators=separators).root(), + body=body, skip_regex=skip_regex, no_initial_skip=True, max_tokens=max_tokens, diff --git a/tests/unit/library/test_json.py b/tests/unit/library/test_json.py index d320fdb4f..6ab5adc82 100644 --- a/tests/unit/library/test_json.py +++ b/tests/unit/library/test_json.py @@ -17,13 +17,16 @@ def generate_and_check( - target_obj: Any, schema_obj: Union[str, JSONSchema], desired_temperature: Optional[float] = None + target_obj: Any, + schema_obj: Union[str, JSONSchema], + desired_temperature: Optional[float] = None, + strict_properties: bool = False, ): if isinstance(schema_obj, str): schema_obj = json_loads(schema_obj) # Sanity check what we're being asked - validate(instance=target_obj, schema=schema_obj) + validate(instance=target_obj, schema=schema_obj or {}) prepared_json = json_dumps(target_obj) assert json.loads(prepared_json) == target_obj @@ -31,16 +34,17 @@ def generate_and_check( # We partial in the grammar_callable if desired_temperature is not None: grammar_callable = partial( - gen_json, schema=schema_obj, temperature=desired_temperature + gen_json, schema=schema_obj, temperature=desired_temperature, strict_properties=strict_properties ) else: - grammar_callable = partial(gen_json, schema=schema_obj) + grammar_callable = partial(gen_json, schema=schema_obj, strict_properties=strict_properties) lm = _generate_and_check( grammar_callable, test_string=prepared_json, ) - check_run_with_temperature(lm, desired_temperature) + if desired_temperature is not None: + check_run_with_temperature(lm, desired_temperature) def check_match_failure( @@ -50,8 +54,9 @@ def check_match_failure( failure_byte: Optional[bytes] = None, allowed_bytes: Optional[Set[bytes]] = None, schema_obj: Union[str, JSONSchema], + strict_properties: bool = False, ): - grammar = gen_json(schema=schema_obj) + grammar = gen_json(schema=schema_obj, strict_properties=strict_properties) _check_match_failure( bad_string=bad_string, @@ -3237,7 +3242,7 @@ class TestWhitespace: seps, ) def test_separators(self, separators, schema, obj): - grammar = gen_json(schema=schema, separators=separators) + grammar = gen_json(schema=schema, separators=separators, strict_properties=False) for seps in self.seps: prepared_json = json.dumps(obj, separators=seps) if separators == seps: @@ -3267,14 +3272,13 @@ def test_separators(self, separators, schema, obj): [None, 0, 2, 4], ) def test_whitespace_flexibility(self, indent, separators, schema, obj): - grammar = gen_json(schema=schema, whitespace_flexible=True) + grammar = gen_json(schema=schema, strict_properties=False, whitespace_flexible=True) prepared_json = json.dumps(obj, separators=separators, indent=indent) assert grammar.match(prepared_json, raise_exceptions=True) is not None model = models.Mock(f"{prepared_json}".encode()) assert str(model + grammar) == prepared_json - class TestStringSchema: def test_good(self): schema = """{"type": "object", "properties": {"a": {"type": "string"}}}""" @@ -3287,3 +3291,179 @@ def test_bad(self): bad_string='{"a": 42}', schema_obj=schema, ) + +class TestStrictProperties: + @pytest.mark.parametrize( + "target_obj", + [ + 1, + 1.0, + "2", + False, + [1, 2, 3], + {}, + {"a": 1, "b": 2, "c": 3}, + ] + ) + def test_none_schema_exempt(self, target_obj): + schema_obj = None + generate_and_check(target_obj, schema_obj, strict_properties=True) + + @pytest.mark.parametrize( + "target_obj", + [ + 1, + 1.0, + "2", + False, + [1, 2, 3], + {}, + ] + ) + def test_empty_schema_good(self, target_obj): + schema_obj = {} + generate_and_check(target_obj, schema_obj, strict_properties=True) + + @pytest.mark.parametrize( + "target_obj", + [ + {"a": 1, "b": 2, "c": 3}, + {"a": 1, "b": 2}, + {"a": 1}, + ] + ) + def test_empty_schema_bad(self, target_obj): + schema_obj = {} + check_match_failure( + bad_string=json_dumps(target_obj), + schema_obj=schema_obj, + strict_properties=True, + ) + + @pytest.mark.parametrize( + "target_obj", + [ + {}, + ] + ) + def test_object_schema_good(self, target_obj): + schema_obj = {"type": "object"} + generate_and_check(target_obj, schema_obj, strict_properties=True) + + @pytest.mark.parametrize( + "target_obj", + [ + 1, + 1.0, + "2", + False, + [1, 2, 3], + {"a": 1, "b": 2, "c": 3}, + ] + ) + def test_object_schema_bad(self, target_obj): + schema_obj = {"type": "object"} + check_match_failure( + bad_string=json_dumps(target_obj), + schema_obj=schema_obj, + strict_properties=True, + ) + + @pytest.mark.parametrize( + "target_obj", + [ + {}, + {"a": 1}, + {"b": 2}, + {"a": 1, "b": 2}, + ] + ) + def test_object_schema_with_properties_good(self, target_obj): + schema_obj = { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "number"}, + } + } + generate_and_check(target_obj, schema_obj, strict_properties=True) + + @pytest.mark.parametrize( + "target_obj", + [ + {"a": 1, "b": 2, "c": 3}, + ] + ) + def test_object_schema_with_properties_bad(self, target_obj): + schema_obj = { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "number"}, + } + } + check_match_failure( + bad_string=json_dumps(target_obj), + schema_obj=schema_obj, + strict_properties=True, + ) + + @pytest.mark.parametrize( + "target_obj", + [ + {"foo": {"a": 1}}, + {"foo": {"a": 1.0}}, + {"foo": {"a": "2"}}, + {"foo": {"a": False}}, + {"foo": {"a": [1, 2, 3]}}, + {"foo": {"a": {}}}, + ] + ) + def test_nested_empty_schema_good(self, target_obj): + schema_obj = { + "type": "object", + "properties": { + "foo": {"type": "object", "properties": {"a": {}}}, + } + } + generate_and_check(target_obj, schema_obj, strict_properties=True) + + @pytest.mark.parametrize( + "target_obj", + [ + {"foo": {"a": {"x": 1}}}, + {"foo": {"b": 1}}, + ] + ) + def test_nested_empty_schema_bad(self, target_obj): + schema_obj = { + "type": "object", + "properties": { + "foo": {"type": "object", "properties": {"a": {}}}, + } + } + check_match_failure( + bad_string=json_dumps(target_obj), + schema_obj=schema_obj, + strict_properties=True, + ) + + @pytest.mark.parametrize( + "target_obj", + [ + {"a": 1}, + {"b": 2}, + {"a": 1, "b": 2}, + {"c": "hello"}, + ] + ) + def test_explicit_additional_properties_ok(self, target_obj): + schema_obj = { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "number"}, + }, + "additionalProperties": {"type": "string"}, + } + generate_and_check(target_obj, schema_obj, strict_properties=True) diff --git a/tests/unit/library/test_pydantic.py b/tests/unit/library/test_pydantic.py index 9bda573f6..84bc90d09 100644 --- a/tests/unit/library/test_pydantic.py +++ b/tests/unit/library/test_pydantic.py @@ -1,7 +1,7 @@ import inspect from json import dumps as original_json_dumps from functools import partial -from typing import Any, Dict, Generic, List, Literal, Tuple, Type, TypeVar, Union, Set +from typing import Any, Dict, Generic, List, Literal, Tuple, Type, TypeVar, Union, Optional import pydantic import pytest @@ -55,6 +55,7 @@ def validate_string( def generate_and_check( target_obj: Any, pydantic_model: Union[Type[pydantic.BaseModel], pydantic.TypeAdapter], + strict_properties: bool = False, ): # Sanity check what we're being asked target_obj = validate_obj(target_obj, pydantic_model) @@ -62,19 +63,21 @@ def generate_and_check( assert validate_string(prepared_json, pydantic_model) == target_obj # Check that the grammar can produce the literal prepared_json string - grammar_callable = partial(gen_json, schema=pydantic_model) + grammar_callable = partial(gen_json, schema=pydantic_model, strict_properties=strict_properties) _generate_and_check(grammar_callable, prepared_json) def check_match_failure( + *, bad_obj: Any, - good_bytes: bytes, - failure_byte: bytes, - allowed_bytes: Set[bytes], + good_bytes: Optional[bytes] = None, + failure_byte: Optional[bytes] = None, + allowed_bytes: Optional[set[bytes]] = None, pydantic_model: Union[Type[pydantic.BaseModel], pydantic.TypeAdapter], + strict_properties: bool = False, ): bad_string = json_dumps(bad_obj) - grammar = gen_json(schema=pydantic_model) + grammar = gen_json(schema=pydantic_model, strict_properties=strict_properties) _check_match_failure( bad_string=bad_string, good_bytes=good_bytes, @@ -292,3 +295,39 @@ def test_bad(self): allowed_bytes={b","}, # expect a comma to continue the object with "barks" pydantic_model=self.Model, ) + +class TestExtra: + class Unset(pydantic.BaseModel): + a: int + + class Ignore(pydantic.BaseModel): + a: int + model_config = pydantic.ConfigDict(extra="ignore") + + class Allow(pydantic.BaseModel): + a: int + model_config = pydantic.ConfigDict(extra="allow") + + class Forbid(pydantic.BaseModel): + a: int + model_config = pydantic.ConfigDict(extra="forbid") + + def test_unset(self): + obj = {"a": 42, "b": "hello"} + generate_and_check(target_obj=obj, pydantic_model=self.Unset, strict_properties=False) + check_match_failure(bad_obj=obj, pydantic_model=self.Unset, strict_properties=True) + + def test_ignore(self): + obj = {"a": 42, "b": "hello"} + generate_and_check(target_obj=obj, pydantic_model=self.Ignore, strict_properties=False) + check_match_failure(bad_obj=obj, pydantic_model=self.Ignore, strict_properties=True) + + def test_allow(self): + obj = {"a": 42, "b": "hello"} + generate_and_check(target_obj=obj, pydantic_model=self.Allow, strict_properties=False) + generate_and_check(target_obj=obj, pydantic_model=self.Allow, strict_properties=True) + + def test_forbid(self): + obj = {"a": 42, "b": "hello"} + check_match_failure(bad_obj=obj, pydantic_model=self.Forbid, strict_properties=False) + check_match_failure(bad_obj=obj, pydantic_model=self.Forbid, strict_properties=True)