Skip to content

Commit

Permalink
Merge branch 'fix_to_scenario_list_issue' into retry_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjosephhorton committed Sep 16, 2024
2 parents b1408ad + b21a480 commit dcadf0a
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 5 deletions.
15 changes: 15 additions & 0 deletions edsl/questions/QuestionBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,21 @@ def html(

return rendered_html

@classmethod
def example_model(cls):
from edsl import Model

q = cls.example()
m = Model("test", canned_response=cls._simulate_answer(q)["answer"])

return m

@classmethod
def example_results(cls):
m = cls.example_model()
q = cls.example()
return q.by(m).run(cache=False)

def rich_print(self):
"""Print the question in a rich format."""
from rich.table import Table
Expand Down
Empty file.
Empty file.
6 changes: 5 additions & 1 deletion edsl/results/DatasetExportMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,11 @@ def to_scenario_list(self, remove_prefix: bool = True) -> list[dict]:
from edsl import ScenarioList, Scenario

list_of_dicts = self.to_dicts(remove_prefix=remove_prefix)
return ScenarioList([Scenario(d) for d in list_of_dicts])
scenarios = []
for d in list_of_dicts:
scenarios.append(Scenario(d))
return ScenarioList(scenarios)
# return ScenarioList([Scenario(d) for d in list_of_dicts])

def to_agent_list(self, remove_prefix: bool = True):
"""Convert the results to a list of dictionaries, one per agent.
Expand Down
12 changes: 10 additions & 2 deletions edsl/scenarios/FileStore.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,16 @@ def view(self):


class PDFFileStore(FileStore):
def __init__(self, filename):
super().__init__(filename, suffix=".pdf")
def __init__(
self,
filename,
binary: Optional[bool] = None,
suffix: Optional[str] = None,
base64_string: Optional[str] = None,
):
super().__init__(
filename, binary=binary, base64_string=base64_string, suffix=".pdf"
)

def view(self):
pdf_path = self.to_tempfile()
Expand Down
6 changes: 5 additions & 1 deletion edsl/scenarios/ScenarioList.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, data: Optional[list] = None, codebook: Optional[dict] = None)
def has_jinja_braces(self) -> bool:
"""Check if the ScenarioList has Jinja braces."""
return any([scenario.has_jinja_braces for scenario in self])

def convert_jinja_braces(self) -> ScenarioList:
"""Convert Jinja braces to Python braces."""
return ScenarioList([scenario.convert_jinja_braces() for scenario in self])
Expand Down Expand Up @@ -282,6 +282,10 @@ def _repr_html_(self) -> str:
for s in data["scenarios"]:
_ = s.pop("edsl_version")
_ = s.pop("edsl_class_name")
for scenario in data["scenarios"]:
for key, value in scenario.items():
if hasattr(value, "to_dict"):
data[key] = value.to_dict()
return data_to_html(data)

def tally(self, field) -> dict:
Expand Down
10 changes: 9 additions & 1 deletion edsl/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
from typing import Callable, Union


class CustomEncoder(json.JSONEncoder):
def default(self, obj):
try:
return json.JSONEncoder.default(self, obj)
except TypeError:
return str(obj)


def time_it(func):
@wraps(func)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -124,7 +132,7 @@ def data_to_html(data, replace_new_lines=False):
from pygments.formatters import HtmlFormatter
from IPython.display import HTML

json_str = json.dumps(data, indent=4)
json_str = json.dumps(data, indent=4, cls=CustomEncoder)
formatted_json = highlight(
json_str,
JsonLexer(),
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ license = "MIT"
name = "edsl"
readme = "README.md"
version = "0.1.33.dev1"
include = [
"edsl/questions/templates/**/*",
]

[tool.poetry.dependencies]
python = ">=3.9.1,<3.13"
Expand Down
13 changes: 13 additions & 0 deletions tests/questions/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
from edsl import Question


@pytest.mark.parametrize("question_type", Question.available())
def test_individual_questions(question_type):
if question_type != "functional":
q = Question.example(question_type)
r = q.example_results()
_ = hash(r)
_ = r._repr_html_()
else:
pytest.skip("Skipping functional question type")

0 comments on commit dcadf0a

Please sign in to comment.