Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JSON] Add strict_properties kwarg to guidance.json to make JSON output terser by default #1068

Closed
wants to merge 9 commits into from
55 changes: 38 additions & 17 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,11 @@ def get_sibling_keys(node: Mapping[str, Any], key: str) -> set[str]:
class GenJson:
item_separator = ", "
key_separator = ": "
def __init__(self, schema: JSONSchema, separators: Optional[tuple[str, str]] = None) -> None:
def __init__(self, schema: JSONSchema, separators: Optional[tuple[str, str]] = None, strict_properties: bool = True) -> None:
self.schema = schema
if separators is not None:
self.item_separator, self.key_separator = separators
self.strict_properties = strict_properties

registry: referencing.Registry[JSONSchema] = referencing.Registry()
resource: referencing.Resource[JSONSchema] = referencing.jsonschema.DRAFT202012.create_resource(schema)
Expand Down Expand Up @@ -799,18 +800,8 @@ def any(self, lm):
self.json(json_schema={"type": "number"}),
self.json(json_schema={"type": "string"}),
# Recursive cases
self.json(
json_schema={
"type": "array",
"items": True,
},
),
self.json(
json_schema={
"type": "object",
"additionalProperties": True,
},
),
self.json(json_schema={"type": "array"}),
self.json(json_schema={"type": "object"}),
]
)

Expand Down Expand Up @@ -940,7 +931,10 @@ def json(
elif target_type == JSONType.OBJECT:
option = self.object(
properties=json_schema.get(ObjectKeywords.PROPERTIES, {}),
additional_properties=json_schema.get(ObjectKeywords.ADDITIONAL_PROPERTIES, True),
# For "true" adherence to the JSON schema spec, additionalProperties should be `False` by default.
# However, defaulting to `False` is going to give a more terse output, which is probably
# more useful for most users.
additional_properties=json_schema.get(ObjectKeywords.ADDITIONAL_PROPERTIES, not self.strict_properties),
required=json_schema.get(ObjectKeywords.REQUIRED, set()),
)
else:
Expand All @@ -962,10 +956,11 @@ def json(
Type["pydantic.BaseModel"],
"pydantic.TypeAdapter",
] = None,
temperature: float = 0.0,
max_tokens: int = 100000000,
separators: Optional[tuple[str, str]] = None,
whitespace_flexible: bool = False,
strict_properties: bool = True,
temperature: float = 0.0,
max_tokens: int = 100000000,
**kwargs,
):
"""Generate valid JSON according to the supplied JSON schema or `pydantic` model.
Expand Down Expand Up @@ -1010,6 +1005,23 @@ def json(
- A JSON schema object. This is a JSON schema string which has been passed to ``json.loads()``
- A subclass of ``pydantic.BaseModel``
- An instance of ``pydantic.TypeAdapter``
separators : Optional[Tuple[str, str]]
The (item, key) separators to use when generating the JSON. For maximal compactness/token savings, use `(",", ":")`.
Note that most JSON "in the wild" will include whitespace around these separators, and compact JSON may be out-of-distribution
for some models.
Default: (", ", ": ")
whitespace_flexible : bool
If True, allow for whitespace to be inserted between tokens in the JSON output. This gives maximal control to the model in terms of
formatting (and therefore might be more likely to be in-distribution for some models), but guidance will be able to accelerate fewer
tokens.
strict_properties : bool
If True, the generated JSON will only include additional properties if they are explicitly allowed in the schema. If False, additional
properties will be allowed unless they are explicitly disallowed.
max_tokens : int
The maximum number of tokens to generate.
Note setting this to a small number will likely result in an incomplete JSON object.
temperature : float
The temperature to use when generating the JSON.
"""
if "compact" in kwargs:
warnings.warn("The 'compact' argument is deprecated and has no effect. It will be removed in a future release.", category=DeprecationWarning)
Expand All @@ -1020,6 +1032,9 @@ def json(
# Default schema is empty, "anything goes" schema
# TODO: consider default being `{"type": "object"}`
schema = {}
# In this case, we don't want to use strict_properties
# because we're not actually validating against a schema
strict_properties = False
elif isinstance(schema, (Mapping, bool, str)):
if isinstance(schema, str):
schema = cast(JSONSchema, json_loads(schema))
Expand All @@ -1038,10 +1053,16 @@ def json(
else:
skip_regex = None

body = GenJson(
schema=schema,
separators=separators,
strict_properties=strict_properties,
).root()

return lm + with_temperature(
subgrammar(
name,
body=GenJson(schema=schema, separators=separators).root(),
body=body,
skip_regex=skip_regex,
no_initial_skip=True,
max_tokens=max_tokens,
Expand Down
198 changes: 189 additions & 9 deletions tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,34 @@


def generate_and_check(
target_obj: Any, schema_obj: Union[str, JSONSchema], desired_temperature: Optional[float] = None
target_obj: Any,
schema_obj: Union[str, JSONSchema],
desired_temperature: Optional[float] = None,
strict_properties: bool = False,
):
if isinstance(schema_obj, str):
schema_obj = json_loads(schema_obj)

# Sanity check what we're being asked
validate(instance=target_obj, schema=schema_obj)
validate(instance=target_obj, schema=schema_obj or {})
prepared_json = json_dumps(target_obj)
assert json.loads(prepared_json) == target_obj

# Now test that the grammar can recognize and generate prepared_json
# We partial in the grammar_callable
if desired_temperature is not None:
grammar_callable = partial(
gen_json, schema=schema_obj, temperature=desired_temperature
gen_json, schema=schema_obj, temperature=desired_temperature, strict_properties=strict_properties
)
else:
grammar_callable = partial(gen_json, schema=schema_obj)
grammar_callable = partial(gen_json, schema=schema_obj, strict_properties=strict_properties)

lm = _generate_and_check(
grammar_callable,
test_string=prepared_json,
)
check_run_with_temperature(lm, desired_temperature)
if desired_temperature is not None:
check_run_with_temperature(lm, desired_temperature)


def check_match_failure(
Expand All @@ -50,8 +54,9 @@ def check_match_failure(
failure_byte: Optional[bytes] = None,
allowed_bytes: Optional[Set[bytes]] = None,
schema_obj: Union[str, JSONSchema],
strict_properties: bool = False,
):
grammar = gen_json(schema=schema_obj)
grammar = gen_json(schema=schema_obj, strict_properties=strict_properties)

_check_match_failure(
bad_string=bad_string,
Expand Down Expand Up @@ -3237,7 +3242,7 @@ class TestWhitespace:
seps,
)
def test_separators(self, separators, schema, obj):
grammar = gen_json(schema=schema, separators=separators)
grammar = gen_json(schema=schema, separators=separators, strict_properties=False)
for seps in self.seps:
prepared_json = json.dumps(obj, separators=seps)
if separators == seps:
Expand Down Expand Up @@ -3267,14 +3272,13 @@ def test_separators(self, separators, schema, obj):
[None, 0, 2, 4],
)
def test_whitespace_flexibility(self, indent, separators, schema, obj):
grammar = gen_json(schema=schema, whitespace_flexible=True)
grammar = gen_json(schema=schema, strict_properties=False, whitespace_flexible=True)
prepared_json = json.dumps(obj, separators=separators, indent=indent)

assert grammar.match(prepared_json, raise_exceptions=True) is not None
model = models.Mock(f"<s>{prepared_json}".encode())
assert str(model + grammar) == prepared_json


class TestStringSchema:
def test_good(self):
schema = """{"type": "object", "properties": {"a": {"type": "string"}}}"""
Expand All @@ -3287,3 +3291,179 @@ def test_bad(self):
bad_string='{"a": 42}',
schema_obj=schema,
)

class TestStrictProperties:
@pytest.mark.parametrize(
"target_obj",
[
1,
1.0,
"2",
False,
[1, 2, 3],
{},
{"a": 1, "b": 2, "c": 3},
]
)
def test_none_schema_exempt(self, target_obj):
schema_obj = None
generate_and_check(target_obj, schema_obj, strict_properties=True)

@pytest.mark.parametrize(
"target_obj",
[
1,
1.0,
"2",
False,
[1, 2, 3],
{},
]
)
def test_empty_schema_good(self, target_obj):
schema_obj = {}
generate_and_check(target_obj, schema_obj, strict_properties=True)

@pytest.mark.parametrize(
"target_obj",
[
{"a": 1, "b": 2, "c": 3},
{"a": 1, "b": 2},
{"a": 1},
]
)
def test_empty_schema_bad(self, target_obj):
schema_obj = {}
check_match_failure(
bad_string=json_dumps(target_obj),
schema_obj=schema_obj,
strict_properties=True,
)

@pytest.mark.parametrize(
"target_obj",
[
{},
]
)
def test_object_schema_good(self, target_obj):
schema_obj = {"type": "object"}
generate_and_check(target_obj, schema_obj, strict_properties=True)

@pytest.mark.parametrize(
"target_obj",
[
1,
1.0,
"2",
False,
[1, 2, 3],
{"a": 1, "b": 2, "c": 3},
]
)
def test_object_schema_bad(self, target_obj):
schema_obj = {"type": "object"}
check_match_failure(
bad_string=json_dumps(target_obj),
schema_obj=schema_obj,
strict_properties=True,
)

@pytest.mark.parametrize(
"target_obj",
[
{},
{"a": 1},
{"b": 2},
{"a": 1, "b": 2},
]
)
def test_object_schema_with_properties_good(self, target_obj):
schema_obj = {
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "number"},
}
}
generate_and_check(target_obj, schema_obj, strict_properties=True)

@pytest.mark.parametrize(
"target_obj",
[
{"a": 1, "b": 2, "c": 3},
]
)
def test_object_schema_with_properties_bad(self, target_obj):
schema_obj = {
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "number"},
}
}
check_match_failure(
bad_string=json_dumps(target_obj),
schema_obj=schema_obj,
strict_properties=True,
)

