Skip to content

Commit

Permalink
Merge pull request #569 from expectedparrot/agent_list_helper_methods
Browse files Browse the repository at this point in the history
Agent list helper methods
  • Loading branch information
apostolosfilippas committed May 27, 2024
2 parents f29cb61 + fcbf851 commit 5ed9ea6
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
28 changes: 27 additions & 1 deletion edsl/agents/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
):
"""Initialize a new instance of Agent.
:param traits: A dictionary of traits that the agent has. The keys need to be
:param traits: A dictionary of traits that the agent has. The keys need to be valid identifiers.
:param name: A name for the agent
:param codebook: A codebook mapping trait keys to trait descriptions.
:param instruction: Instructions for the agent in how to answer questions.
Expand Down Expand Up @@ -485,6 +485,32 @@ def _table(self) -> tuple[dict, list]:
table_data.append({"Attribute": attr_name, "Value": repr(attr_value)})
column_names = ["Attribute", "Value"]
return table_data, column_names

def remove_trait(self, trait: str) -> Agent:
"""Remove a trait from the agent.
Example usage:
>>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
>>> a.remove_trait("age")
Agent(traits = {'hair': 'brown', 'height': 5.5})
"""
_ = self.traits.pop(trait)
return self

def translate_traits(self, values_codebook: dict) -> Agent:
"""Translate traits to a new codebook.
>>> a = Agent(traits = {"age": 10, "hair": 1, "height": 5.5})
>>> a.translate_traits({"hair": {1:"brown"}})
Agent(traits = {'age': 10, 'hair': 'brown', 'height': 5.5})
:param values_codebook: The new codebook.
"""
for key, value in self.traits.items():
if key in values_codebook:
self.traits[key] = values_codebook[key][value]
return self

def rich_print(self):
"""Display an object as a rich table.
Expand Down
57 changes: 57 additions & 0 deletions edsl/agents/AgentList.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rich import print_json
from rich.table import Table
import json
import csv

from edsl.Base import Base
from edsl.agents import Agent
Expand All @@ -37,6 +38,62 @@ def __init__(self, data: Optional[list[Agent]] = None):
else:
super().__init__()

@classmethod
def from_csv(cls, file_path: str):
"""Load AgentList from a CSV file.
>>> import csv
>>> import os
>>> with open('/tmp/agents.csv', 'w') as f:
... writer = csv.writer(f)
... _ = writer.writerow(['age', 'hair', 'height'])
... _ = writer.writerow([22, 'brown', 5.5])
>>> al = AgentList.from_csv('/tmp/agents.csv')
>>> al
AgentList([Agent(traits = {'age': '22', 'hair': 'brown', 'height': '5.5'})])
>>> os.remove('/tmp/agents.csv')
:param file_path: The path to the CSV file.
"""
agent_list = []
with open(file_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
agent_list.append(Agent(row))
return cls(agent_list)

def translate_traits(self, values_codebook: dict[str, str]):
"""Translate traits to a new codebook.
:param codebook: The new codebook.
"""
for agent in self.data:
agent.translate_traits(codebook)
return self

def remove_trait(self, trait: str):
"""Remove traits from the AgentList.
:param traits: The traits to remove.
>>> al = AgentList([Agent({'age': 22, 'hair': 'brown', 'height': 5.5}), Agent({'age': 22, 'hair': 'brown', 'height': 5.5})])
>>> al.remove_trait('age')
AgentList([Agent(traits = {'hair': 'brown', 'height': 5.5}), Agent(traits = {'hair': 'brown', 'height': 5.5})])
"""
for agent in self.data:
_ = agent.remove_trait(trait)
return self

@staticmethod
def get_codebook(file_path: str):
"""Return the codebook for a CSV file.
:param file_path: The path to the CSV file.
"""
with open(file_path, "r") as f:
reader = csv.DictReader(f)
return {field: None for field in reader.fieldnames}

@add_edsl_version
def to_dict(self):
"""Return dictionary of AgentList to serialization.
Expand Down
14 changes: 14 additions & 0 deletions edsl/scenarios/ScenarioList.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,20 @@ def __getitem__(self, key: Union[int, slice]) -> Any:
return super().__getitem__(key)
else:
return self.to_dict()[key]

def to_agent_list(self):
"""Convert the ScenarioList to an AgentList.
>>> s = ScenarioList([Scenario({'age': 22, 'hair': 'brown', 'height': 5.5}), Scenario({'age': 22, 'hair': 'brown', 'height': 5.5})])
>>> s.to_agent_list()
AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5})])
"""
from edsl.agents.AgentList import AgentList
from edsl.agents.Agent import Agent

return AgentList([Agent(traits = s.data) for s in self])


if __name__ == "__main__":
Expand Down

0 comments on commit 5ed9ea6

Please sign in to comment.