From ec5c2a893a3ad0faabd3e9d2a11438263b202f5a Mon Sep 17 00:00:00 2001 From: Corentin Date: Wed, 13 Mar 2024 12:41:07 +0100 Subject: [PATCH 01/16] First refactoring step save --- scripts/run_simulation.py | 52 +++-- vivarium/simulator/sim_computation.py | 132 +------------ vivarium/simulator/simulator.py | 1 + vivarium/simulator/states.py | 266 ++++++++++++++++++++++++++ 4 files changed, 295 insertions(+), 156 deletions(-) create mode 100644 vivarium/simulator/states.py diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index efd449e..85c8219 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -2,9 +2,13 @@ import logging import numpy as np +import jax.numpy as jnp from vivarium.simulator import behaviors -from vivarium.simulator.sim_computation import dynamics_rigid, StateType +from vivarium.simulator.sim_computation import dynamics_rigid +from vivarium.simulator.states import SimulatorState, AgentState, ObjectState, NVEState, State +from vivarium.simulator.states import init_simulator_state, init_agent_state, init_object_state, init_nve_state, init_state + from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig from vivarium.controllers import converters from vivarium.simulator.simulator import Simulator @@ -35,42 +39,34 @@ def parse_args(): args = parse_args() logging.basicConfig(level=args.log_level.upper()) + + # TODO : set the state without the configs - simulator_config = SimulatorConfig( + simulator_state = init_simulator_state( box_size=args.box_size, n_agents=args.n_agents, n_objects=args.n_objects, num_steps_lax=args.num_steps_lax, - dt=args.dt, - freq=args.freq, neighbor_radius=args.neighbor_radius, + dt=args.dt, to_jit=args.to_jit, use_fori_loop=args.use_fori_loop ) - - agent_configs = [ - AgentConfig(idx=i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_agents) - ] - - object_configs = [ - ObjectConfig(idx=simulator_config.n_agents + i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_objects) - ] - - state = converters.set_state_from_config_dict( - { - StateType.AGENT: agent_configs, - StateType.OBJECT: object_configs, - StateType.SIMULATOR: [simulator_config] - } - ) + + agents_state = init_agent_state( + n_agents=args.n_agents, + ) + + object_state = init_object_state( + n_objects=args.n_objects, + ) + + nve_state = init_nve_state( + simulator_state=simulator_state, + diameter=diameter, + friction=friction, + seed=0 + ) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index 6f47701..1418097 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -1,141 +1,17 @@ from functools import partial -from enum import Enum import jax import jax.numpy as jnp from jax import ops, vmap, lax -from jax_md import space, rigid_body, util, simulate, energy, quantity -from jax_md.dataclasses import dataclass - - +from jax_md import space, util, rigid_body, simulate, energy, quantity f32 = util.f32 -SPACE_NDIMS = 2 -class EntityType(Enum): - AGENT = 0 - OBJECT = 1 - - def to_state_type(self): - return StateType(self.value) - -class StateType(Enum): - AGENT = 0 - OBJECT = 1 - SIMULATOR = 2 - - def is_entity(self): - return self != StateType.SIMULATOR - - def to_entity_type(self): - assert self.is_entity() - return EntityType(self.value) - - - -@dataclass -class NVEState(simulate.NVEState): - entity_type: util.Array - entity_idx: util.Array # idx in XState (e.g. AgentState) - diameter: util.Array - friction: util.Array - exists: util.Array - - @property - def velocity(self) -> util.Array: - return self.momentum / self.mass - -@dataclass -class AgentState: - nve_idx: util.Array # idx in NVEState - prox: util.Array - motor: util.Array - behavior: util.Array - wheel_diameter: util.Array - speed_mul: util.Array - theta_mul: util.Array - proxs_dist_max: util.Array - proxs_cos_min: util.Array - color: util.Array - - -@dataclass -class ObjectState: - nve_idx: util.Array # idx in NVEState - color: util.Array - -@dataclass -class SimulatorState: - idx: util.Array - box_size: util.Array - n_agents: util.Array - n_objects: util.Array - num_steps_lax: util.Array - dt: util.Array - freq: util.Array - neighbor_radius: util.Array - to_jit: util.Array - use_fori_loop: util.Array - - @staticmethod - def get_type(attr): - if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']: - return int - elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius']: - return float - elif attr in ['to_jit', 'use_fori_loop']: - return bool - else: - raise ValueError() - -@dataclass -class State: - simulator_state: SimulatorState - nve_state: NVEState - agent_state: AgentState - object_state: ObjectState - - def field(self, stype_or_nested_fields): - if isinstance(stype_or_nested_fields, StateType): - name = stype_or_nested_fields.name.lower() - nested_fields = (f'{name}_state', ) - else: - nested_fields = stype_or_nested_fields - - res = self - for f in nested_fields: - res = getattr(res, f) - - return res - - # Should we keep this function because it is duplicated below ? - # def nve_idx(self, etype): - # cond = self.e_cond(etype) - # return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list - - def nve_idx(self, etype, entity_idx): - return self.field(etype).nve_idx[entity_idx] - - def e_idx(self, etype): - return self.nve_state.entity_idx[self.nve_state.entity_type == etype.value] - - def e_cond(self, etype): - return self.nve_state.entity_type == etype.value - - def row_idx(self, field, nve_idx): - return nve_idx if field == 'nve_state' else self.nve_state.entity_idx[jnp.array(nve_idx)] - - def __getattr__(self, name): - def wrapper(e_type): - value = getattr(self.nve_state, name) - if isinstance(value, rigid_body.RigidBody): - return rigid_body.RigidBody(center=value.center[self.e_cond(e_type)], - orientation=value.orientation[self.e_cond(e_type)]) - else: - return value[self.e_cond(e_type)] - return wrapper +# Only work on 2D environments atm +SPACE_NDIMS = 2 +# TODO : Add documentation to functions below def normal(theta): return jnp.array([jnp.cos(theta), jnp.sin(theta)]) diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index 61c6efb..68e4edf 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -180,3 +180,4 @@ def get_change_time(self): def get_state(self): return self.state + \ No newline at end of file diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py new file mode 100644 index 0000000..cd2488b --- /dev/null +++ b/vivarium/simulator/states.py @@ -0,0 +1,266 @@ +from enum import Enum + +import matplotlib.colors as mcolors +import jax.numpy as jnp + +from jax import random +from jax_md import util, simulate, rigid_body +from jax_md.dataclasses import dataclass +from jax_md.rigid_body import RigidBody + + +# Helper function to transform a color string into rgb with matplotlib colors +def string_to_rgb(color_str): + return jnp.array(list(mcolors.to_rgb(color_str))) + +# TODO : Add documentation on these classes + +class EntityType(Enum): + AGENT = 0 + OBJECT = 1 + + def to_state_type(self): + return StateType(self.value) + +class StateType(Enum): + AGENT = 0 + OBJECT = 1 + SIMULATOR = 2 + + def is_entity(self): + return self != StateType.SIMULATOR + + def to_entity_type(self): + assert self.is_entity() + return EntityType(self.value) + + +# NVE (should maybe rename it entities) + +# No need to define position, momentum, force, and mass (i.e already in simulate.NVEState) +@dataclass +class NVEState(simulate.NVEState): + entity_type: util.Array + entity_idx: util.Array # idx in XState (e.g. AgentState) + diameter: util.Array + friction: util.Array + exists: util.Array + + @property + def velocity(self) -> util.Array: + return self.momentum / self.mass + +def init_nve_state( + simulator_state, + diameter, + friction, + seed, + ) -> NVEState: + """ + Initialize agent state with given parameters + """ + n_agents = simulator_state.n_agents[0] + n_objects = simulator_state.n_objects[0] + n_entities = n_agents + n_objects + + key = random.PRNGKey(seed) + key_pos, key_or = random.split(key, 2) + + # Assign random positions to each entities (will be changed in the future to allow defining custom positions) + positions = random.uniform(key_pos, (n_entities, 2)) * simulator_state.box_size + # Assign random orientations between 0 and 2*pi + orientations = random.uniform(key_or, (n_entities, 2)) * 2 * jnp.pi + + return NVEState( + position=RigidBody(center=positions, orientation=orientations), + # TODO: Why is momentum set to none ? + momentum=None, + # Should we indeed set the force and mass to 0 ? + force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), + mass=RigidBody(center=jnp.zeros((n_entities, 1)), orientation=jnp.zeros(n_entities)), + entity_type=jnp.array([EntityType.AGENT.value] * n_agents + [EntityType.OBJECT.value] * n_objects, dtype=int), + entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), + diameter=jnp.full((n_entities), diameter), + friction=jnp.full((n_entities), friction), + # Set all the entities to exist by default, but we should add a function to be able to change that + exists=jnp.ones(n_entities, dtype=int) + ) + + +# Agents + +@dataclass +class AgentState: + nve_idx: util.Array # idx in NVEState + prox: util.Array + motor: util.Array + behavior: util.Array + wheel_diameter: util.Array + speed_mul: util.Array + theta_mul: util.Array + proxs_dist_max: util.Array + proxs_cos_min: util.Array + color: util.Array + +# Could implement it as a static or class method +def init_agent_state( + n_agents, + behavior, + wheel_diameter, + speed_mul, + theta_mul, + prox_dist_max, + prox_cos_min, + color + ) -> AgentState: + """ + Initialize agent state with given parameters + """ + return AgentState( + nve_idx=jnp.arange(n_agents, dtype=int), + prox=jnp.zeros((n_agents, 2)), + motor=jnp.zeros((n_agents, 2)), + behavior=jnp.full((n_agents), behavior), + wheel_diameter=jnp.full((n_agents), wheel_diameter), + speed_mul=jnp.full((n_agents), speed_mul), + theta_mul=jnp.full((n_agents), theta_mul), + prox_dist_max=jnp.full((n_agents), prox_dist_max), + proxs_cos_min=jnp.full((n_agents), prox_cos_min), + color=jnp.tile(string_to_rgb(color), (n_agents, 1)) + ) + + +# Objects + +@dataclass +class ObjectState: + nve_idx: util.Array # idx in NVEState + color: util.Array + +def init_object_state(n_objects, color) -> ObjectState: + """ + Initialize object state with given parameters + """ + return ObjectState( + nve_idx=jnp.arange(n_objects, dtype=int), + color=jnp.tile(string_to_rgb(color), (n_objects, 1)) + ) + + +# Simulator + +@dataclass +class SimulatorState: + idx: util.Array + box_size: util.Array + n_agents: util.Array + n_objects: util.Array + num_steps_lax: util.Array + dt: util.Array + freq: util.Array + neighbor_radius: util.Array + to_jit: util.Array + use_fori_loop: util.Array + + @staticmethod + def get_type(attr): + if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']: + return int + elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius']: + return float + elif attr in ['to_jit', 'use_fori_loop']: + return bool + else: + raise ValueError() + +def init_simulator_state( + box_size, + n_agents, + n_objects, + num_steps_lax, + dt, + freq, + neighbor_radius, + to_jit, + use_fori_loop + ) -> SimulatorState: + """ + Initialize simulator state with given parameters + """ + return SimulatorState( + idx=jnp.array([0]), + box_size=jnp.array([box_size]), + n_agents=jnp.array([n_agents]), + n_objects=jnp.array([n_objects]), + num_steps_lax=jnp.array([num_steps_lax], dtype=int), + dt=jnp.array([dt], dtype=float), + freq=jnp.array([freq], dtype=float), + neighbor_radius=jnp.array([neighbor_radius], dtype=float), + # Use 1*bool to transform True to 1 and False to 0 + to_jit= jnp.array([1*to_jit]), + use_fori_loop=jnp.array([1*use_fori_loop])) + +@dataclass +class State: + simulator_state: SimulatorState + nve_state: NVEState + agent_state: AgentState + object_state: ObjectState + + def field(self, stype_or_nested_fields): + if isinstance(stype_or_nested_fields, StateType): + name = stype_or_nested_fields.name.lower() + nested_fields = (f'{name}_state', ) + else: + nested_fields = stype_or_nested_fields + + res = self + for f in nested_fields: + res = getattr(res, f) + + return res + + # Should we keep this function because it is duplicated below ? + # def nve_idx(self, etype): + # cond = self.e_cond(etype) + # return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list + + def nve_idx(self, etype, entity_idx): + return self.field(etype).nve_idx[entity_idx] + + def e_idx(self, etype): + return self.nve_state.entity_idx[self.nve_state.entity_type == etype.value] + + def e_cond(self, etype): + return self.nve_state.entity_type == etype.value + + def row_idx(self, field, nve_idx): + return nve_idx if field == 'nve_state' else self.nve_state.entity_idx[jnp.array(nve_idx)] + + def __getattr__(self, name): + def wrapper(e_type): + value = getattr(self.nve_state, name) + if isinstance(value, rigid_body.RigidBody): + return rigid_body.RigidBody(center=value.center[self.e_cond(e_type)], + orientation=value.orientation[self.e_cond(e_type)]) + else: + return value[self.e_cond(e_type)] + return wrapper + +def init_state( + simulator_state, + agents_state, + objects_state, + nve_state + ) -> State: + + return State( + simulator_state=simulator_state, + agents_state=agents_state, + objects_state=objects_state, + nve_state=nve_state + ) + + + + \ No newline at end of file From 98046d4dc184c31c5378c0d676609d90e7e764b4 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Mon, 18 Mar 2024 18:15:11 +0100 Subject: [PATCH 02/16] Add all init functions in state.py --- vivarium/simulator/states.py | 270 +++++++++++++++++++---------------- 1 file changed, 147 insertions(+), 123 deletions(-) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index cd2488b..968a9ea 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -1,3 +1,4 @@ +from typing import Optional, List, Union from enum import Enum import matplotlib.colors as mcolors @@ -9,12 +10,7 @@ from jax_md.rigid_body import RigidBody -# Helper function to transform a color string into rgb with matplotlib colors -def string_to_rgb(color_str): - return jnp.array(list(mcolors.to_rgb(color_str))) - # TODO : Add documentation on these classes - class EntityType(Enum): AGENT = 0 OBJECT = 1 @@ -35,8 +31,7 @@ def to_entity_type(self): return EntityType(self.value) -# NVE (should maybe rename it entities) - +# NVE (we could potentially rename it entities ? What do you think ?) # No need to define position, momentum, force, and mass (i.e already in simulate.NVEState) @dataclass class NVEState(simulate.NVEState): @@ -50,44 +45,6 @@ class NVEState(simulate.NVEState): def velocity(self) -> util.Array: return self.momentum / self.mass -def init_nve_state( - simulator_state, - diameter, - friction, - seed, - ) -> NVEState: - """ - Initialize agent state with given parameters - """ - n_agents = simulator_state.n_agents[0] - n_objects = simulator_state.n_objects[0] - n_entities = n_agents + n_objects - - key = random.PRNGKey(seed) - key_pos, key_or = random.split(key, 2) - - # Assign random positions to each entities (will be changed in the future to allow defining custom positions) - positions = random.uniform(key_pos, (n_entities, 2)) * simulator_state.box_size - # Assign random orientations between 0 and 2*pi - orientations = random.uniform(key_or, (n_entities, 2)) * 2 * jnp.pi - - return NVEState( - position=RigidBody(center=positions, orientation=orientations), - # TODO: Why is momentum set to none ? - momentum=None, - # Should we indeed set the force and mass to 0 ? - force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), - mass=RigidBody(center=jnp.zeros((n_entities, 1)), orientation=jnp.zeros(n_entities)), - entity_type=jnp.array([EntityType.AGENT.value] * n_agents + [EntityType.OBJECT.value] * n_objects, dtype=int), - entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), - diameter=jnp.full((n_entities), diameter), - friction=jnp.full((n_entities), friction), - # Set all the entities to exist by default, but we should add a function to be able to change that - exists=jnp.ones(n_entities, dtype=int) - ) - - -# Agents @dataclass class AgentState: @@ -102,52 +59,12 @@ class AgentState: proxs_cos_min: util.Array color: util.Array -# Could implement it as a static or class method -def init_agent_state( - n_agents, - behavior, - wheel_diameter, - speed_mul, - theta_mul, - prox_dist_max, - prox_cos_min, - color - ) -> AgentState: - """ - Initialize agent state with given parameters - """ - return AgentState( - nve_idx=jnp.arange(n_agents, dtype=int), - prox=jnp.zeros((n_agents, 2)), - motor=jnp.zeros((n_agents, 2)), - behavior=jnp.full((n_agents), behavior), - wheel_diameter=jnp.full((n_agents), wheel_diameter), - speed_mul=jnp.full((n_agents), speed_mul), - theta_mul=jnp.full((n_agents), theta_mul), - prox_dist_max=jnp.full((n_agents), prox_dist_max), - proxs_cos_min=jnp.full((n_agents), prox_cos_min), - color=jnp.tile(string_to_rgb(color), (n_agents, 1)) - ) - - -# Objects @dataclass class ObjectState: nve_idx: util.Array # idx in NVEState color: util.Array -def init_object_state(n_objects, color) -> ObjectState: - """ - Initialize object state with given parameters - """ - return ObjectState( - nve_idx=jnp.arange(n_objects, dtype=int), - color=jnp.tile(string_to_rgb(color), (n_objects, 1)) - ) - - -# Simulator @dataclass class SimulatorState: @@ -172,33 +89,7 @@ def get_type(attr): return bool else: raise ValueError() - -def init_simulator_state( - box_size, - n_agents, - n_objects, - num_steps_lax, - dt, - freq, - neighbor_radius, - to_jit, - use_fori_loop - ) -> SimulatorState: - """ - Initialize simulator state with given parameters - """ - return SimulatorState( - idx=jnp.array([0]), - box_size=jnp.array([box_size]), - n_agents=jnp.array([n_agents]), - n_objects=jnp.array([n_objects]), - num_steps_lax=jnp.array([num_steps_lax], dtype=int), - dt=jnp.array([dt], dtype=float), - freq=jnp.array([freq], dtype=float), - neighbor_radius=jnp.array([neighbor_radius], dtype=float), - # Use 1*bool to transform True to 1 and False to 0 - to_jit= jnp.array([1*to_jit]), - use_fori_loop=jnp.array([1*use_fori_loop])) + @dataclass class State: @@ -246,21 +137,154 @@ def wrapper(e_type): else: return value[self.e_cond(e_type)] return wrapper - + + +# Helper function to transform a color string into rgb with matplotlib colors +def _string_to_rgb(color_str): + return jnp.array(list(mcolors.to_rgb(color_str))) + + +def init_simulator_state( + box_size: float = 100., + n_agents: int = 10, + n_objects: int = 2, + num_steps_lax: int = 4, + dt: float = 0.1, + freq: float = 40., + neighbor_radius: float = 100., + to_jit: bool = True, + use_fori_loop: bool = False + ) -> SimulatorState: + """ + Initialize simulator state with given parameters + """ + return SimulatorState( + idx=jnp.array([0]), + box_size=jnp.array([box_size]), + n_agents=jnp.array([n_agents]), + n_objects=jnp.array([n_objects]), + num_steps_lax=jnp.array([num_steps_lax], dtype=int), + dt=jnp.array([dt], dtype=float), + freq=jnp.array([freq], dtype=float), + neighbor_radius=jnp.array([neighbor_radius], dtype=float), + # Use 1*bool to transform True to 1 and False to 0 + to_jit= jnp.array([1*to_jit]), + use_fori_loop=jnp.array([1*use_fori_loop])) + +# TODO : Should also add union float, list[float] for friction, diameter ... +def init_nve_state( + simulator_state: SimulatorState, + diameter: float = 5., + friction: float = 0.1, + agents_positions: Optional[Union[List[float], None]] = None, + objects_positions: Optional[Union[List[float], None]] = None, + seed: int = 0, + ) -> NVEState: + """ + Initialize nve state with given parameters + """ + n_agents = simulator_state.n_agents[0] + n_objects = simulator_state.n_objects[0] + n_entities = n_agents + n_objects + + key = random.PRNGKey(seed) + key_pos, key_or = random.split(key, 2) + + # If we have a list of agents positions, transform it into a jax array + if agents_positions: + agents_positions = jnp.array(agents_positions) + # Else initialize random positions + else: + agents_positions = random.uniform(key_pos, (n_agents, 2)) * simulator_state.box_size + + # Same for ojects positions + if objects_positions: + objects_positions = jnp.array(objects_positions) + else: + objects_positions = random.uniform(key_pos, (n_objects, 2)) * simulator_state.box_size + + agents_entities = jnp.full(n_agents, EntityType.AGENT.value) + object_entities = jnp.full(n_objects, EntityType.OBJECT.value) + entity_types = jnp.concatenate((agents_entities, object_entities), dtype=int) + + # Assign their positions to each entities + positions = jnp.concatenate((agents_positions, objects_positions)) + # Assign random orientations between 0 and 2*pi + orientations = random.uniform(key_or, (n_entities,)) * 2 * jnp.pi + + return NVEState( + position=RigidBody(center=positions, orientation=orientations), + # TODO: Why is momentum set to none ? + momentum=None, + # Should we indeed set the force and mass to 0 ? + force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), + mass=RigidBody(center=jnp.zeros((n_entities, 1)), orientation=jnp.zeros(n_entities)), + entity_type=entity_types, + entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), + diameter=jnp.full((n_entities), diameter), + friction=jnp.full((n_entities), friction), + # Set all the entities to exist by default, but we should add a function to be able to change that + exists=jnp.ones(n_entities, dtype=int) + ) + + +# Could implement it as a static or class method +def init_agent_state( + n_agents: int, + behavior: int = 0, + wheel_diameter: float = 2., + speed_mul: float = 1., + theta_mul: float = 1., + prox_dist_max: float = 100., + prox_cos_min: float = 0., + color: str = "blue" + ) -> AgentState: + """ + Initialize agent state with given parameters + """ + + # TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters) + # TODO : if the shape if just 1 value, assign it to all agents + # TODO : else, ensure you are given a list of arguments and transform it into a jax array + + return AgentState( + nve_idx=jnp.arange(n_agents, dtype=int), + prox=jnp.zeros((n_agents, 2)), + motor=jnp.zeros((n_agents, 2)), + behavior=jnp.full((n_agents), behavior), + wheel_diameter=jnp.full((n_agents), wheel_diameter), + speed_mul=jnp.full((n_agents), speed_mul), + theta_mul=jnp.full((n_agents), theta_mul), + proxs_dist_max=jnp.full((n_agents), prox_dist_max), + proxs_cos_min=jnp.full((n_agents), prox_cos_min), + color=jnp.tile(_string_to_rgb(color), (n_agents, 1)) + ) + + +def init_object_state( + n_objects: int, + color: str = "red" + ) -> ObjectState: + """ + Initialize object state with given parameters + """ + return ObjectState( + nve_idx=jnp.arange(n_objects, dtype=int), + color=jnp.tile(_string_to_rgb(color), (n_objects, 1)) + ) + + def init_state( - simulator_state, - agents_state, - objects_state, - nve_state + simulator_state: SimulatorState, + agents_state: AgentState, + objects_state: ObjectState, + nve_state: NVEState ) -> State: - + return State( simulator_state=simulator_state, - agents_state=agents_state, - objects_state=objects_state, + agent_state=agents_state, + object_state=objects_state, nve_state=nve_state ) - - - \ No newline at end of file From b498f041da41c330f15db9e90d402e20005c1079 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Mon, 18 Mar 2024 18:17:27 +0100 Subject: [PATCH 03/16] Fix imports with new states file --- vivarium/controllers/config.py | 2 +- vivarium/controllers/converters.py | 16 +++++------ vivarium/controllers/notebook_controller.py | 2 +- vivarium/controllers/panel_controller.py | 2 +- vivarium/controllers/simulator_controller.py | 13 +++++---- vivarium/interface/panel_app.py | 2 +- vivarium/simulator/grpc_server/converters.py | 5 ++-- .../simulator/grpc_server/simulator_server.py | 27 +++++++++---------- vivarium/simulator/simulator.py | 3 +-- 9 files changed, 35 insertions(+), 37 deletions(-) diff --git a/vivarium/controllers/config.py b/vivarium/controllers/config.py index 73c2a19..0e5b0ff 100644 --- a/vivarium/controllers/config.py +++ b/vivarium/controllers/config.py @@ -2,7 +2,7 @@ from param import Parameterized import vivarium.simulator.behaviors as behaviors -from vivarium.simulator.sim_computation import StateType +from vivarium.simulator.states import StateType from jax_md.rigid_body import monomer diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index e3d91ae..1d0c030 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -1,20 +1,18 @@ +import typing +import dataclasses +from collections import namedtuple, defaultdict + +import jax_md import jax.numpy as jnp import numpy as np +import matplotlib.colors as mcolors -import jax_md -from jax_md.util import f32 from jax_md.rigid_body import RigidBody -import dataclasses -import typing -from collections import namedtuple, defaultdict - from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig, stype_to_config, config_to_stype -from vivarium.simulator.sim_computation import State, SimulatorState, NVEState, AgentState, ObjectState, EntityType, StateType +from vivarium.simulator.states import State, SimulatorState, NVEState, AgentState, ObjectState, EntityType, StateType from vivarium.simulator.behaviors import behavior_name_map, reversed_behavior_name_map -import matplotlib.colors as mcolors - agent_config_fields = AgentConfig.param.objects().keys() agent_state_fields = [f.name for f in jax_md.dataclasses.fields(AgentState)] diff --git a/vivarium/controllers/notebook_controller.py b/vivarium/controllers/notebook_controller.py index e9c9ef3..02b22f4 100644 --- a/vivarium/controllers/notebook_controller.py +++ b/vivarium/controllers/notebook_controller.py @@ -5,7 +5,7 @@ import logging from vivarium.controllers.simulator_controller import SimulatorController -from vivarium.simulator.sim_computation import StateType, EntityType +from vivarium.simulator.states import StateType, EntityType lg = logging.getLogger(__name__) diff --git a/vivarium/controllers/panel_controller.py b/vivarium/controllers/panel_controller.py index 485ad7e..d6117f6 100644 --- a/vivarium/controllers/panel_controller.py +++ b/vivarium/controllers/panel_controller.py @@ -1,7 +1,7 @@ from vivarium.controllers import converters from vivarium.controllers.config import AgentConfig, ObjectConfig, config_to_stype, Config from vivarium.controllers.simulator_controller import SimulatorController -from vivarium.simulator.sim_computation import EntityType, StateType +from vivarium.simulator.states import EntityType, StateType from vivarium.simulator.grpc_server.simulator_client import SimulatorGRPCClient import param diff --git a/vivarium/controllers/simulator_controller.py b/vivarium/controllers/simulator_controller.py index 9280076..f2bceae 100644 --- a/vivarium/controllers/simulator_controller.py +++ b/vivarium/controllers/simulator_controller.py @@ -1,13 +1,16 @@ +import time +import threading +import logging + +from contextlib import contextmanager + import param from vivarium.simulator.grpc_server.simulator_client import SimulatorGRPCClient from vivarium.controllers.config import SimulatorConfig -from vivarium.simulator.sim_computation import StateType +from vivarium.simulator.states import StateType from vivarium.controllers import converters -import time -import threading -from contextlib import contextmanager -import logging + lg = logging.getLogger(__name__) diff --git a/vivarium/interface/panel_app.py b/vivarium/interface/panel_app.py index 76dfbd8..684adc8 100644 --- a/vivarium/interface/panel_app.py +++ b/vivarium/interface/panel_app.py @@ -9,7 +9,7 @@ from vivarium.simulator.grpc_server.simulator_client import SimulatorGRPCClient from vivarium.controllers.panel_controller import PanelController -from vivarium.simulator.sim_computation import EntityType +from vivarium.simulator.states import EntityType pn.extension() diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index 5289db8..ae5178d 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -1,9 +1,10 @@ from jax_md.rigid_body import RigidBody -from vivarium.simulator.grpc_server.numproto.numproto import proto_to_ndarray, ndarray_to_proto -from vivarium.simulator.sim_computation import State, SimulatorState, NVEState, AgentState, ObjectState import simulator_pb2 +from vivarium.simulator.grpc_server.numproto.numproto import proto_to_ndarray, ndarray_to_proto +from vivarium.simulator.states import State, SimulatorState, NVEState, AgentState, ObjectState + def proto_to_state(state): return State(simulator_state=proto_to_simulator_state(state.simulator_state), diff --git a/vivarium/simulator/grpc_server/simulator_server.py b/vivarium/simulator/grpc_server/simulator_server.py index af9a411..cef5508 100644 --- a/vivarium/simulator/grpc_server/simulator_server.py +++ b/vivarium/simulator/grpc_server/simulator_server.py @@ -1,24 +1,21 @@ -from collections import defaultdict +import logging -from numproto.numproto import proto_to_ndarray +from concurrent import futures +from threading import Lock +from contextlib import contextmanager +from collections import defaultdict import grpc -import simulator_pb2_grpc import simulator_pb2 +import simulator_pb2_grpc -import numpy as np -import logging -from concurrent import futures -from threading import Lock -from contextlib import contextmanager +from numproto.numproto import proto_to_ndarray + +from vivarium.simulator.grpc_server.converters import state_to_proto +from vivarium.simulator.grpc_server.converters import nve_state_to_proto +from vivarium.simulator.grpc_server.converters import agent_state_to_proto +from vivarium.simulator.grpc_server.converters import object_state_to_proto -from vivarium.controllers.config import SimulatorConfig, AgentConfig, ObjectConfig -from vivarium.simulator.sim_computation import StateType -from vivarium.simulator.simulator import Simulator -from vivarium.simulator.sim_computation import dynamics_rigid -import vivarium.simulator.behaviors as behaviors -from vivarium.simulator.grpc_server.converters import state_to_proto, nve_state_to_proto, agent_state_to_proto, object_state_to_proto -from vivarium.controllers.converters import set_state_from_config_dict lg = logging.getLogger(__name__) diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index dbfcb86..bcf60c5 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -13,8 +13,7 @@ from jax import lax from jax_md import space, partition, dataclasses -from vivarium.simulator.sim_computation import EntityType, SimulatorState -from vivarium.controllers import converters +from vivarium.simulator.states import EntityType, SimulatorState lg = logging.getLogger(__name__) From 89ecef760e4527da40b2298c5e2a3fc2b8bb7435 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Mon, 18 Mar 2024 18:18:04 +0100 Subject: [PATCH 04/16] Initialize states without configs --- scripts/run_server.py | 44 +++++++++++++++++++-------------------- scripts/run_simulation.py | 33 +++++++++++++---------------- 2 files changed, 37 insertions(+), 40 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index e59ce50..fa2be7a 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -1,14 +1,14 @@ import logging import argparse -import numpy as np - -import vivarium.simulator.behaviors as behaviors -from vivarium.controllers.config import SimulatorConfig, AgentConfig, ObjectConfig -from vivarium.simulator.sim_computation import StateType +from vivarium.simulator import behaviors +from vivarium.simulator.states import init_simulator_state +from vivarium.simulator.states import init_agent_state +from vivarium.simulator.states import init_object_state +from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.sim_computation import dynamics_rigid -from vivarium.controllers.converters import set_state_from_config_dict from vivarium.simulator.grpc_server.simulator_server import serve lg = logging.getLogger(__name__) @@ -35,34 +35,34 @@ def parse_args(): logging.basicConfig(level=args.log_level.upper()) - simulator_config = SimulatorConfig( + simulator_state = init_simulator_state( box_size=args.box_size, n_agents=args.n_agents, n_objects=args.n_objects, num_steps_lax=args.num_steps_lax, + neighbor_radius=args.neighbor_radius, dt=args.dt, freq=args.freq, - neighbor_radius=args.neighbor_radius, to_jit=args.to_jit, use_fori_loop=args.use_fori_loop ) - agent_configs = [AgentConfig(idx=i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_agents)] + agents_state = init_agent_state( + n_agents=args.n_agents, + ) + + objects_state = init_object_state( + n_objects=args.n_objects, + ) - object_configs = [ObjectConfig(idx=simulator_config.n_agents + i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_objects)] + nve_state = init_nve_state(simulator_state) - state = set_state_from_config_dict({StateType.AGENT: agent_configs, - StateType.OBJECT: object_configs, - StateType.SIMULATOR: [simulator_config] - }) + state = init_state( + simulator_state=simulator_state, + agents_state=agents_state, + objects_state=objects_state, + nve_state=nve_state + ) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) lg.info('Simulator server started') diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index 78e0a75..511f954 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -1,17 +1,14 @@ import argparse import logging -import numpy as np -import jax.numpy as jnp - from vivarium.simulator import behaviors -from vivarium.simulator.sim_computation import dynamics_rigid -from vivarium.simulator.states import SimulatorState, AgentState, ObjectState, NVEState, State -from vivarium.simulator.states import init_simulator_state, init_agent_state, init_object_state, init_nve_state, init_state - -from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig -from vivarium.controllers import converters +from vivarium.simulator.states import init_simulator_state +from vivarium.simulator.states import init_agent_state +from vivarium.simulator.states import init_object_state +from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator +from vivarium.simulator.sim_computation import dynamics_rigid lg = logging.getLogger(__name__) @@ -40,8 +37,6 @@ def parse_args(): logging.basicConfig(level=args.log_level.upper()) - # TODO : set the state without the configs - simulator_state = init_simulator_state( box_size=args.box_size, n_agents=args.n_agents, @@ -49,6 +44,7 @@ def parse_args(): num_steps_lax=args.num_steps_lax, neighbor_radius=args.neighbor_radius, dt=args.dt, + freq=args.freq, to_jit=args.to_jit, use_fori_loop=args.use_fori_loop ) @@ -57,17 +53,18 @@ def parse_args(): n_agents=args.n_agents, ) - object_state = init_object_state( + objects_state = init_object_state( n_objects=args.n_objects, ) - nve_state = init_nve_state( - simulator_state=simulator_state, - diameter=diameter, - friction=friction, - seed=0 - ) + nve_state = init_nve_state(simulator_state) + state = init_state( + simulator_state=simulator_state, + agents_state=agents_state, + objects_state=objects_state, + nve_state=nve_state + ) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) From 4308ce64d327a1cdf046530390908deb6b9c2fa0 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 20 Mar 2024 15:56:27 +0100 Subject: [PATCH 05/16] Fix init state values that turn into NaN in client side --- scripts/run_server.py | 10 +++------- vivarium/simulator/states.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index fa2be7a..f59b23a 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -47,15 +47,11 @@ def parse_args(): use_fori_loop=args.use_fori_loop ) - agents_state = init_agent_state( - n_agents=args.n_agents, - ) + agents_state = init_agent_state(simulator_state=simulator_state) - objects_state = init_object_state( - n_objects=args.n_objects, - ) + objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state(simulator_state) + nve_state = init_nve_state(simulator_state=simulator_state) state = init_state( simulator_state=simulator_state, diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 968a9ea..3b88b8b 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -176,6 +176,8 @@ def init_nve_state( simulator_state: SimulatorState, diameter: float = 5., friction: float = 0.1, + mass_center: float = 1., + mass_orientation: float = 0.125, agents_positions: Optional[Union[List[float], None]] = None, objects_positions: Optional[Union[List[float], None]] = None, seed: int = 0, @@ -214,11 +216,11 @@ def init_nve_state( return NVEState( position=RigidBody(center=positions, orientation=orientations), - # TODO: Why is momentum set to none ? + # TODO: Why is momentum set to none ? momentum=None, # Should we indeed set the force and mass to 0 ? force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), - mass=RigidBody(center=jnp.zeros((n_entities, 1)), orientation=jnp.zeros(n_entities)), + mass=RigidBody(center=jnp.full((n_entities, 1), mass_center), orientation=jnp.full((n_entities), mass_orientation)), entity_type=entity_types, entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), diameter=jnp.full((n_entities), diameter), @@ -230,8 +232,8 @@ def init_nve_state( # Could implement it as a static or class method def init_agent_state( - n_agents: int, - behavior: int = 0, + simulator_state: SimulatorState, + behavior: int = 1, wheel_diameter: float = 2., speed_mul: float = 1., theta_mul: float = 1., @@ -242,10 +244,10 @@ def init_agent_state( """ Initialize agent state with given parameters """ - + n_agents = simulator_state.n_agents[0] # TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters) - # TODO : if the shape if just 1 value, assign it to all agents - # TODO : else, ensure you are given a list of arguments and transform it into a jax array + # TODO : if the shape if just 1 value, assign it to all agents + # TODO : else, ensure you are given a list of arguments of size n_agents and transform it into a jax array return AgentState( nve_idx=jnp.arange(n_agents, dtype=int), @@ -262,14 +264,17 @@ def init_agent_state( def init_object_state( - n_objects: int, + simulator_state: SimulatorState, color: str = "red" ) -> ObjectState: """ Initialize object state with given parameters """ + n_agents, n_objects = simulator_state.n_agents[0], simulator_state.n_objects[0] + start_idx, stop_idx = n_agents, n_agents + n_objects + objects_nve_idx = jnp.arange(start_idx, stop_idx, dtype=int) return ObjectState( - nve_idx=jnp.arange(n_objects, dtype=int), + nve_idx=objects_nve_idx, color=jnp.tile(_string_to_rgb(color), (n_objects, 1)) ) From 868074e176cc46c3f01869c76e5da02c06645855 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 20 Mar 2024 16:12:29 +0100 Subject: [PATCH 06/16] Fix converters import error --- vivarium/simulator/simulator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index bcf60c5..806a6a8 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -13,6 +13,7 @@ from jax import lax from jax_md import space, partition, dataclasses +from vivarium.controllers import converters from vivarium.simulator.states import EntityType, SimulatorState lg = logging.getLogger(__name__) From b86da7ac6ad09ddc7c82d0fc8697a785326ffc42 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 20 Mar 2024 16:14:52 +0100 Subject: [PATCH 07/16] Change arguments for agents and objects states init --- scripts/run_simulation.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index 511f954..356ac56 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -49,15 +49,11 @@ def parse_args(): use_fori_loop=args.use_fori_loop ) - agents_state = init_agent_state( - n_agents=args.n_agents, - ) + agents_state = init_agent_state(simulator_state=simulator_state) - objects_state = init_object_state( - n_objects=args.n_objects, - ) + objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state(simulator_state) + nve_state = init_nve_state(simulator_state=simulator_state) state = init_state( simulator_state=simulator_state, From 225c5f655cc7ae5c4801f11d5ee537c7432e6a56 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 20 Mar 2024 16:47:16 +0100 Subject: [PATCH 08/16] Add option to test initialization with custom positions --- scripts/run_server.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index f59b23a..37c7f43 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -15,6 +15,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Simulator Configuration') + parser.add_argument('--custom_pos', action='store_true', help='Just a test arg to wether use custom or random pos') parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') parser.add_argument('--n_agents', type=int, default=10, help='Number of agents') parser.add_argument('--n_objects', type=int, default=2, help='Number of objects') @@ -35,6 +36,17 @@ def parse_args(): logging.basicConfig(level=args.log_level.upper()) + # Define a custom list of init agents positions (ugly here but will be done in a yaml file later) + if args.custom_pos: + coord = 0 + positions = [] + # Here all the agents are placed on the box diagonal + for i in range(args.n_agents): + coord += 10 + positions.append([coord, coord]) + else: + positions = None + simulator_state = init_simulator_state( box_size=args.box_size, n_agents=args.n_agents, @@ -51,7 +63,7 @@ def parse_args(): objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state(simulator_state=simulator_state) + nve_state = init_nve_state(simulator_state=simulator_state, agents_positions=positions) state = init_state( simulator_state=simulator_state, From 4ee97dc7b880e101eae9da93264f8328f25e06e5 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 20 Mar 2024 17:33:29 +0100 Subject: [PATCH 09/16] Update tests for new state initialization --- tests/test_simulator.py | 46 --------------------- tests/test_simulator_init.py | 78 ++++++++++++++++++++++++++++++++++++ tests/test_simulator_run.py | 36 +++++++++++++++++ 3 files changed, 114 insertions(+), 46 deletions(-) delete mode 100644 tests/test_simulator.py create mode 100644 tests/test_simulator_init.py create mode 100644 tests/test_simulator_run.py diff --git a/tests/test_simulator.py b/tests/test_simulator.py deleted file mode 100644 index 7511973..0000000 --- a/tests/test_simulator.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np - -import vivarium.simulator.behaviors as behaviors -from vivarium.simulator.simulator import Simulator -from vivarium.simulator.sim_computation import dynamics_rigid, StateType -from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig -from vivarium.controllers import converters - - -NUM_STEPS = 50 - - -# First smoke test, we could split it into different parts later (initialization, run, ...) -def test_simulator_run(): - simulator_config = SimulatorConfig(to_jit=True) - - agent_configs = [ - AgentConfig( - idx=i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_agents) - ] - - object_configs = [ - ObjectConfig( - idx=simulator_config.n_agents + i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_objects) - ] - - state = converters.set_state_from_config_dict( - {StateType.AGENT: agent_configs, - StateType.OBJECT: object_configs, - StateType.SIMULATOR: [simulator_config] - } - ) - - simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) - - simulator.run(threaded=False, num_steps=NUM_STEPS) - - assert simulator \ No newline at end of file diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py new file mode 100644 index 0000000..a283554 --- /dev/null +++ b/tests/test_simulator_init.py @@ -0,0 +1,78 @@ +from vivarium.simulator import behaviors +from vivarium.simulator.states import init_simulator_state +from vivarium.simulator.states import init_agent_state +from vivarium.simulator.states import init_object_state +from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_state +from vivarium.simulator.simulator import Simulator +from vivarium.simulator.sim_computation import dynamics_rigid + + +def test_init_simulator_no_args(): + """ Test the initialization of state without arguments """ + simulator_state = init_simulator_state() + agents_state = init_agent_state(simulator_state=simulator_state) + objects_state = init_object_state(simulator_state=simulator_state) + nve_state = init_nve_state(simulator_state=simulator_state) + + state = init_state( + simulator_state=simulator_state, + agents_state=agents_state, + objects_state=objects_state, + nve_state=nve_state + ) + + simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + + assert simulator + + +def test_init_simulator_args(): + """ Test the initialization of state with arguments """ + box_size = 100.0 + n_agents = 10 + n_objects = 2 + + diameter = 5.0 + friction = 0.1 + behavior = 1 + wheel_diameter = 2.0 + speed_mul = 1.0 + theta_mul = 1.0 + prox_dist_max = 20.0 + prox_cos_min = 0.0 + color = "red" + + simulator_state = init_simulator_state( + box_size=box_size, + n_agents=n_agents, + n_objects=n_objects) + + nve_state = init_nve_state( + simulator_state, + diameter=diameter, + friction=friction) + + agent_state = init_agent_state( + simulator_state, + behavior=behavior, + wheel_diameter=wheel_diameter, + speed_mul=speed_mul, + theta_mul=theta_mul, + prox_dist_max=prox_dist_max, + prox_cos_min=prox_cos_min) + + object_state = init_object_state( + simulator_state, + color=color) + + state = init_state( + simulator_state=simulator_state, + agents_state=agent_state, + objects_state=object_state, + nve_state=nve_state + ) + + simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + + assert simulator \ No newline at end of file diff --git a/tests/test_simulator_run.py b/tests/test_simulator_run.py new file mode 100644 index 0000000..8eff893 --- /dev/null +++ b/tests/test_simulator_run.py @@ -0,0 +1,36 @@ +from vivarium.simulator import behaviors +from vivarium.simulator.states import init_simulator_state +from vivarium.simulator.states import init_agent_state +from vivarium.simulator.states import init_object_state +from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_state +from vivarium.simulator.simulator import Simulator +from vivarium.simulator.sim_computation import dynamics_rigid + +NUM_STEPS = 50 + +# First smoke test, we could split it into different parts later (initialization, run, ...) +def test_simulator_run(): + + simulator_state = init_simulator_state() + + agents_state = init_agent_state(simulator_state=simulator_state) + + objects_state = init_object_state(simulator_state=simulator_state) + + nve_state = init_nve_state(simulator_state=simulator_state) + + state = init_state( + simulator_state=simulator_state, + agents_state=agents_state, + objects_state=objects_state, + nve_state=nve_state + ) + + simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + + simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + + simulator.run(threaded=False, num_steps=NUM_STEPS) + + assert simulator \ No newline at end of file From b7e98d9c26a5a327d093540a5378a43a7d53b982 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 21 Mar 2024 10:40:19 +0100 Subject: [PATCH 10/16] Reduce default value of agents proximeters range --- vivarium/simulator/states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 3b88b8b..db7913a 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -237,7 +237,7 @@ def init_agent_state( wheel_diameter: float = 2., speed_mul: float = 1., theta_mul: float = 1., - prox_dist_max: float = 100., + prox_dist_max: float = 40., prox_cos_min: float = 0., color: str = "blue" ) -> AgentState: From fd9674691141633e8706ddf490c058ccc7710107 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 21 Mar 2024 11:50:41 +0100 Subject: [PATCH 11/16] Refactor init_positions function and update comments --- vivarium/simulator/states.py | 51 ++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index db7913a..e3df7c6 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -66,6 +66,7 @@ class ObjectState: color: util.Array +# TODO : I think it would make more sense to have max_agents, max_objects here instead of n_*** @dataclass class SimulatorState: idx: util.Array @@ -111,7 +112,7 @@ def field(self, stype_or_nested_fields): return res - # Should we keep this function because it is duplicated below ? + # TODO : Should we keep this function because it is duplicated below ? # def nve_idx(self, etype): # cond = self.e_cond(etype) # return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list @@ -171,7 +172,21 @@ def init_simulator_state( to_jit= jnp.array([1*to_jit]), use_fori_loop=jnp.array([1*use_fori_loop])) -# TODO : Should also add union float, list[float] for friction, diameter ... + +def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2): + assert (len(positions) == n_elements if positions else True) + # If positions are passed, transform them in jax array + if positions: + positions = jnp.array(positions) + # Else initialize random positions + else: + positions = random.uniform(key_pos, (n_elements, n_dims)) * box_size + return positions + +# TODO : Should also add union float, list[float] for friction, diameter .. +# Also think it would be easier to only handle the nve state in the physics engine. (if it doesn't make the simulation run slower) +# Here it makes a really long function, and it isn't very modular +# Plus we store different attributes of entities (e.g Agents positions and sensors) in different dataclasses def init_nve_state( simulator_state: SimulatorState, diameter: float = 5., @@ -190,30 +205,22 @@ def init_nve_state( n_entities = n_agents + n_objects key = random.PRNGKey(seed) - key_pos, key_or = random.split(key, 2) + key_pos, key_or = random.split(key) + key_ag, key_obj = random.split(key_pos) - # If we have a list of agents positions, transform it into a jax array - if agents_positions: - agents_positions = jnp.array(agents_positions) - # Else initialize random positions - else: - agents_positions = random.uniform(key_pos, (n_agents, 2)) * simulator_state.box_size + # If we have a list of agents or objects positions, transform it into a jax array, else initialize random positions + agents_positions = _init_positions(key_ag, agents_positions, n_agents, simulator_state.box_size) + objects_positions = _init_positions(key_obj, objects_positions, n_objects, simulator_state.box_size) + # Assign their positions to each entities + positions = jnp.concatenate((agents_positions, objects_positions)) - # Same for ojects positions - if objects_positions: - objects_positions = jnp.array(objects_positions) - else: - objects_positions = random.uniform(key_pos, (n_objects, 2)) * simulator_state.box_size + # Assign random orientations between 0 and 2*pi + orientations = random.uniform(key_or, (n_entities,)) * 2 * jnp.pi agents_entities = jnp.full(n_agents, EntityType.AGENT.value) object_entities = jnp.full(n_objects, EntityType.OBJECT.value) entity_types = jnp.concatenate((agents_entities, object_entities), dtype=int) - # Assign their positions to each entities - positions = jnp.concatenate((agents_positions, objects_positions)) - # Assign random orientations between 0 and 2*pi - orientations = random.uniform(key_or, (n_entities,)) * 2 * jnp.pi - return NVEState( position=RigidBody(center=positions, orientation=orientations), # TODO: Why is momentum set to none ? @@ -225,7 +232,7 @@ def init_nve_state( entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), diameter=jnp.full((n_entities), diameter), friction=jnp.full((n_entities), friction), - # Set all the entities to exist by default, but we should add a function to be able to change that + # Set all the entities to exist by default, but we should add a function to be able to change that exists=jnp.ones(n_entities, dtype=int) ) @@ -246,8 +253,8 @@ def init_agent_state( """ n_agents = simulator_state.n_agents[0] # TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters) - # TODO : if the shape if just 1 value, assign it to all agents - # TODO : else, ensure you are given a list of arguments of size n_agents and transform it into a jax array + # if the shape if just 1 value, assign it to all agents + # else, ensure you are given a list of arguments of size n_agents and transform it into a jax array return AgentState( nve_idx=jnp.arange(n_agents, dtype=int), From 98a505b4cb6f17d9b0ea33b7e0e055221d4bba63 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 21 Mar 2024 12:24:18 +0100 Subject: [PATCH 12/16] Allow defining a initial number of existing entities --- scripts/run_server.py | 9 ++++++++- vivarium/simulator/states.py | 24 +++++++++++++++++++----- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index 37c7f43..8e12ac6 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -18,7 +18,9 @@ def parse_args(): parser.add_argument('--custom_pos', action='store_true', help='Just a test arg to wether use custom or random pos') parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') parser.add_argument('--n_agents', type=int, default=10, help='Number of agents') + parser.add_argument('--n_existing_agents', type=int, default=10, help='Number of agents') parser.add_argument('--n_objects', type=int, default=2, help='Number of objects') + parser.add_argument('--n_existing_objects', type=int, default=2, help='Number of agents') parser.add_argument('--num_steps-lax', type=int, default=4, help='Number of lax steps per loop') parser.add_argument('--dt', type=float, default=0.1, help='Time step size') parser.add_argument('--freq', type=float, default=40.0, help='Frequency parameter') @@ -63,7 +65,12 @@ def parse_args(): objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state(simulator_state=simulator_state, agents_positions=positions) + nve_state = init_nve_state( + simulator_state=simulator_state, + agents_positions=positions, + existing_agents=args.n_existing_agents, + existing_objects=args.n_existing_objects, + ) state = init_state( simulator_state=simulator_state, diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index e3df7c6..0b33aae 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -161,7 +161,7 @@ def init_simulator_state( """ return SimulatorState( idx=jnp.array([0]), - box_size=jnp.array([box_size]), + box_size=jnp.array([box_size]), n_agents=jnp.array([n_agents]), n_objects=jnp.array([n_objects]), num_steps_lax=jnp.array([num_steps_lax], dtype=int), @@ -183,6 +183,16 @@ def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2): positions = random.uniform(key_pos, (n_elements, n_dims)) * box_size return positions +def _init_existing(n_existing, n_elements): + if n_existing: + assert n_existing <= n_elements + existing_arr = jnp.ones((n_existing)) + non_existing_arr = jnp.zeros((n_elements - n_existing)) + exists_array = jnp.concatenate((existing_arr, non_existing_arr)) + else: + exists_array = jnp.ones((n_elements)) + return exists_array + # TODO : Should also add union float, list[float] for friction, diameter .. # Also think it would be easier to only handle the nve state in the physics engine. (if it doesn't make the simulation run slower) # Here it makes a really long function, and it isn't very modular @@ -195,6 +205,8 @@ def init_nve_state( mass_orientation: float = 0.125, agents_positions: Optional[Union[List[float], None]] = None, objects_positions: Optional[Union[List[float], None]] = None, + existing_agents: Optional[Union[int, List[float], None]] = None, + existing_objects: Optional[Union[int, List[float], None]] = None, seed: int = 0, ) -> NVEState: """ @@ -221,19 +233,21 @@ def init_nve_state( object_entities = jnp.full(n_objects, EntityType.OBJECT.value) entity_types = jnp.concatenate((agents_entities, object_entities), dtype=int) + existing_agents = _init_existing(existing_agents, n_agents) + existing_objects = _init_existing(existing_objects, n_objects) + exists = jnp.concatenate((existing_agents, existing_objects), dtype=int) + + # TODO: Why is momentum set to none ? return NVEState( position=RigidBody(center=positions, orientation=orientations), - # TODO: Why is momentum set to none ? momentum=None, - # Should we indeed set the force and mass to 0 ? force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), mass=RigidBody(center=jnp.full((n_entities, 1), mass_center), orientation=jnp.full((n_entities), mass_orientation)), entity_type=entity_types, entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), diameter=jnp.full((n_entities), diameter), friction=jnp.full((n_entities), friction), - # Set all the entities to exist by default, but we should add a function to be able to change that - exists=jnp.ones(n_entities, dtype=int) + exists=exists ) From 0d423175fb3d932548bf79bc0363e7bec02a61d3 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 4 Apr 2024 00:00:19 +0200 Subject: [PATCH 13/16] Remove code to test custom positions in run_server --- scripts/run_server.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index 8e12ac6..6eea169 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -15,7 +15,6 @@ def parse_args(): parser = argparse.ArgumentParser(description='Simulator Configuration') - parser.add_argument('--custom_pos', action='store_true', help='Just a test arg to wether use custom or random pos') parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') parser.add_argument('--n_agents', type=int, default=10, help='Number of agents') parser.add_argument('--n_existing_agents', type=int, default=10, help='Number of agents') @@ -38,17 +37,6 @@ def parse_args(): logging.basicConfig(level=args.log_level.upper()) - # Define a custom list of init agents positions (ugly here but will be done in a yaml file later) - if args.custom_pos: - coord = 0 - positions = [] - # Here all the agents are placed on the box diagonal - for i in range(args.n_agents): - coord += 10 - positions.append([coord, coord]) - else: - positions = None - simulator_state = init_simulator_state( box_size=args.box_size, n_agents=args.n_agents, @@ -66,8 +54,7 @@ def parse_args(): objects_state = init_object_state(simulator_state=simulator_state) nve_state = init_nve_state( - simulator_state=simulator_state, - agents_positions=positions, + simulator_state=simulator_state, existing_agents=args.n_existing_agents, existing_objects=args.n_existing_objects, ) From adefccef9fa7646404194be5944637573fc78cfe Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 4 Apr 2024 13:10:38 +0200 Subject: [PATCH 14/16] Fix vmap bug on normal function and clean sim_computation --- vivarium/simulator/sim_computation.py | 129 +------------------------- 1 file changed, 2 insertions(+), 127 deletions(-) diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index 49fb2c6..243c68a 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -5,138 +5,13 @@ from jax import ops, vmap, lax from jax_md import space, rigid_body, util, simulate, energy, quantity -from jax_md.dataclasses import dataclass f32 = util.f32 -# Only work on 2D environments atm +# Only work on 2D environments atm SPACE_NDIMS = 2 - -class EntityType(Enum): - AGENT = 0 - OBJECT = 1 - - def to_state_type(self): - return StateType(self.value) - -class StateType(Enum): - AGENT = 0 - OBJECT = 1 - SIMULATOR = 2 - - def is_entity(self): - return self != StateType.SIMULATOR - - def to_entity_type(self): - assert self.is_entity() - return EntityType(self.value) - - - -@dataclass -class NVEState(simulate.NVEState): - entity_type: util.Array - entity_idx: util.Array # idx in XState (e.g. AgentState) - diameter: util.Array - friction: util.Array - exists: util.Array - - @property - def velocity(self) -> util.Array: - return self.momentum / self.mass - -@dataclass -class AgentState: - nve_idx: util.Array # idx in NVEState - prox: util.Array - motor: util.Array - behavior: util.Array - wheel_diameter: util.Array - speed_mul: util.Array - theta_mul: util.Array - proxs_dist_max: util.Array - proxs_cos_min: util.Array - color: util.Array - - -@dataclass -class ObjectState: - nve_idx: util.Array # idx in NVEState - color: util.Array - -@dataclass -class SimulatorState: - idx: util.Array - box_size: util.Array - n_agents: util.Array - n_objects: util.Array - num_steps_lax: util.Array - dt: util.Array - freq: util.Array - neighbor_radius: util.Array - to_jit: util.Array - use_fori_loop: util.Array - - @staticmethod - def get_type(attr): - if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']: - return int - elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius']: - return float - elif attr in ['to_jit', 'use_fori_loop']: - return bool - else: - raise ValueError() - -@dataclass -class State: - simulator_state: SimulatorState - nve_state: NVEState - agent_state: AgentState - object_state: ObjectState - - def field(self, stype_or_nested_fields): - if isinstance(stype_or_nested_fields, StateType): - name = stype_or_nested_fields.name.lower() - nested_fields = (f'{name}_state', ) - else: - nested_fields = stype_or_nested_fields - - res = self - for f in nested_fields: - res = getattr(res, f) - - return res - - # Should we keep this function because it is duplicated below ? - # def nve_idx(self, etype): - # cond = self.e_cond(etype) - # return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list - - def nve_idx(self, etype, entity_idx): - return self.field(etype).nve_idx[entity_idx] - - def e_idx(self, etype): - return self.nve_state.entity_idx[self.nve_state.entity_type == etype.value] - - def e_cond(self, etype): - return self.nve_state.entity_type == etype.value - - def row_idx(self, field, nve_idx): - return nve_idx if field == 'nve_state' else self.nve_state.entity_idx[jnp.array(nve_idx)] - - def __getattr__(self, name): - def wrapper(e_type): - value = getattr(self.nve_state, name) - if isinstance(value, rigid_body.RigidBody): - return rigid_body.RigidBody(center=value.center[self.e_cond(e_type)], - orientation=value.orientation[self.e_cond(e_type)]) - else: - return value[self.e_cond(e_type)] - return wrapper - - +@vmap def normal(theta): return jnp.array([jnp.cos(theta), jnp.sin(theta)]) From a54f9208cc9d25785bca81e76d442c851df795e6 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 4 Apr 2024 13:15:31 +0200 Subject: [PATCH 15/16] Add collision parameters and clean files --- scripts/run_simulation.py | 4 ++-- tests/test_simulator_init.py | 17 +++++++++------ tests/test_simulator_run.py | 2 -- vivarium/simulator/states.py | 42 ++++++++++++++++-------------------- 4 files changed, 31 insertions(+), 34 deletions(-) diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index c128f9b..1279853 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -28,8 +28,8 @@ def parse_args(): # By default jit compile the code and use normal python loops parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation') parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop') - parser.add_argument('--collision_eps', type=float, required=False, default=0.3) - parser.add_argument('--collision_alpha', type=float, required=False, default=0.7) + parser.add_argument('--collision_eps', type=float, required=False, default=0.1) + parser.add_argument('--collision_alpha', type=float, required=False, default=0.5) return parser.parse_args() diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index a283554..7ea3718 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -32,6 +32,8 @@ def test_init_simulator_args(): box_size = 100.0 n_agents = 10 n_objects = 2 + col_eps = 0.1 + col_alpha = 0.5 diameter = 5.0 friction = 0.1 @@ -45,13 +47,15 @@ def test_init_simulator_args(): simulator_state = init_simulator_state( box_size=box_size, - n_agents=n_agents, - n_objects=n_objects) + n_agents=n_agents, + n_objects=n_objects, + collision_eps=col_eps, + collision_alpha=col_alpha) nve_state = init_nve_state( simulator_state, - diameter=diameter, - friction=friction) + diameter=diameter, + friction=friction) agent_state = init_agent_state( simulator_state, @@ -63,15 +67,14 @@ def test_init_simulator_args(): prox_cos_min=prox_cos_min) object_state = init_object_state( - simulator_state, + simulator_state, color=color) state = init_state( simulator_state=simulator_state, agents_state=agent_state, objects_state=object_state, - nve_state=nve_state - ) + nve_state=nve_state) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) diff --git a/tests/test_simulator_run.py b/tests/test_simulator_run.py index 8eff893..5e9c23b 100644 --- a/tests/test_simulator_run.py +++ b/tests/test_simulator_run.py @@ -9,9 +9,7 @@ NUM_STEPS = 50 -# First smoke test, we could split it into different parts later (initialization, run, ...) def test_simulator_run(): - simulator_state = init_simulator_state() agents_state = init_agent_state(simulator_state=simulator_state) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 0b33aae..b53c9da 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -10,7 +10,7 @@ from jax_md.rigid_body import RigidBody -# TODO : Add documentation on these classes +# TODO : Add documentation on these classes class EntityType(Enum): AGENT = 0 OBJECT = 1 @@ -30,8 +30,6 @@ def to_entity_type(self): assert self.is_entity() return EntityType(self.value) - -# NVE (we could potentially rename it entities ? What do you think ?) # No need to define position, momentum, force, and mass (i.e already in simulate.NVEState) @dataclass class NVEState(simulate.NVEState): @@ -79,17 +77,19 @@ class SimulatorState: neighbor_radius: util.Array to_jit: util.Array use_fori_loop: util.Array + collision_alpha: util.Array + collision_eps: util.Array @staticmethod def get_type(attr): if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']: return int - elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius']: + elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius', 'collision_alpha', 'collision_eps']: return float elif attr in ['to_jit', 'use_fori_loop']: return bool else: - raise ValueError() + raise ValueError(f"Unknown attribute {attr}") @dataclass @@ -154,7 +154,9 @@ def init_simulator_state( freq: float = 40., neighbor_radius: float = 100., to_jit: bool = True, - use_fori_loop: bool = False + use_fori_loop: bool = False, + collision_alpha: float = 0.5, + collision_eps: float = 0.1 ) -> SimulatorState: """ Initialize simulator state with given parameters @@ -170,33 +172,32 @@ def init_simulator_state( neighbor_radius=jnp.array([neighbor_radius], dtype=float), # Use 1*bool to transform True to 1 and False to 0 to_jit= jnp.array([1*to_jit]), - use_fori_loop=jnp.array([1*use_fori_loop])) + use_fori_loop=jnp.array([1*use_fori_loop]), + collision_alpha=jnp.array([collision_alpha]), + collision_eps=jnp.array([collision_eps])) -def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2): - assert (len(positions) == n_elements if positions else True) +def _init_positions(key_pos, positions, n_entities, box_size, n_dims=2): + assert (positions is None or len(positions) == n_entities) # If positions are passed, transform them in jax array if positions: positions = jnp.array(positions) # Else initialize random positions else: - positions = random.uniform(key_pos, (n_elements, n_dims)) * box_size + positions = random.uniform(key_pos, (n_entities, n_dims)) * box_size return positions -def _init_existing(n_existing, n_elements): +def _init_existing(n_existing, n_entities): if n_existing: - assert n_existing <= n_elements + assert n_existing <= n_entities existing_arr = jnp.ones((n_existing)) - non_existing_arr = jnp.zeros((n_elements - n_existing)) + non_existing_arr = jnp.zeros((n_entities - n_existing)) exists_array = jnp.concatenate((existing_arr, non_existing_arr)) else: - exists_array = jnp.ones((n_elements)) + exists_array = jnp.ones((n_entities)) return exists_array -# TODO : Should also add union float, list[float] for friction, diameter .. -# Also think it would be easier to only handle the nve state in the physics engine. (if it doesn't make the simulation run slower) -# Here it makes a really long function, and it isn't very modular -# Plus we store different attributes of entities (e.g Agents positions and sensors) in different dataclasses + def init_nve_state( simulator_state: SimulatorState, diameter: float = 5., @@ -237,7 +238,6 @@ def init_nve_state( existing_objects = _init_existing(existing_objects, n_objects) exists = jnp.concatenate((existing_agents, existing_objects), dtype=int) - # TODO: Why is momentum set to none ? return NVEState( position=RigidBody(center=positions, orientation=orientations), momentum=None, @@ -251,7 +251,6 @@ def init_nve_state( ) -# Could implement it as a static or class method def init_agent_state( simulator_state: SimulatorState, behavior: int = 1, @@ -266,9 +265,6 @@ def init_agent_state( Initialize agent state with given parameters """ n_agents = simulator_state.n_agents[0] - # TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters) - # if the shape if just 1 value, assign it to all agents - # else, ensure you are given a list of arguments of size n_agents and transform it into a jax array return AgentState( nve_idx=jnp.arange(n_agents, dtype=int), From efa959edb14812b2aaef2a1906dba251d5abc2e9 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 4 Apr 2024 14:31:24 +0200 Subject: [PATCH 16/16] Fix scipy version --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index a000c53..8938e98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ jax==0.4.23 jaxlib==0.4.23 jax-md==0.2.8 +scipy==1.12.0 # Interface panel==1.3.8