Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
apostolosfilippas committed May 12, 2024
1 parent d81576e commit ad123c2
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 67 deletions.
29 changes: 16 additions & 13 deletions edsl/jobs/Jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,38 @@ def by(

setattr(self, objects_key, new_objects) # update the job
return self

def prompts(self):
from edsl.results.Dataset import Dataset
interviews = self.interviews()
#data = []

interviews = self.interviews()
# data = []
interview_indices = []
question_indices = []
user_prompts = []
system_prompts = []
scenario_indices = []

for interview_index, interview in enumerate(interviews):
invigilators = list(interview._build_invigilators(debug = False))
invigilators = list(interview._build_invigilators(debug=False))
for question_index, invigilator in enumerate(invigilators):
prompts = invigilator.get_prompts()
user_prompts.append(prompts["user_prompt"])
system_prompts.append(prompts["system_prompt"])
interview_indices.append(interview_index)
scenario_indices.append(invigilator.scenario)
question_indices.append(invigilator.question.question_name)
#breakpoint()
return Dataset([
{'interview_index': interview_indices},
{'question_index': question_indices},
{'user_prompt': user_prompts},
{'scenario_index': scenario_indices},
{'system_prompt': system_prompts}])


# breakpoint()
return Dataset(
[
{"interview_index": interview_indices},
{"question_index": question_indices},
{"user_prompt": user_prompts},
{"scenario_index": scenario_indices},
{"system_prompt": system_prompts},
]
)

@staticmethod
def _turn_args_to_list(args):
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments."""
Expand Down
42 changes: 24 additions & 18 deletions edsl/results/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@

from edsl.results.ResultsExportMixin import ResultsExportMixin


class Dataset(UserList, ResultsExportMixin):
"""A class to represent a dataset of observations."""

def __init__(self, data: list[dict[str, Any]] = None):
"""Initialize the dataset with the given data."""
super().__init__(data)


def __len__(self) -> int:
"""Return the number of observations in the dataset.
Need to override the __len__ method to return the number of observations in the dataset because
Need to override the __len__ method to return the number of observations in the dataset because
otherwise, the UserList class would return the number of dictionaries in the dataset.
"""
#breakpoint()
# breakpoint()
_, values = list(self.data[0].items())[0]
return len(values)

Expand Down Expand Up @@ -56,10 +56,10 @@ def _repr_html_(self) -> str:

return data_to_html(self.data)

def shuffle(self, seed = None) -> Dataset:
def shuffle(self, seed=None) -> Dataset:
if seed is not None:
random.seed(seed)
random.seed(seed)

indices = None

for entry in self:
Expand All @@ -70,43 +70,49 @@ def shuffle(self, seed = None) -> Dataset:
entry[key] = [values[i] for i in indices]

return self

def sample(self, n:int = None, frac:float = None, with_replacement:bool = True, seed = None) -> 'Dataset':

def sample(
self,
n: int = None,
frac: float = None,
with_replacement: bool = True,
seed=None,
) -> "Dataset":
if seed is not None:
random.seed(seed)

# Validate the input for sampling parameters
if n is None and frac is None:
raise ValueError("Either 'n' or 'frac' must be provided for sampling.")
if n is not None and frac is not None:
raise ValueError("Only one of 'n' or 'frac' should be specified.")

# Get the length of the lists from the first entry
first_key, first_values = list(self[0].items())[0]
total_length = len(first_values)

# Determine the number of samples based on 'n' or 'frac'
if n is None:
n = int(total_length * frac)

if not with_replacement and n > total_length:
raise ValueError("Sample size cannot be greater than the number of available elements when sampling without replacement.")

raise ValueError(
"Sample size cannot be greater than the number of available elements when sampling without replacement."
)

# Sample indices based on the method chosen
if with_replacement:
indices = [random.randint(0, total_length - 1) for _ in range(n)]
else:
indices = random.sample(range(total_length), k=n)

# Apply the same indices to all entries
for entry in self:
key, values = list(entry.items())[0]
entry[key] = [values[i] for i in indices]

return self



def order_by(self, sort_key: str, reverse: bool = False) -> Dataset:
"""Return a new dataset with the observations sorted by the given key."""

Expand Down
4 changes: 3 additions & 1 deletion edsl/results/Result.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def sub_dicts(self) -> dict[str, dict]:
] = self.question_to_attributes[key]["question_type"]

return {
"agent": self.agent.traits | {"agent_name": agent_name} | {"agent_instruction": self.agent.instruction},
"agent": self.agent.traits
| {"agent_name": agent_name}
| {"agent_instruction": self.agent.instruction},
"scenario": self.scenario,
"model": self.model.parameters | {"model": self.model.model},
"answer": self.answer,
Expand Down
35 changes: 21 additions & 14 deletions edsl/results/Results.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def __add__(self, other: Results) -> Results:
>>> r3 = r + r2
"""
if self.survey != other.survey:
raise Exception("The surveys are not the same so they cannot be added together.")
raise Exception(
"The surveys are not the same so they cannot be added together."
)
if self.created_columns != other.created_columns:
raise Exception(
"The created columns are not the same so they cannot be added together."
Expand Down Expand Up @@ -493,8 +495,8 @@ def new_result(old_result: Result, var_name: str) -> Result:
data=new_data,
created_columns=self.created_columns + [var_name],
)
def shuffle(self, seed = None) -> Results:

def shuffle(self, seed=None) -> Results:
"""Shuffle the results.
Example:
Expand All @@ -508,8 +510,14 @@ def shuffle(self, seed = None) -> Results:
new_data = self.data.copy()
random.shuffle(new_data)
return Results(survey=self.survey, data=new_data, created_columns=None)

def sample(self, n:int = None, frac:float = None, with_replacement:bool = True, seed = None) -> Results:

def sample(
self,
n: int = None,
frac: float = None,
with_replacement: bool = True,
seed=None,
) -> Results:
"""Sample the results.
:param n: An integer representing the number of samples to take.
Expand All @@ -527,18 +535,18 @@ def sample(self, n:int = None, frac:float = None, with_replacement:bool = True,

if n is None and frac is None:
raise Exception("You must specify either n or frac.")

if n is not None and frac is not None:
raise Exception("You cannot specify both n and frac.")

if frac is not None and n is None:
n = int(frac * len(self.data))

if with_replacement:
new_data = random.choices(self.data, k = n)
new_data = random.choices(self.data, k=n)
else:
new_data = random.sample(self.data, n)

return Results(survey=self.survey, data=new_data, created_columns=None)

def select(self, *columns: Union[str, list[str]]) -> Dataset:
Expand Down Expand Up @@ -665,13 +673,13 @@ def sort_key(item):
for col in columns:
# Parse the column into its data type and key
data_type, key = self._parse_column(col)

# Retrieve the value from the item based on the parsed data type and key
value = item.get_value(data_type, key)

# Convert the value to numeric if possible, and append it to the key components
key_components.append(to_numeric_if_possible(value))

# Convert the list of key components into a tuple to serve as the sorting key
return tuple(key_components)

Expand All @@ -682,7 +690,6 @@ def sort_key(item):
)
return Results(survey=self.survey, data=new_data, created_columns=None)


# def sort_by(self, column, reverse: bool = False) -> Results:
# """Sort the results by a column.

Expand Down
1 change: 1 addition & 0 deletions edsl/scenarios/Scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
remove_edsl_version,
)