@pytest.mark.parametrize(
"target_obj",
[
{"foo": {"a": 1}},
{"foo": {"a": 1.0}},
{"foo": {"a": "2"}},
{"foo": {"a": False}},
{"foo": {"a": [1, 2, 3]}},
{"foo": {"a": {}}},
]
)
def test_nested_empty_schema_good(self, target_obj):
schema_obj = {
"type": "object",
"properties": {
"foo": {"type": "object", "properties": {"a": {}}},
}
}
generate_and_check(target_obj, schema_obj, strict_properties=True)

@pytest.mark.parametrize(
"target_obj",
[
{"foo": {"a": {"x": 1}}},
{"foo": {"b": 1}},
]
)
def test_nested_empty_schema_bad(self, target_obj):
schema_obj = {
"type": "object",
"properties": {
"foo": {"type": "object", "properties": {"a": {}}},
}
}
check_match_failure(
bad_string=json_dumps(target_obj),
schema_obj=schema_obj,
strict_properties=True,
)

@pytest.mark.parametrize(
"target_obj",
[
{"a": 1},
{"b": 2},
{"a": 1, "b": 2},
{"c": "hello"},
]
)
def test_explicit_additional_properties_ok(self, target_obj):
schema_obj = {
"type": "object",
"properties": {
"a": {"type": "integer"},
"b": {"type": "number"},
},
"additionalProperties": {"type": "string"},
}
generate_and_check(target_obj, schema_obj, strict_properties=True)
Loading
Loading