18
18
from random import Random
19
19
20
20
# mypy
21
- from typing import TYPE_CHECKING , Any
21
+ from typing import TYPE_CHECKING , Any , List
22
22
23
23
if TYPE_CHECKING :
24
24
# We ensure that these are not imported during runtime to prevent cyclic
@@ -127,11 +127,11 @@ def __contains__(self, agent: Agent) -> bool:
127
127
return agent in self ._agents
128
128
129
129
def select (
130
- self ,
131
- filter_func : Callable [[Agent ], bool ] | None = None ,
132
- n : int = 0 ,
133
- inplace : bool = False ,
134
- agent_type : type [Agent ] | None = None ,
130
+ self ,
131
+ filter_func : Callable [[Agent ], bool ] | None = None ,
132
+ n : int = 0 ,
133
+ inplace : bool = False ,
134
+ agent_type : type [Agent ] | None = None ,
135
135
) -> AgentSet :
136
136
"""
137
137
Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
@@ -154,7 +154,7 @@ def agent_generator(filter_func=None, agent_type=None, n=0):
154
154
count = 0
155
155
for agent in self :
156
156
if (not filter_func or filter_func (agent )) and (
157
- not agent_type or isinstance (agent , agent_type )
157
+ not agent_type or isinstance (agent , agent_type )
158
158
):
159
159
yield agent
160
160
count += 1
@@ -191,10 +191,10 @@ def shuffle(self, inplace: bool = False) -> AgentSet:
191
191
)
192
192
193
193
def sort (
194
- self ,
195
- key : Callable [[Agent ], Any ] | str ,
196
- ascending : bool = False ,
197
- inplace : bool = False ,
194
+ self ,
195
+ key : Callable [[Agent ], Any ] | str ,
196
+ ascending : bool = False ,
197
+ inplace : bool = False ,
198
198
) -> AgentSet :
199
199
"""
200
200
Sort the agents in the AgentSet based on a specified attribute or custom function.
@@ -227,7 +227,7 @@ def _update(self, agents: Iterable[Agent]):
227
227
return self
228
228
229
229
def do (
230
- self , method_name : str , * args , return_results : bool = False , ** kwargs
230
+ self , method_name : str , * args , return_results : bool = False , ** kwargs
231
231
) -> AgentSet | list [Any ]:
232
232
"""
233
233
Invoke a method on each agent in the AgentSet.
@@ -356,6 +356,41 @@ def random(self) -> Random:
356
356
"""
357
357
return self .model .random
358
358
359
+ def apply (self , func : Callable , axis : str = "agent" , args = (), result_type = None , ** kwargs ) -> List [Any ] | Any :
360
+ """
361
+ Apply a function to all agents in the AgentSet either to each agent individually or to the entire agentset.
362
+
363
+ Args:
364
+ func (Callable): The function to apply to each individual agent or the entire agentset
365
+ axis (str): {'agent', 'agetset'} The axis along which to apply the function.
366
+
367
+ * 'agent' means apply the function to each agent.
368
+ * 'agentset' means apply the function to the entire agentset.
369
+
370
+ args (list or tuple):Positional arguments to pass to the function.
371
+ kwargs (dict): Additional keyword arguments to pass as keywords arguments to the function.
372
+
373
+ Returns:
374
+ the result of applying the function along the specified axis. In case of axis=agent, it will be a list with
375
+ the return of func for each agent. In case of axis=agentset, it is the return of func.
376
+
377
+ Notes:
378
+ To maintain method chaining in case of axis=agentset, func should return an agentset
379
+
380
+ """
381
+ if axis == "agent" :
382
+ # TODO:: add a results_type to make it trivial to return a dataframe with agent.id and func results?
383
+ # TODO:: this is a good idea, but tricky because you don't know all column names
384
+ return [func (agent , * args , ** kwargs ) for agent in self ]
385
+ elif axis == "agentset" :
386
+ return func (self , * args , ** kwargs )
387
+ else :
388
+ raise ValueError (f"axis should be agent or agentset not { axis } " )
389
+
390
+
391
+ from pandas import DataFrame
392
+
393
+ DataFrame .apply ()
359
394
360
395
# consider adding for performance reasons
361
396
# for Sequence: __reversed__, index, and count
0 commit comments