Skip to content

Commit

Permalink
Merge pull request #1051 from expectedparrot/question_budget_back
Browse files Browse the repository at this point in the history
Question budget back
  • Loading branch information
apostolosfilippas committed Sep 17, 2024
2 parents 76c8389 + 58082c8 commit 6d37ae9
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 65 deletions.
20 changes: 17 additions & 3 deletions edsl/jobs/interviews/Interview.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ async def _answer_question_and_record_task(
) -> "AgentResponseDict":
"""Answer a question and records the task."""

had_language_model_no_response_error = False

@retry(
stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
wait=wait_exponential(
Expand All @@ -277,6 +279,8 @@ async def _answer_question_and_record_task(
reraise=True,
)
async def attempt_answer():
nonlocal had_language_model_no_response_error

invigilator = self._get_invigilator(question)

if self._skip_this_question(question):
Expand Down Expand Up @@ -306,6 +310,7 @@ async def attempt_answer():

except asyncio.TimeoutError as e:
self._handle_exception(e, invigilator, task)
had_language_model_no_response_error = True
raise LanguageModelNoResponseError(
f"Language model timed out for question '{question.question_name}.'"
)
Expand All @@ -314,14 +319,17 @@ async def attempt_answer():
self._handle_exception(e, invigilator, task)

if "response" not in locals():
had_language_model_no_response_error = True
raise LanguageModelNoResponseError(
f"Language model did not return a response for question '{question.question_name}.'"
)

# it got fixed!
if question.question_name in self.exceptions:
# if it gets here, it means the no response error was fixed
if (
question.question_name in self.exceptions
and had_language_model_no_response_error
):
self.exceptions.record_fixed_question(question.question_name)
# breakpoint()

return response

Expand Down Expand Up @@ -375,6 +383,8 @@ def _handle_exception(
):
import copy

# breakpoint()

answers = copy.copy(self.answers)
exception_entry = InterviewExceptionEntry(
exception=e,
Expand All @@ -385,6 +395,10 @@ def _handle_exception(
task.task_status = TaskStatus.FAILED
self.exceptions.add(invigilator.question.question_name, exception_entry)

if self.raise_validation_errors:
if isinstance(e, QuestionAnswerValidationError):
raise e

if hasattr(self, "stop_on_exception"):
stop_on_exception = self.stop_on_exception
else:
Expand Down
2 changes: 1 addition & 1 deletion edsl/jobs/tasks/TaskHistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def exceptions(self):
def unfixed_exceptions(self):
"""
>>> len(TaskHistory.example().unfixed_exceptions)
0
4
"""
return [
i.exceptions
Expand Down
9 changes: 6 additions & 3 deletions edsl/questions/QuestionBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,13 @@ def html(
if scenario is None:
scenario = {}


prior_answers_dict = {}
for key, value in answers.items():
if not key.endswith("_comment") and not key.endswith("_generated_tokens"):
prior_answers_dict[key] = {"answer": value}

if isinstance(answers, dict):
for key, value in answers.items():
if not key.endswith("_comment") and not key.endswith("_generated_tokens"):
prior_answers_dict[key] = {"answer": value}

# breakpoint()

Expand Down
133 changes: 94 additions & 39 deletions edsl/questions/QuestionBudget.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,65 @@
from __future__ import annotations
import random
from typing import Any, Optional, Union
from typing import Any, Optional, Union, List


from pydantic import Field, BaseModel, validator

from edsl.questions.QuestionBase import QuestionBase
from edsl.questions.descriptors import IntegerDescriptor, QuestionOptionsDescriptor

from edsl.questions.ResponseValidatorABC import ResponseValidatorABC


class BudgewResponseValidator(ResponseValidatorABC):
valid_examples = []

invalid_examples = []

def fix(self, response, verbose=False):
if verbose:
print(f"Fixing list response: {response}")
answer = str(response.get("answer") or response.get("generated_tokens", ""))
if len(answer.split(",")) > 0:
return (
{"answer": answer.split(",")} | {"comment": response.get("comment")}
if "comment" in response
else {}
)


def create_budget_model(
budget_sum: float, permissive: bool, question_options: List[str]
):

class BudgetResponse(BaseModel):
answer: List[float] = Field(
...,
description="List of non-negative numbers representing budget allocation",
min_items=len(question_options),
max_items=len(question_options),
)
comment: Optional[str] = None
generated_tokens: Optional[str] = None

@validator("answer")
def validate_answer(cls, v):
if len(v) != len(question_options):
raise ValueError(f"Must provide {len(question_options)} values")
if any(x < 0 for x in v):
raise ValueError("All values must be non-negative")
total = sum(v)
if not permissive and total != budget_sum:
raise ValueError(f"Sum of numbers must equal {budget_sum}")
elif permissive and total > budget_sum:
raise ValueError(f"Sum of numbers cannot exceed {budget_sum}")
return v

class Config:
extra = "forbid"

return BudgetResponse


class QuestionBudget(QuestionBase):
"""This question prompts the agent to allocate a budget among options."""
Expand All @@ -12,16 +68,18 @@ class QuestionBudget(QuestionBase):
budget_sum: int = IntegerDescriptor(none_allowed=False)
question_options: list[str] = QuestionOptionsDescriptor(q_budget=True)
_response_model = None
response_validator_class = None
response_validator_class = BudgewResponseValidator

def __init__(
self,
question_name: str,
question_text: str,
question_options: list[str],
budget_sum: int,
include_comment: bool = True,
question_presentation: Optional[str] = None,
answering_instructions: Optional[str] = None,
permissive: bool = False,
):
"""Instantiate a new QuestionBudget.
Expand All @@ -36,20 +94,17 @@ def __init__(
self.budget_sum = budget_sum
self.question_presentation = question_presentation
self.answering_instructions = answering_instructions
self.permissive = permissive
self.include_comment = include_comment

################
# Answer methods
################
def _validate_answer(self, answer: dict[str, Any]) -> dict[str, Union[int, str]]:
"""Validate the answer."""
self._validate_answer_template_basic(answer)
self._validate_answer_key_value(answer, "answer", dict)
self._validate_answer_budget(answer)
return answer
def create_response_model(self):
return create_budget_model(
self.budget_sum, self.permissive, self.question_options
)

def _translate_answer_code_to_answer(
self, answer_codes: dict[str, int], scenario: "Scenario" = None
):
self, answer_code, combined_dict
) -> list[dict]:
"""
Translate the answer codes to the actual answers.
Expand All @@ -58,35 +113,35 @@ def _translate_answer_code_to_answer(
This code will translate that to "a".
"""
translated_codes = []
for answer_code, response in answer_codes.items():
translated_codes.append({self.question_options[int(answer_code)]: response})
for answer_code, question_option in zip(answer_code, self.question_options):
translated_codes.append({question_option: answer_code})

return translated_codes

def _simulate_answer(self, human_readable=True):
"""Simulate a valid answer for debugging purposes (what the validator expects)."""
from edsl.utilities.utilities import random_string

if human_readable:
keys = self.question_options
else:
keys = range(len(self.question_options))
remaining_budget = self.budget_sum
values = []
for _ in range(len(self.question_options)):
if _ == len(self.question_options) - 1:
# Assign remaining budget to the last value
values.append(remaining_budget)
else:
# Generate a random value between 0 and remaining budget
value = random.randint(0, remaining_budget)
values.append(value)
remaining_budget -= value
answer = dict(zip(keys, values))
return {
"answer": answer,
"comment": random_string(),
}
# def _simulate_answer(self, human_readable=True):
# """Simulate a valid answer for debugging purposes (what the validator expects)."""
# from edsl.utilities.utilities import random_string

# if human_readable:
# keys = self.question_options
# else:
# keys = range(len(self.question_options))
# remaining_budget = self.budget_sum
# values = []
# for _ in range(len(self.question_options)):
# if _ == len(self.question_options) - 1:
# # Assign remaining budget to the last value
# values.append(remaining_budget)
# else:
# # Generate a random value between 0 and remaining budget
# value = random.randint(0, remaining_budget)
# values.append(value)
# remaining_budget -= value
# answer = dict(zip(keys, values))
# return {
# "answer": answer,
# "comment": random_string(),
# }

@property
def question_html_content(self) -> str:
Expand Down
Empty file.
7 changes: 7 additions & 0 deletions edsl/questions/templates/budget/answering_instructions.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Return only a comma-separated list the values in the same order as the options, with 0s included, on one line, in square braces.

Example: if there are 4 options, the response should be "[25,25,25,25]" to allocate 25 to each option.

{% if include_comment %}
After the answer, you can put a comment explaining your choice on the next line.
{% endif %}
7 changes: 7 additions & 0 deletions edsl/questions/templates/budget/question_presentation.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{{question_text}}
The options are
{% for option in question_options %}
{{ loop.index0 }}: {{option}}
{% endfor %}
Allocate your budget of {{budget_sum}} among the options.

39 changes: 20 additions & 19 deletions tests/questions/test_QuestionBudget.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from edsl.questions.QuestionBudget import QuestionBudget, main


def test_QuestionBudget_main():
main()
# def test_QuestionBudget_main():
# main()


valid_question = {
Expand Down Expand Up @@ -103,7 +103,8 @@ def test_QuestionBudget_construction():


def test_QuestionBudget_answers():
valid_answer = {"answer": {"0": 25, "1": 25, "2": 25, "3": 25}, "comment": "Yum!"}
# valid_answer = {"answer": {"0": 25, "1": 25, "2": 25, "3": 25}, "comment": "Yum!"}
valid_answer = {"answer": [25, 25, 25, 25], "comment": "Yum!"}
q = QuestionBudget(**valid_question)
# answer must be an integer or interpretable as integer
q._validate_answer(valid_answer)
Expand Down Expand Up @@ -139,25 +140,25 @@ def test_QuestionBudget_extras():
q = QuestionBudget(**valid_question)
# instructions
# translate
assert q._translate_answer_code_to_answer({"0": 25, "1": 25, "2": 25, "3": 25}) == [
assert q._translate_answer_code_to_answer([25, 25, 25, 25], {}) == [
{"Pizza": 25},
{"Ice Cream": 25},
{"Burgers": 25},
{"Salad": 25},
]
# _simulate_answer
assert q._simulate_answer().keys() == q._simulate_answer(human_readable=True).keys()
simulated_answer = q._simulate_answer(human_readable=False)
assert isinstance(simulated_answer, dict)
assert "answer" in simulated_answer
assert "comment" in simulated_answer
assert isinstance(simulated_answer["answer"], dict)
assert all(
[type(k) == int and k in range(len(q.question_options))]
for k in simulated_answer["answer"].keys()
)
assert round(sum(simulated_answer["answer"].values())) == q.budget_sum
assert list(q._simulate_answer(human_readable=False)["answer"].keys()) == list(
range(len(q.question_options))
)
# form elements
# assert q._simulate_answer().keys() == q._simulate_answer(human_readable=True).keys()
# simulated_answer = q._simulate_answer(human_readable=False)
# assert isinstance(simulated_answer, dict)
# assert "answer" in simulated_answer
# assert "comment" in simulated_answer
# assert isinstance(simulated_answer["answer"], dict)
# assert all(
# [type(k) == int and k in range(len(q.question_options))]
# for k in simulated_answer["answer"].keys()
# )
# assert round(sum(simulated_answer["answer"].values())) == q.budget_sum
# assert list(q._simulate_answer(human_readable=False)["answer"].keys()) == list(
# range(len(q.question_options))
# )
# # form elements

0 comments on commit 6d37ae9

Please sign in to comment.