Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt constructor refactor #1249

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 2 additions & 23 deletions edsl/agents/InvigilatorBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,24 +236,6 @@ def example(
from edsl import Model

model = Model("test", canned_response="SPAM!")
# class TestLanguageModelGood(LanguageModel):
# """A test language model."""

# _model_ = "test"
# _parameters_ = {"temperature": 0.5}
# _inference_service_ = InferenceServiceType.TEST.value

# async def async_execute_model_call(
# self, user_prompt: str, system_prompt: str
# ) -> dict[str, Any]:
# await asyncio.sleep(0.1)
# if hasattr(self, "throw_an_exception"):
# raise Exception("Error!")
# return {"message": """{"answer": "SPAM!"}"""}

# def parse_response(self, raw_response: dict[str, Any]) -> str:
# """Parse the response from the model."""
# return raw_response["message"]

if throw_an_exception:
model.throw_an_exception = True
Expand All @@ -263,11 +245,8 @@ def example(

if not survey:
survey = Survey.example()
# if question:
# need to have the focal question name in the list of names
# survey._questions[0].question_name = question.question_name
# survey.add_question(question)
if question:

if question not in survey.questions and question is not None:
survey.add_question(question)

question = question or survey.questions[0]
Expand Down
223 changes: 123 additions & 100 deletions edsl/agents/PromptConstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,127 @@ def question_file_keys(self):
question_file_keys.append(var)
return question_file_keys

def build_replacement_dict(self, question_data: dict):
"""
Builds a dictionary of replacement values by combining multiple data sources.
"""
# File references dictionary
file_refs = {key: f"<see file {key}>" for key in self.scenario_file_keys}

# Scenario items excluding file keys
scenario_items = {
k: v for k, v in self.scenario.items() if k not in self.scenario_file_keys
}

# Question settings with defaults
question_settings = {
"use_code": getattr(self.question, "_use_code", True),
"include_comment": getattr(self.question, "_include_comment", False),
}

# Combine all dictionaries using dict.update() for clarity
replacement_dict = {}
for d in [
file_refs,
question_data,
scenario_items,
self.prior_answers_dict(),
{"agent": self.agent},
question_settings,
]:
replacement_dict.update(d)

return replacement_dict

def _get_question_options(self, question_data):
question_options_entry = question_data.get("question_options", None)
question_options = question_options_entry

if isinstance(question_options_entry, str):
env = Environment()
parsed_content = env.parse(question_options_entry)
question_option_key = list(meta.find_undeclared_variables(parsed_content))[
0
]
# print("question_option_key: ", question_option_key)
# breakpoint()
# look to see if the question_option_key is in the scenario
if isinstance(self.scenario.get(question_option_key), list):
question_options = self.scenario.get(question_option_key)

# breakpoint()

# might be getting it from the prior answers
if self.prior_answers_dict().get(question_option_key) is not None:
prior_question = self.prior_answers_dict().get(question_option_key)
if hasattr(prior_question, "answer"):
if isinstance(prior_question.answer, list):
question_options = prior_question.answer
else:
question_options = [
"N/A",
"Will be populated by prior answer",
"These are placeholder options",
]
return question_options

def build_question_instructions_prompt(self):
"""Buils the question instructions prompt."""

question_prompt = Prompt(self.question.get_instructions(model=self.model.model))

# Get the data for the question - this is a dictionary of the question data
# e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
question_data = self.question.data.copy()

if "question_options" in question_data:
question_options = self._get_question_options(question_data)
question_data["question_options"] = question_options

# check to see if the question_options is actually a string
# This is used when the user is using the question_options as a variable from a scenario
# if "question_options" in question_data:
replacement_dict = self.build_replacement_dict(question_data)
rendered_instructions = question_prompt.render(replacement_dict)

# is there anything left to render?
undefined_template_variables = (
rendered_instructions.undefined_template_variables({})
)

# Check if it's the name of a question in the survey
for question_name in self.survey.question_names:
if question_name in undefined_template_variables:
print(
"Question name found in undefined_template_variables: ",
question_name,
)

if undefined_template_variables:
msg = f"Question instructions still has variables: {undefined_template_variables}."
import warnings

warnings.warn(msg)
# raise QuestionScenarioRenderError(
# f"Question instructions still has variables: {undefined_template_variables}."
# )

# Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
relevant_instructions = self.survey.relevant_instructions(
self.question.question_name
)

if relevant_instructions != []:
# preamble_text = Prompt(
# text="You were given the following instructions: "
# )
preamble_text = Prompt(text="")
for instruction in relevant_instructions:
preamble_text += instruction.text
rendered_instructions = preamble_text + rendered_instructions

return rendered_instructions

@property
def question_instructions_prompt(self) -> Prompt:
"""
Expand All @@ -118,109 +239,11 @@ def question_instructions_prompt(self) -> Prompt:
Prompt(text=\"""...
...
"""
# The user might have passed a custom prompt, which would be stored in _question_instructions_prompt
if not hasattr(self, "_question_instructions_prompt"):
# Gets the instructions for the question - this is how the question should be answered
question_prompt = Prompt(
self.question.get_instructions(model=self.model.model)
self._question_instructions_prompt = (
self.build_question_instructions_prompt()
)

# Get the data for the question - this is a dictionary of the question data
# e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
question_data = self.question.data.copy()

# check to see if the question_options is actually a string
# This is used when the user is using the question_options as a variable from a scenario
# if "question_options" in question_data:
if isinstance(self.question.data.get("question_options", None), str):
env = Environment()
parsed_content = env.parse(self.question.data["question_options"])
question_option_key = list(
meta.find_undeclared_variables(parsed_content)
)[0]

# look to see if the question_option_key is in the scenario
if isinstance(
question_options := self.scenario.get(question_option_key), list
):
question_data["question_options"] = question_options
self.question.question_options = question_options

# might be getting it from the prior answers
if self.prior_answers_dict().get(question_option_key) is not None:
prior_question = self.prior_answers_dict().get(question_option_key)
if hasattr(prior_question, "answer"):
if isinstance(prior_question.answer, list):
question_data["question_options"] = prior_question.answer
self.question.question_options = prior_question.answer
else:
placeholder_options = [
"N/A",
"Will be populated by prior answer",
"These are placeholder options",
]
question_data["question_options"] = placeholder_options
self.question.question_options = placeholder_options

replacement_dict = (
{key: f"<see file {key}>" for key in self.scenario_file_keys}
| question_data
| {
k: v
for k, v in self.scenario.items()
if k not in self.scenario_file_keys
} # don't include images in the replacement dict
| self.prior_answers_dict()
| {"agent": self.agent}
| {
"use_code": getattr(self.question, "_use_code", True),
"include_comment": getattr(
self.question, "_include_comment", False
),
}
)

rendered_instructions = question_prompt.render(replacement_dict)

# is there anything left to render?
undefined_template_variables = (
rendered_instructions.undefined_template_variables({})
)

# Check if it's the name of a question in the survey
for question_name in self.survey.question_names:
if question_name in undefined_template_variables:
print(
"Question name found in undefined_template_variables: ",
question_name,
)

if undefined_template_variables:
msg = f"Question instructions still has variables: {undefined_template_variables}."
import warnings

warnings.warn(msg)
# raise QuestionScenarioRenderError(
# f"Question instructions still has variables: {undefined_template_variables}."
# )

####################################
# Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
####################################
relevant_instructions = self.survey.relevant_instructions(
self.question.question_name
)

if relevant_instructions != []:
# preamble_text = Prompt(
# text="You were given the following instructions: "
# )
preamble_text = Prompt(text="")
for instruction in relevant_instructions:
preamble_text += instruction.text
rendered_instructions = preamble_text + rendered_instructions

self._question_instructions_prompt = rendered_instructions
return self._question_instructions_prompt

@property
Expand Down
5 changes: 4 additions & 1 deletion edsl/jobs/interviews/InterviewExceptionEntry.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def example(cls):
m = LanguageModel.example(test_model=True)
q = QuestionFreeText.example(exception_to_throw=ValueError)
results = q.by(m).run(
skip_retry=True, print_exceptions=False, raise_validation_errors=True
skip_retry=True,
print_exceptions=False,
raise_validation_errors=True,
disable_remote_inference=True,
)
return results.task_history.exceptions[0]["how_are_you"][0]

Expand Down
35 changes: 35 additions & 0 deletions tests/agents/test_prompt_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@
from edsl.questions import QuestionMultipleChoice as q


def test_option_expansion_from_current_answers():
from edsl import QuestionMultipleChoice, QuestionList, Survey
from edsl.agents.InvigilatorBase import InvigilatorBase

q0 = QuestionList(
question_text="What are some age levels?",
question_name="age_levels",
)
q1 = QuestionMultipleChoice(
question_text="Here is a question",
question_name="example_question",
question_options="{{ age_levels }}",
)
i = InvigilatorBase.example(question=q1, survey=Survey([q0, q1]))
i.current_answers = {"age_levels": ["10-20", "20-30"]}
assert "example_question" in i.prompt_constructor.prior_answers_dict()
assert "10-20" in i.prompt_constructor.question_instructions_prompt


def test_option_expansion_from_scenario():
from edsl import QuestionMultipleChoice, Scenario

q = QuestionMultipleChoice(
question_text="Here is a question",
question_name="example_question",
question_options="{{ age_levels }}",
)
from edsl.agents.InvigilatorBase import InvigilatorBase

i = InvigilatorBase.example(question=q)
i.scenario = Scenario({"age_levels": ["10-20", "20-30"]})
assert "10-20" in i.prompt_constructor.question_instructions_prompt
# breakpoint()


def test_system_prompt_traits_passed():
agent = Agent(traits={"age": 10, "hair": "brown", "height": 5.5})
i = agent._create_invigilator(question=q.example(), survey=q.example().to_survey())
Expand Down
Loading