class Scenario(Base, UserDict):
"""A Scenario is a dictionary of keys/values for parameterizing questions."""

Expand Down
52 changes: 33 additions & 19 deletions edsl/surveys/Survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(

if name is not None:
import warnings

warnings.warn("name is deprecated.")

# @property
Expand Down Expand Up @@ -175,30 +176,41 @@ def _set_memory_plan(self, prior_questions_func):
prior_questions=prior_questions_func(i),
)

def add_question_group(self,
start_question: Union[QuestionBase, str],
end_question: Union[QuestionBase, str],
group_name: str) -> None:
def add_question_group(
self,
start_question: Union[QuestionBase, str],
end_question: Union[QuestionBase, str],
group_name: str,
) -> None:
"""Add a group of questions to the survey."""

if not group_name.isidentifier():
raise ValueError(f"Group name {group_name} is not a valid identifier.")

if group_name in self.question_groups:
raise ValueError(f"Group name {group_name} already exists in the survey.")

if group_name in self.question_name_to_index:
raise ValueError(f"Group name {group_name} already exists as a question name in the survey.")
raise ValueError(
f"Group name {group_name} already exists as a question name in the survey."
)

start_index = self._get_question_index(start_question)
end_index = self._get_question_index(end_question)

if start_index > end_index:
raise ValueError(f"Start index {start_index} is greater than end index {end_index}.")
raise ValueError(
f"Start index {start_index} is greater than end index {end_index}."
)

