From 92afe3334e9ca9ae36932c2c624c04f07c1cf1a4 Mon Sep 17 00:00:00 2001 From: Richard Edgar Date: Wed, 1 May 2024 05:43:21 -0400 Subject: [PATCH] [Feature] Improved JSON Schema support (#787) Add some extra support for JSON schema, to work better with the `langchain` examples: - Allow definitions to be given in a `definitions` key, as well as `$defs` - Allow for single-entry `allOf` blocks --- guidance/library/_json.py | 34 ++++++---- tests/library/test_json.py | 125 +++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 11 deletions(-) diff --git a/guidance/library/_json.py b/guidance/library/_json.py index 93a429d7c..5cd57fe25 100644 --- a/guidance/library/_json.py +++ b/guidance/library/_json.py @@ -34,6 +34,9 @@ def _to_compact_json(target: Any) -> str: return json_dumps(target, separators=(",", ":")) +_DEFS_KEYS = ["$defs", "definitions"] + + @guidance(stateless=True) def _gen_json_int(lm): pos_nonzero = char_range("1", "9") + zero_or_more(char_range("0", "9")) @@ -244,6 +247,13 @@ def _gen_json( anyof_list=json_schema[ANYOF_STRING], definitions=definitions ) + ALLOF_STRING = "allOf" + if ALLOF_STRING in json_schema: + allof_list = json_schema[ALLOF_STRING] + if len(allof_list) != 1: + raise ValueError("Only support allOf with exactly one item") + return lm + _gen_json(allof_list[0], definitions) + REF_STRING = "$ref" if REF_STRING in json_schema: return lm + _get_definition( @@ -342,10 +352,11 @@ def json( else: schema = pydantic_to_json_schema(schema) - _DEFS_KEY = "$defs" definitions: Mapping[str, Callable[[], GrammarFunction]] = {} - if _DEFS_KEY in schema: - definitions = _build_definitions(schema[_DEFS_KEY]) + for dk in _DEFS_KEYS: + if dk in schema: + assert len(definitions) == 0, "Found duplicate definitions" + definitions = _build_definitions(schema[dk]) return lm + capture(_gen_json(schema, definitions), name=name) @@ -378,11 +389,12 @@ def _get_definition( definitions: Mapping[str, Callable[[], GrammarFunction]], ): assert definitions is not None - REF_START = "#/$defs/" - assert reference.startswith( - REF_START - ), f"Reference {reference} must start with {REF_START}" - - target_name = reference[len(REF_START) :] - definition = definitions[target_name] - return lm + definition() + target_definition = None + for dk in _DEFS_KEYS: + ref_start = f"#/{dk}/" + if reference.startswith(ref_start): + target_name = reference[len(ref_start) :] + target_definition = definitions[target_name] + + assert target_definition is not None + return lm + target_definition() diff --git a/tests/library/test_json.py b/tests/library/test_json.py index 1c463f070..dd7606dea 100644 --- a/tests/library/test_json.py +++ b/tests/library/test_json.py @@ -606,6 +606,55 @@ def test_simple_ref(self, target_obj): # The actual check _generate_and_check(target_obj, schema_obj) + @pytest.mark.parametrize( + "target_obj", + [ + dict(all_cats=[]), + dict(all_cats=[dict(name="Kasha")]), + dict(all_cats=[dict(name="Dawon"), dict(name="Barong")]), + ], + ) + def test_simple_ref_alt(self, target_obj): + # Uses 'definitions' rather than '$defs' + schema = """{ + "definitions": { + "Cat": { + "properties": { + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name" + ], + "title": "Cat", + "type": "object" + } + }, + "properties": { + "all_cats": { + "items": { + "$ref": "#/definitions/Cat" + }, + "title": "All Cats", + "type": "array" + } + }, + "required": [ + "all_cats" + ], + "title": "CatList", + "type": "object" + }""" + + # First sanity check what we're setting up + schema_obj = json.loads(schema) + validate(instance=target_obj, schema=schema_obj) + + # The actual check + _generate_and_check(target_obj, schema_obj) + def test_nested_ref(self): schema = """{ "$defs": { @@ -741,6 +790,82 @@ def test_anyOf_objects(self, target_obj): _generate_and_check(target_obj, schema_obj) +class TestAllOf: + @pytest.mark.parametrize( + "my_int", + [0, 1, 100, 9876543210, 99, 737, 858, -1, -10, -20], + ) + def test_allOf_integer(self, my_int): + schema = """{ + "allOf" : [{ "type": "integer" }] + } + """ + # First sanity check what we're setting up + schema_obj = json.loads(schema) + validate(instance=my_int, schema=schema_obj) + + # The actual check + _generate_and_check(my_int, schema_obj) + + def test_allOf_ref(self): + schema = """{ + "definitions": { + "Cat": { + "properties": { + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name" + ], + "title": "Cat", + "type": "object" + } + }, + "type": "object", + "properties": { + "my_cat": { + "allOf": [ + { + "$ref": "#/definitions/Cat" + } + ] + } + } + } + """ + + target_obj = dict(my_cat=dict(name="Sampson")) + # First sanity check what we're setting up + schema_obj = json.loads(schema) + validate(instance=target_obj, schema=schema_obj) + + # The actual check + _generate_and_check(target_obj, schema_obj) + + def test_allOf_bad_schema(self): + schema = """{ + "allOf" : [{ "type": "integer" }, { "type": "number" }] + } + """ + # First sanity check what we're setting up + schema_obj = json.loads(schema) + + TARGET_VALUE = 20 + validate(instance=TARGET_VALUE, schema=schema_obj) + + prepared_string = f"{_to_compact_json(TARGET_VALUE)}" + lm = models.Mock(prepared_string.encode()) + + # Run with the mock model + CAPTURE_KEY = "my_capture" + with pytest.raises(ValueError) as ve: + lm += gen_json(name=CAPTURE_KEY, schema=schema_obj) + assert ve.value.args[0] == "Only support allOf with exactly one item" + + class TestEnum: simple_schema = """{ "enum": [1,"2",false]