Skip to content

Commit

Permalink
Merge pull request #479 from expectedparrot/survey_sections
Browse files Browse the repository at this point in the history
Modifying Surveys to support human-use
  • Loading branch information
johnjosephhorton authored May 12, 2024
2 parents 0c41949 + 42ece21 commit a4c909d
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 102 deletions.
2 changes: 1 addition & 1 deletion edsl/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.20"
__version__ = "0.1.21"
29 changes: 16 additions & 13 deletions edsl/jobs/Jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,38 @@ def by(

setattr(self, objects_key, new_objects) # update the job
return self

def prompts(self):
from edsl.results.Dataset import Dataset
interviews = self.interviews()
#data = []

interviews = self.interviews()
# data = []
interview_indices = []
question_indices = []
user_prompts = []
system_prompts = []
scenario_indices = []

for interview_index, interview in enumerate(interviews):
invigilators = list(interview._build_invigilators(debug = False))
invigilators = list(interview._build_invigilators(debug=False))
for question_index, invigilator in enumerate(invigilators):
prompts = invigilator.get_prompts()
user_prompts.append(prompts["user_prompt"])
system_prompts.append(prompts["system_prompt"])
interview_indices.append(interview_index)
scenario_indices.append(invigilator.scenario)
question_indices.append(invigilator.question.question_name)
#breakpoint()
return Dataset([
{'interview_index': interview_indices},
{'question_index': question_indices},
{'user_prompt': user_prompts},
{'scenario_index': scenario_indices},
{'system_prompt': system_prompts}])


# breakpoint()
return Dataset(
[
{"interview_index": interview_indices},
{"question_index": question_indices},
{"user_prompt": user_prompts},
{"scenario_index": scenario_indices},
{"system_prompt": system_prompts},
]
)

@staticmethod
def _turn_args_to_list(args):
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments."""
Expand Down
42 changes: 24 additions & 18 deletions edsl/results/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@

from edsl.results.ResultsExportMixin import ResultsExportMixin


class Dataset(UserList, ResultsExportMixin):
"""A class to represent a dataset of observations."""

def __init__(self, data: list[dict[str, Any]] = None):
"""Initialize the dataset with the given data."""
super().__init__(data)


def __len__(self) -> int:
"""Return the number of observations in the dataset.
Need to override the __len__ method to return the number of observations in the dataset because
Need to override the __len__ method to return the number of observations in the dataset because
otherwise, the UserList class would return the number of dictionaries in the dataset.
"""
#breakpoint()
# breakpoint()
_, values = list(self.data[0].items())[0]
return len(values)

Expand Down Expand Up @@ -56,10 +56,10 @@ def _repr_html_(self) -> str:

return data_to_html(self.data)

def shuffle(self, seed = None) -> Dataset:
def shuffle(self, seed=None) -> Dataset:
if seed is not None:
random.seed(seed)
random.seed(seed)

indices = None

for entry in self:
Expand All @@ -70,43 +70,49 @@ def shuffle(self, seed = None) -> Dataset:
entry[key] = [values[i] for i in indices]

return self

def sample(self, n:int = None, frac:float = None, with_replacement:bool = True, seed = None) -> 'Dataset':

def sample(
self,
n: int = None,
frac: float = None,
with_replacement: bool = True,
seed=None,
) -> "Dataset":
if seed is not None:
random.seed(seed)

# Validate the input for sampling parameters
if n is None and frac is None:
raise ValueError("Either 'n' or 'frac' must be provided for sampling.")
if n is not None and frac is not None:
raise ValueError("Only one of 'n' or 'frac' should be specified.")

# Get the length of the lists from the first entry
first_key, first_values = list(self[0].items())[0]
total_length = len(first_values)

# Determine the number of samples based on 'n' or 'frac'
if n is None:
n = int(total_length * frac)

if not with_replacement and n > total_length:
raise ValueError("Sample size cannot be greater than the number of available elements when sampling without replacement.")

raise ValueError(
"Sample size cannot be greater than the number of available elements when sampling without replacement."
)

# Sample indices based on the method chosen
if with_replacement:
indices = [random.randint(0, total_length - 1) for _ in range(n)]
else:
indices = random.sample(range(total_length), k=n)

# Apply the same indices to all entries
for entry in self:
key, values = list(entry.items())[0]
entry[key] = [values[i] for i in indices]

return self



def order_by(self, sort_key: str, reverse: bool = False) -> Dataset:
"""Return a new dataset with the observations sorted by the given key."""

Expand Down
4 changes: 3 additions & 1 deletion edsl/results/Result.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def sub_dicts(self) -> dict[str, dict]:
] = self.question_to_attributes[key]["question_type"]

return {
"agent": self.agent.traits | {"agent_name": agent_name} | {"agent_instruction": self.agent.instruction},
"agent": self.agent.traits
| {"agent_name": agent_name}
| {"agent_instruction": self.agent.instruction},
"scenario": self.scenario,
"model": self.model.parameters | {"model": self.model.model},
"answer": self.answer,
Expand Down
35 changes: 21 additions & 14 deletions edsl/results/Results.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def __add__(self, other: Results) -> Results:
>>> r3 = r + r2
"""
if self.survey != other.survey:
raise Exception("The surveys are not the same so they cannot be added together.")
raise Exception(
"The surveys are not the same so they cannot be added together."
)
if self.created_columns != other.created_columns:
raise Exception(
"The created columns are not the same so they cannot be added together."
Expand Down Expand Up @@ -493,8 +495,8 @@ def new_result(old_result: Result, var_name: str) -> Result:
data=new_data,
created_columns=self.created_columns + [var_name],
)
def shuffle(self, seed = None) -> Results:

def shuffle(self, seed=None) -> Results:
"""Shuffle the results.
Example:
Expand All @@ -508,8 +510,14 @@ def shuffle(self, seed = None) -> Results:
new_data = self.data.copy()
random.shuffle(new_data)
return Results(survey=self.survey, data=new_data, created_columns=None)

def sample(self, n:int = None, frac:float = None, with_replacement:bool = True, seed = None) -> Results:

def sample(
self,
n: int = None,
frac: float = None,
with_replacement: bool = True,
seed=None,
) -> Results:
"""Sample the results.
:param n: An integer representing the number of samples to take.
Expand All @@ -527,18 +535,18 @@ def sample(self, n:int = None, frac:float = None, with_replacement:bool = True,

if n is None and frac is None:
raise Exception("You must specify either n or frac.")

if n is not None and frac is not None:
raise Exception("You cannot specify both n and frac.")

if frac is not None and n is None:
n = int(frac * len(self.data))

if with_replacement:
new_data = random.choices(self.data, k = n)
new_data = random.choices(self.data, k=n)
else:
new_data = random.sample(self.data, n)

return Results(survey=self.survey, data=new_data, created_columns=None)

def select(self, *columns: Union[str, list[str]]) -> Dataset:
Expand Down Expand Up @@ -665,13 +673,13 @@ def sort_key(item):
for col in columns:
# Parse the column into its data type and key
data_type, key = self._parse_column(col)

# Retrieve the value from the item based on the parsed data type and key
value = item.get_value(data_type, key)

# Convert the value to numeric if possible, and append it to the key components
key_components.append(to_numeric_if_possible(value))

# Convert the list of key components into a tuple to serve as the sorting key
return tuple(key_components)

Expand All @@ -682,7 +690,6 @@ def sort_key(item):
)
return Results(survey=self.survey, data=new_data, created_columns=None)


# def sort_by(self, column, reverse: bool = False) -> Results:
# """Sort the results by a column.

Expand Down
1 change: 1 addition & 0 deletions edsl/scenarios/Scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
remove_edsl_version,
)


class Scenario(Base, UserDict):
"""A Scenario is a dictionary of keys/values for parameterizing questions."""

Expand Down
Loading

0 comments on commit a4c909d

Please sign in to comment.