for existing_group_name, (existing_start_index, existing_end_index) in self.question_groups.items():
for existing_group_name, (
existing_start_index,
existing_end_index,
) in self.question_groups.items():
if start_index < existing_start_index and end_index > existing_end_index:
raise ValueError(f"Group {group_name} contains the questions in the new group.")
raise ValueError(
f"Group {group_name} contains the questions in the new group."
)
if start_index > existing_start_index and end_index < existing_end_index:
raise ValueError(f"Group {group_name} is contained in the new group.")
if start_index < existing_start_index and end_index > existing_start_index:
Expand All @@ -207,8 +219,8 @@ def add_question_group(self,
raise ValueError(f"Group {group_name} overlaps with the new group.")

self.question_groups[group_name] = (start_index, end_index)
#print("Added group")
#print(self.question_groups)
# print("Added group")
# print(self.question_groups)
return self

def add_targeted_memory(
Expand Down Expand Up @@ -441,7 +453,7 @@ def gen_path_through_survey(self) -> Generator[QuestionBase, dict, None]:
question = self.first_question()
while not question == EndOfSurvey:
self.answers = yield question
## TODO: This should also include survey and agent attributes
## TODO: This should also include survey and agent attributes
question = self.next_question(question, self.answers)

@property
Expand Down Expand Up @@ -525,6 +537,7 @@ def __getitem__(self, index) -> QuestionBase:

def diff(self, other):
from rich import print

for key, value in self.to_dict().items():
if value != other.to_dict()[key]:
print(f"Key: {key}")
Expand Down Expand Up @@ -559,11 +572,12 @@ def from_dict(cls, data: dict) -> Survey:
"""Deserialize the dictionary back to a Survey object."""
questions = [QuestionBase.from_dict(q_dict) for q_dict in data["questions"]]
memory_plan = MemoryPlan.from_dict(data["memory_plan"])
survey = cls(questions=questions,
memory_plan=memory_plan,
rule_collection=RuleCollection.from_dict(data["rule_collection"]),
question_groups = data["question_groups"]
)
survey = cls(
questions=questions,
memory_plan=memory_plan,
rule_collection=RuleCollection.from_dict(data["rule_collection"]),
question_groups=data["question_groups"],
)
return survey

###################
Expand Down
4 changes: 2 additions & 2 deletions edsl/surveys/SurveyExportMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class SurveyExportMixin:
"""A mixin class for exporting surveys to different formats."""

def docx(self, filename = None) -> Union["Document", None]:
def docx(self, filename=None) -> Union["Document", None]:
"""Generate a docx document for the survey."""
doc = Document()
doc.add_heading("EDSL Survey")
Expand Down Expand Up @@ -58,7 +58,7 @@ def code(self, filename: str = None, survey_var_name: str = "survey") -> list[st

return formatted_code

def html(self, filename = None) -> str:
def html(self, filename=None) -> str:
"""Generate the html for the survey."""
html_text = []
for question in self._questions:
Expand Down

0 comments on commit ad123c2

Please sign in to comment.