Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
apostolosfilippas committed Sep 20, 2024
1 parent e924d4c commit 92afba8
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 63 deletions.
12 changes: 6 additions & 6 deletions edsl/agents/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ def data(self) -> dict:
if dynamic_traits_func:
func = inspect.getsource(dynamic_traits_func)
raw_data["dynamic_traits_function_source_code"] = func
raw_data["dynamic_traits_function_name"] = (
self.dynamic_traits_function_name
)
raw_data[
"dynamic_traits_function_name"
] = self.dynamic_traits_function_name
if hasattr(self, "answer_question_directly"):
raw_data.pop(
"answer_question_directly", None
Expand All @@ -604,9 +604,9 @@ def data(self) -> dict:
raw_data["answer_question_directly_source_code"] = inspect.getsource(
answer_question_directly_func
)
raw_data["answer_question_directly_function_name"] = (
self.answer_question_directly_function_name
)
raw_data[
"answer_question_directly_function_name"
] = self.answer_question_directly_function_name

return raw_data

Expand Down
27 changes: 15 additions & 12 deletions edsl/agents/PromptConstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@ class PromptComponent(enum.Enum):
def get_jinja2_variables(template_str: str) -> Set[str]:
"""
Extracts all variable names from a Jinja2 template using Jinja2's built-in parsing.
Args:
template_str (str): The Jinja2 template string
Returns:
Set[str]: A set of variable names found in the template
"""
env = Environment()
ast = env.parse(template_str)
return meta.find_undeclared_variables(ast)


class PromptList(UserList):
separator = Prompt(" ")

Expand Down Expand Up @@ -262,7 +263,7 @@ def prior_answers_dict(self) -> dict:
if (new_question := question.split("_comment")[0]) in d:
d[new_question].comment = answer
return d

@property
def question_image_keys(self):
raw_question_text = self.question.question_text
Expand All @@ -272,7 +273,7 @@ def question_image_keys(self):
if var in self.scenario_image_keys:
question_image_keys.append(var)
return question_image_keys

@property
def question_instructions_prompt(self) -> Prompt:
"""
Expand All @@ -285,8 +286,6 @@ def question_instructions_prompt(self) -> Prompt:
if not hasattr(self, "_question_instructions_prompt"):
question_prompt = self.question.get_instructions(model=self.model.model)



# Are any of the scenario values ImageInfo

question_data = self.question.data.copy()
Expand All @@ -295,7 +294,6 @@ def question_instructions_prompt(self) -> Prompt:
# This is used when the user is using the question_options as a variable from a sceario
# if "question_options" in question_data:
if isinstance(self.question.data.get("question_options", None), str):

env = Environment()
parsed_content = env.parse(self.question.data["question_options"])
question_option_key = list(
Expand All @@ -311,7 +309,11 @@ def question_instructions_prompt(self) -> Prompt:
replacement_dict = (
{key: "<see image>" for key in self.scenario_image_keys}
| question_data
| {k:v for k,v in self.scenario.items() if k not in self.scenario_image_keys} # don't include images in the replacement dict
| {
k: v
for k, v in self.scenario.items()
if k not in self.scenario_image_keys
} # don't include images in the replacement dict
| self.prior_answers_dict()
| {"agent": self.agent}
| {
Expand All @@ -322,14 +324,13 @@ def question_instructions_prompt(self) -> Prompt:
}
)


rendered_instructions = question_prompt.render(replacement_dict)

# is there anything left to render?
undefined_template_variables = (
rendered_instructions.undefined_template_variables({})
)

# Check if it's the name of a question in the survey
for question_name in self.survey.question_names:
if question_name in undefined_template_variables:
Expand Down Expand Up @@ -428,7 +429,9 @@ def get_prompts(self) -> Dict[str, Prompt]:
if len(self.question_image_keys) > 1:
raise ValueError("We can only handle one image per question.")
elif len(self.question_image_keys) == 1:
prompts["encoded_image"] = self.scenario[self.question_image_keys[0]].encoded_image
prompts["encoded_image"] = self.scenario[
self.question_image_keys[0]
].encoded_image

return prompts

Expand Down
2 changes: 2 additions & 0 deletions edsl/data_transfer_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import NamedTuple, Dict, List, Optional, Any
from dataclasses import dataclass


class ModelInputs(NamedTuple):
"This is what was send by the agent to the model"
user_prompt: str
Expand Down Expand Up @@ -53,6 +54,7 @@ class ImageInfo:
file_size: int
encoded_image: str


# from collections import UserDict


Expand Down
16 changes: 8 additions & 8 deletions edsl/jobs/interviews/Interview.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def __init__(
self.debug = debug
self.iteration = iteration
self.cache = cache
self.answers: dict[str, str] = (
Answers()
) # will get filled in as interview progresses
self.answers: dict[
str, str
] = Answers() # will get filled in as interview progresses
self.sidecar_model = sidecar_model

# self.stop_on_exception = False
Expand Down Expand Up @@ -418,11 +418,11 @@ def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
"""
current_question_index: int = self.to_index[current_question.question_name]

next_question: Union[int, EndOfSurvey] = (
self.survey.rule_collection.next_question(
q_now=current_question_index,
answers=self.answers | self.scenario | self.agent["traits"],
)
next_question: Union[
int, EndOfSurvey
] = self.survey.rule_collection.next_question(
q_now=current_question_index,
answers=self.answers | self.scenario | self.agent["traits"],
)

next_question_index = next_question.next_q
Expand Down
18 changes: 9 additions & 9 deletions edsl/jobs/runners/JobsRunnerAsyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,19 @@ async def _build_interview_task(

prompt_dictionary = {}
for answer_key_name in answer_key_names:
prompt_dictionary[answer_key_name + "_user_prompt"] = (
question_name_to_prompts[answer_key_name]["user_prompt"]
)
prompt_dictionary[answer_key_name + "_system_prompt"] = (
question_name_to_prompts[answer_key_name]["system_prompt"]
)
prompt_dictionary[
answer_key_name + "_user_prompt"
] = question_name_to_prompts[answer_key_name]["user_prompt"]
prompt_dictionary[
answer_key_name + "_system_prompt"
] = question_name_to_prompts[answer_key_name]["system_prompt"]

raw_model_results_dictionary = {}
for result in valid_results:
question_name = result.question_name
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
result.raw_model_response
)
raw_model_results_dictionary[
question_name + "_raw_model_response"
] = result.raw_model_response
raw_model_results_dictionary[question_name + "_cost"] = result.cost
one_use_buys = (
"NA"
Expand Down
1 change: 0 additions & 1 deletion edsl/jobs/runners/JobsRunnerStatus.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def elapsed_time(self):


class JobsRunnerStatus:

def __init__(
self,
jobs_runner: "JobsRunnerAsyncio",
Expand Down
8 changes: 5 additions & 3 deletions edsl/questions/QuestionBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def fake_data_factory(self):
if not hasattr(self, "_fake_data_factory"):
from polyfactory.factories.pydantic_factory import ModelFactory

class FakeData(ModelFactory[self.response_model]): ...
class FakeData(ModelFactory[self.response_model]):
...

self._fake_data_factory = FakeData
return self._fake_data_factory
Expand Down Expand Up @@ -477,12 +478,13 @@ def html(
if scenario is None:
scenario = {}


prior_answers_dict = {}

if isinstance(answers, dict):
for key, value in answers.items():
if not key.endswith("_comment") and not key.endswith("_generated_tokens"):
if not key.endswith("_comment") and not key.endswith(
"_generated_tokens"
):
prior_answers_dict[key] = {"answer": value}

# breakpoint()
Expand Down
1 change: 0 additions & 1 deletion edsl/questions/QuestionBudget.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def fix(self, response, verbose=False):
def create_budget_model(
budget_sum: float, permissive: bool, question_options: List[str]
):

class BudgetResponse(BaseModel):
answer: List[float] = Field(
...,
Expand Down
5 changes: 2 additions & 3 deletions edsl/questions/QuestionFreeText.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ class FreeTextResponseValidator(ResponseValidatorABC):

def fix(self, response, verbose=False):
return {
'answer': str(response.get('generated_tokens')),
'generated_tokens': str(response.get('generated_tokens'))
"answer": str(response.get("generated_tokens")),
"generated_tokens": str(response.get("generated_tokens")),
}



class QuestionFreeText(QuestionBase):
"""This question prompts the agent to respond with free text."""

Expand Down
40 changes: 25 additions & 15 deletions edsl/questions/Quick.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from edsl import QuestionFreeText, QuestionMultipleChoice, Survey, QuestionList, Question
from edsl import (
QuestionFreeText,
QuestionMultipleChoice,
Survey,
QuestionList,
Question,
)

def Quick(question_text):

def Quick(question_text):
q_type = QuestionMultipleChoice(
question_text = f"A researcher is asking a language model this: {question_text}. What is the most appropriate type of question to ask?",
question_name = "potential_question_type",
question_options = ["multiple_choice", "list", "free_text"])

question_text=f"A researcher is asking a language model this: {question_text}. What is the most appropriate type of question to ask?",
question_name="potential_question_type",
question_options=["multiple_choice", "list", "free_text"],
)

q_name = QuestionFreeText(
question_text = f"A researcher is asking a language model this: {question_text}. What is a good name for this question that's a valid python identifier? Just return the proposed identifer",
question_name = "potential_question_name")

q_options = QuestionList(question_text = f"A research is asking this question: { question_text }. What are the possible options for this question?",
question_name = "potential_question_options")

survey = Survey([q_type, q_name, q_options]).add_skip_rule(q_options, "{{ potential_question_type }} != 'multiple_choice'")
question_text=f"A researcher is asking a language model this: {question_text}. What is a good name for this question that's a valid python identifier? Just return the proposed identifer",
question_name="potential_question_name",
)

q_options = QuestionList(
question_text=f"A research is asking this question: { question_text }. What are the possible options for this question?",
question_name="potential_question_options",
)

survey = Survey([q_type, q_name, q_options]).add_skip_rule(
q_options, "{{ potential_question_type }} != 'multiple_choice'"
)
return survey
# results = survey.run()
# question_type = results.select("potential_question_type").first()
Expand All @@ -27,5 +39,3 @@ def Quick(question_text):
# return Question(question_type, question_name = question_name)
# else:
# return Question(question_type, question_name = question_name, question_options = question_options)


2 changes: 0 additions & 2 deletions edsl/scenarios/FileStore.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def view(self):


class SQLiteFileStore(FileStore):

def __init__(
self,
filename,
Expand Down Expand Up @@ -309,7 +308,6 @@ def view(self):


class HTMLFileStore(FileStore):

def __init__(
self,
filename,
Expand Down
6 changes: 4 additions & 2 deletions edsl/scenarios/Scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def from_url(cls, url: str, field_name: Optional[str] = "text") -> "Scenario":
return cls({"url": url, field_name: text})

@classmethod
def from_image(cls, image_path: str, image_name: Optional[str] = None) -> 'Scenario':
def from_image(
cls, image_path: str, image_name: Optional[str] = None
) -> "Scenario":
"""
Creates a scenario with a base64 encoding of an image.
Expand Down Expand Up @@ -489,7 +491,7 @@ def rich_print(self) -> "Table":
return table

@classmethod
def example(cls, randomize: bool = False, has_image = False) -> Scenario:
def example(cls, randomize: bool = False, has_image=False) -> Scenario:
"""
Returns an example Scenario instance.
Expand Down
1 change: 0 additions & 1 deletion edsl/scenarios/ScenarioListPdfMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def fetch_and_save_pdf(url, filename):


class ScenarioListPdfMixin:

@classmethod
def from_pdf(cls, filename_or_url, collapse_pages=False):
# Check if the input is a URL
Expand Down

0 comments on commit 92afba8

Please sign in to comment.