Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into riedgar-ms/azure-ai…
Browse files Browse the repository at this point in the history
…-studio-support-01
  • Loading branch information
riedgar-ms committed May 1, 2024
2 parents 3bcb48e + 92afe33 commit 559b341
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 11 deletions.
34 changes: 23 additions & 11 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def _to_compact_json(target: Any) -> str:
return json_dumps(target, separators=(",", ":"))


_DEFS_KEYS = ["$defs", "definitions"]


@guidance(stateless=True)
def _gen_json_int(lm):
pos_nonzero = char_range("1", "9") + zero_or_more(char_range("0", "9"))
Expand Down Expand Up @@ -244,6 +247,13 @@ def _gen_json(
anyof_list=json_schema[ANYOF_STRING], definitions=definitions
)

ALLOF_STRING = "allOf"
if ALLOF_STRING in json_schema:
allof_list = json_schema[ALLOF_STRING]
if len(allof_list) != 1:
raise ValueError("Only support allOf with exactly one item")
return lm + _gen_json(allof_list[0], definitions)

REF_STRING = "$ref"
if REF_STRING in json_schema:
return lm + _get_definition(
Expand Down Expand Up @@ -342,10 +352,11 @@ def json(
else:
schema = pydantic_to_json_schema(schema)

_DEFS_KEY = "$defs"
definitions: Mapping[str, Callable[[], GrammarFunction]] = {}
if _DEFS_KEY in schema:
definitions = _build_definitions(schema[_DEFS_KEY])
for dk in _DEFS_KEYS:
if dk in schema:
assert len(definitions) == 0, "Found duplicate definitions"
definitions = _build_definitions(schema[dk])

return lm + capture(_gen_json(schema, definitions), name=name)

Expand Down Expand Up @@ -378,11 +389,12 @@ def _get_definition(
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
assert definitions is not None
REF_START = "#/$defs/"
assert reference.startswith(
REF_START
), f"Reference {reference} must start with {REF_START}"

target_name = reference[len(REF_START) :]
definition = definitions[target_name]
return lm + definition()
target_definition = None
for dk in _DEFS_KEYS:
ref_start = f"#/{dk}/"
if reference.startswith(ref_start):
target_name = reference[len(ref_start) :]
target_definition = definitions[target_name]

assert target_definition is not None
return lm + target_definition()
125 changes: 125 additions & 0 deletions tests/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,55 @@ def test_simple_ref(self, target_obj):
# The actual check
_generate_and_check(target_obj, schema_obj)

@pytest.mark.parametrize(
"target_obj",
[
dict(all_cats=[]),
dict(all_cats=[dict(name="Kasha")]),
dict(all_cats=[dict(name="Dawon"), dict(name="Barong")]),
],
)
def test_simple_ref_alt(self, target_obj):
# Uses 'definitions' rather than '$defs'
schema = """{
"definitions": {
"Cat": {
"properties": {
"name": {
"title": "Name",
"type": "string"
}
},
"required": [
"name"
],
"title": "Cat",
"type": "object"
}
},
"properties": {
"all_cats": {
"items": {
"$ref": "#/definitions/Cat"
},
"title": "All Cats",
"type": "array"
}
},
"required": [
"all_cats"
],
"title": "CatList",
"type": "object"
}"""

# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)

def test_nested_ref(self):
schema = """{
"$defs": {
Expand Down Expand Up @@ -741,6 +790,82 @@ def test_anyOf_objects(self, target_obj):
_generate_and_check(target_obj, schema_obj)


class TestAllOf:
@pytest.mark.parametrize(
"my_int",
[0, 1, 100, 9876543210, 99, 737, 858, -1, -10, -20],
)
def test_allOf_integer(self, my_int):
schema = """{
"allOf" : [{ "type": "integer" }]
}
"""
# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=my_int, schema=schema_obj)

# The actual check
_generate_and_check(my_int, schema_obj)

def test_allOf_ref(self):
schema = """{
"definitions": {
"Cat": {
"properties": {
"name": {
"title": "Name",
"type": "string"
}
},
"required": [
"name"
],
"title": "Cat",
"type": "object"
}
},
"type": "object",
"properties": {
"my_cat": {
"allOf": [
{
"$ref": "#/definitions/Cat"
}
]
}
}
}
"""

target_obj = dict(my_cat=dict(name="Sampson"))
# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)

def test_allOf_bad_schema(self):
schema = """{
"allOf" : [{ "type": "integer" }, { "type": "number" }]
}
"""
# First sanity check what we're setting up
schema_obj = json.loads(schema)

TARGET_VALUE = 20
validate(instance=TARGET_VALUE, schema=schema_obj)

prepared_string = f"<s>{_to_compact_json(TARGET_VALUE)}"
lm = models.Mock(prepared_string.encode())

# Run with the mock model
CAPTURE_KEY = "my_capture"
with pytest.raises(ValueError) as ve:
lm += gen_json(name=CAPTURE_KEY, schema=schema_obj)
assert ve.value.args[0] == "Only support allOf with exactly one item"


class TestEnum:
simple_schema = """{
"enum": [1,"2",false]
Expand Down

0 comments on commit 559b341

Please sign in to comment.