From fcbf8510d1e1b55f875dac67164ae63a8ab31976 Mon Sep 17 00:00:00 2001 From: John Horton Date: Mon, 27 May 2024 17:01:47 -0400 Subject: [PATCH] Some helper functions for Agent/AgentList --- edsl/agents/Agent.py | 28 ++++++++++++++++- edsl/agents/AgentList.py | 57 ++++++++++++++++++++++++++++++++++ edsl/scenarios/ScenarioList.py | 14 +++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/edsl/agents/Agent.py b/edsl/agents/Agent.py index 75ce68d8..ebaadd6d 100644 --- a/edsl/agents/Agent.py +++ b/edsl/agents/Agent.py @@ -63,7 +63,7 @@ def __init__( ): """Initialize a new instance of Agent. - :param traits: A dictionary of traits that the agent has. The keys need to be + :param traits: A dictionary of traits that the agent has. The keys need to be valid identifiers. :param name: A name for the agent :param codebook: A codebook mapping trait keys to trait descriptions. :param instruction: Instructions for the agent in how to answer questions. @@ -485,6 +485,32 @@ def _table(self) -> tuple[dict, list]: table_data.append({"Attribute": attr_name, "Value": repr(attr_value)}) column_names = ["Attribute", "Value"] return table_data, column_names + + def remove_trait(self, trait: str) -> Agent: + """Remove a trait from the agent. + + Example usage: + + >>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5}) + >>> a.remove_trait("age") + Agent(traits = {'hair': 'brown', 'height': 5.5}) + """ + _ = self.traits.pop(trait) + return self + + def translate_traits(self, values_codebook: dict) -> Agent: + """Translate traits to a new codebook. + + >>> a = Agent(traits = {"age": 10, "hair": 1, "height": 5.5}) + >>> a.translate_traits({"hair": {1:"brown"}}) + Agent(traits = {'age': 10, 'hair': 'brown', 'height': 5.5}) + + :param values_codebook: The new codebook. + """ + for key, value in self.traits.items(): + if key in values_codebook: + self.traits[key] = values_codebook[key][value] + return self def rich_print(self): """Display an object as a rich table. diff --git a/edsl/agents/AgentList.py b/edsl/agents/AgentList.py index 011c4327..06812ca4 100644 --- a/edsl/agents/AgentList.py +++ b/edsl/agents/AgentList.py @@ -15,6 +15,7 @@ from rich import print_json from rich.table import Table import json +import csv from edsl.Base import Base from edsl.agents import Agent @@ -37,6 +38,62 @@ def __init__(self, data: Optional[list[Agent]] = None): else: super().__init__() + @classmethod + def from_csv(cls, file_path: str): + """Load AgentList from a CSV file. + + >>> import csv + >>> import os + >>> with open('/tmp/agents.csv', 'w') as f: + ... writer = csv.writer(f) + ... _ = writer.writerow(['age', 'hair', 'height']) + ... _ = writer.writerow([22, 'brown', 5.5]) + >>> al = AgentList.from_csv('/tmp/agents.csv') + >>> al + AgentList([Agent(traits = {'age': '22', 'hair': 'brown', 'height': '5.5'})]) + >>> os.remove('/tmp/agents.csv') + + :param file_path: The path to the CSV file. + """ + agent_list = [] + with open(file_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + agent_list.append(Agent(row)) + return cls(agent_list) + + def translate_traits(self, values_codebook: dict[str, str]): + """Translate traits to a new codebook. + + :param codebook: The new codebook. + """ + for agent in self.data: + agent.translate_traits(codebook) + return self + + def remove_trait(self, trait: str): + """Remove traits from the AgentList. + + :param traits: The traits to remove. + + >>> al = AgentList([Agent({'age': 22, 'hair': 'brown', 'height': 5.5}), Agent({'age': 22, 'hair': 'brown', 'height': 5.5})]) + >>> al.remove_trait('age') + AgentList([Agent(traits = {'hair': 'brown', 'height': 5.5}), Agent(traits = {'hair': 'brown', 'height': 5.5})]) + """ + for agent in self.data: + _ = agent.remove_trait(trait) + return self + + @staticmethod + def get_codebook(file_path: str): + """Return the codebook for a CSV file. + + :param file_path: The path to the CSV file. + """ + with open(file_path, "r") as f: + reader = csv.DictReader(f) + return {field: None for field in reader.fieldnames} + @add_edsl_version def to_dict(self): """Return dictionary of AgentList to serialization. diff --git a/edsl/scenarios/ScenarioList.py b/edsl/scenarios/ScenarioList.py index 48634708..248079a6 100644 --- a/edsl/scenarios/ScenarioList.py +++ b/edsl/scenarios/ScenarioList.py @@ -176,6 +176,20 @@ def __getitem__(self, key: Union[int, slice]) -> Any: return super().__getitem__(key) else: return self.to_dict()[key] + + def to_agent_list(self): + """Convert the ScenarioList to an AgentList. + + >>> s = ScenarioList([Scenario({'age': 22, 'hair': 'brown', 'height': 5.5}), Scenario({'age': 22, 'hair': 'brown', 'height': 5.5})]) + >>> s.to_agent_list() + AgentList([Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5}), Agent(traits = {'age': 22, 'hair': 'brown', 'height': 5.5})]) + + + """ + from edsl.agents.AgentList import AgentList + from edsl.agents.Agent import Agent + + return AgentList([Agent(traits = s.data) for s in self]) if __name__ == "__main__":