Skip to content

Commit

Permalink
Look at the temperatures
Browse files Browse the repository at this point in the history
  • Loading branch information
riedgar-ms committed Apr 30, 2024
1 parent e41a5c2 commit 2ccec1a
Showing 1 changed file with 67 additions and 29 deletions.
96 changes: 67 additions & 29 deletions tests/library/test_json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Union

import pytest
from jsonschema import validate
Expand All @@ -10,7 +10,9 @@
from guidance.library._json import _to_compact_json


def _generate_and_check(target_obj: Any, schema_obj):
def _generate_and_check(
target_obj: Any, schema_obj, desired_temperature: Union[float, None] = None
):
# Sanity check what we're being asked
validate(instance=target_obj, schema=schema_obj)

Expand All @@ -19,11 +21,34 @@ def _generate_and_check(target_obj: Any, schema_obj):

# Run with the mock model
CAPTURE_KEY = "my_capture"
lm += gen_json(name=CAPTURE_KEY, schema=schema_obj)
if desired_temperature is not None:
lm += gen_json(
name=CAPTURE_KEY, schema=schema_obj, temperature=desired_temperature
)
else:
lm += gen_json(name=CAPTURE_KEY, schema=schema_obj)

# Make sure the round trip works
assert json.loads(lm[CAPTURE_KEY]) == target_obj

# Check on some temperatures
if desired_temperature is not None:
assert len(lm.engine.called_temperatures) > 0
# Make sure that at least one temperature matches exactly
temperature_matches = [
x == desired_temperature for x in lm.engine.called_temperatures
]
assert any(temperature_matches)
# Check that all temperatures were 0 or the desired temperature
# If there has been a forced byte, then get_logits() is
# called with a temperature of zero
assert all(
[
(x == desired_temperature or x == 0)
for x in lm.engine.called_temperatures
]
)


def _check_match_failure(bad_string, failure_byte, schema_obj):
grammar = gen_json(schema=schema_obj)
Expand All @@ -45,14 +70,15 @@ def test_null():


@pytest.mark.parametrize("target_obj", [True, False])
def test_boolean(target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_boolean(target_obj, temperature):
schema = """{"type": "boolean" }"""

# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=target_obj, schema=schema_obj)

_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)


class TestInteger:
Expand Down Expand Up @@ -110,13 +136,14 @@ class TestNumber:
123.6,
],
)
def test_number(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_number(self, target_obj, temperature):
# First sanity check what we're setting up
schema_obj = json.loads(TestNumber.schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize(
["bad_string", "failure_byte"],
Expand Down Expand Up @@ -154,15 +181,16 @@ def test_bad_number(self, bad_string, failure_byte):
"Some more symbols: ; are useful!",
],
)
def test_string_schema(my_string: str):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_string_schema(my_string: str, temperature):
schema = """{ "type": "string" }"""

# First sanity check what we're setting up
schema_obj = json.loads(schema)
validate(instance=my_string, schema=schema_obj)

# The actual check
_generate_and_check(my_string, schema_obj)
_generate_and_check(my_string, schema_obj, desired_temperature=temperature)


class TestSimpleObject:
Expand Down Expand Up @@ -208,7 +236,8 @@ def test_object_with_many_properties(self):
# The actual check
_generate_and_check(target_obj, schema_obj)

def test_directly_nested_object(self):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_directly_nested_object(self, temperature):
schema = """{
"type": "object",
"properties": {
Expand Down Expand Up @@ -236,9 +265,10 @@ def test_directly_nested_object(self):
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

def test_object_containing_list(self):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_object_containing_list(self, temperature):
schema = """{
"type": "object",
"properties" : {
Expand All @@ -261,7 +291,7 @@ def test_object_containing_list(self):
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize(
["bad_string", "failure_byte"],
Expand All @@ -286,7 +316,8 @@ def test_bad_object(self, bad_string, failure_byte):
class TestSimpleArray:
# These are array without references
@pytest.mark.parametrize("target_obj", [[], [0], [34, 56], [1, 2, 3], [9, 8, 7, 6]])
def test_integer_list(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_integer_list(self, target_obj, temperature):
schema = """{
"type" : "array",
"items" : {
Expand All @@ -300,7 +331,7 @@ def test_integer_list(self, target_obj):
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize("target_obj", [[], ["a"], ["b c", "d, e"]])
def test_string_list(self, target_obj):
Expand All @@ -323,7 +354,8 @@ def test_string_list(self, target_obj):
"target_obj",
[[], [dict(a=1)], [dict(a=2), dict(a=3)], [dict(a=4), dict(a=5), dict(a=6)]],
)
def test_object_list(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_object_list(self, target_obj, temperature):
schema = """{
"type" : "array",
"items" : {
Expand All @@ -342,7 +374,7 @@ def test_object_list(self, target_obj):
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize(
["bad_string", "failure_byte"],
Expand Down Expand Up @@ -606,7 +638,8 @@ def test_simple_ref(self, target_obj):
# The actual check
_generate_and_check(target_obj, schema_obj)

def test_nested_ref(self):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_nested_ref(self, temperature):
schema = """{
"$defs": {
"A": {
Expand Down Expand Up @@ -659,12 +692,13 @@ def test_nested_ref(self):
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)


class TestAnyOf:
@pytest.mark.parametrize("target_obj", [123, True])
def test_anyOf_simple(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_anyOf_simple(self, target_obj, temperature):
schema = """{
"anyOf": [
{
Expand All @@ -681,7 +715,7 @@ def test_anyOf_simple(self, target_obj):
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize(
"target_obj",
Expand All @@ -690,7 +724,8 @@ def test_anyOf_simple(self, target_obj):
dict(my_val=dict(my_str="Some long string or other")),
],
)
def test_anyOf_objects(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_anyOf_objects(self, target_obj, temperature):
schema = """{
"$defs": {
"A": {
Expand Down Expand Up @@ -738,7 +773,7 @@ def test_anyOf_objects(self, target_obj):
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)


class TestEnum:
Expand All @@ -752,13 +787,14 @@ class TestEnum:
}"""

@pytest.mark.parametrize("target_obj", [1, "2", False])
def test_enum(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_enum(self, target_obj, temperature):
# First sanity check what we're setting up
schema_obj = json.loads(self.simple_schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize(
"bad_obj, failure_byte",
Expand Down Expand Up @@ -820,13 +856,14 @@ class TestAdditionalProperties:
"""

@pytest.mark.parametrize("target_obj", [{}, {"a": 1}, {"a": 1, "b": 2}])
def test_simple_additional_properties(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_simple_additional_properties(self, target_obj, temperature):
# First sanity check what we're setting up
schema_obj = json.loads(self.simple_schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize(
"bad_obj, failure_byte",
Expand Down Expand Up @@ -868,13 +905,14 @@ def test_anyOf_bad_type(self, bad_obj, failure_byte):
{"mystr": "hello", "a": 1, "b": 2},
],
)
def test_properties_and_additional_properties(self, target_obj):
@pytest.mark.parametrize("temperature", [None, 0.1, 1])
def test_properties_and_additional_properties(self, target_obj, temperature):
# First sanity check what we're setting up
schema_obj = json.loads(self.combined_schema)
validate(instance=target_obj, schema=schema_obj)

# The actual check
_generate_and_check(target_obj, schema_obj)
_generate_and_check(target_obj, schema_obj, desired_temperature=temperature)

@pytest.mark.parametrize(
"bad_obj, failure_byte",
Expand Down

0 comments on commit 2ccec1a

Please sign in to comment.