diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index c128f9b..1279853 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -28,8 +28,8 @@ def parse_args(): # By default jit compile the code and use normal python loops parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation') parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop') - parser.add_argument('--collision_eps', type=float, required=False, default=0.3) - parser.add_argument('--collision_alpha', type=float, required=False, default=0.7) + parser.add_argument('--collision_eps', type=float, required=False, default=0.1) + parser.add_argument('--collision_alpha', type=float, required=False, default=0.5) return parser.parse_args() diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index a283554..7ea3718 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -32,6 +32,8 @@ def test_init_simulator_args(): box_size = 100.0 n_agents = 10 n_objects = 2 + col_eps = 0.1 + col_alpha = 0.5 diameter = 5.0 friction = 0.1 @@ -45,13 +47,15 @@ def test_init_simulator_args(): simulator_state = init_simulator_state( box_size=box_size, - n_agents=n_agents, - n_objects=n_objects) + n_agents=n_agents, + n_objects=n_objects, + collision_eps=col_eps, + collision_alpha=col_alpha) nve_state = init_nve_state( simulator_state, - diameter=diameter, - friction=friction) + diameter=diameter, + friction=friction) agent_state = init_agent_state( simulator_state, @@ -63,15 +67,14 @@ def test_init_simulator_args(): prox_cos_min=prox_cos_min) object_state = init_object_state( - simulator_state, + simulator_state, color=color) state = init_state( simulator_state=simulator_state, agents_state=agent_state, objects_state=object_state, - nve_state=nve_state - ) + nve_state=nve_state) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) diff --git a/tests/test_simulator_run.py b/tests/test_simulator_run.py index 8eff893..5e9c23b 100644 --- a/tests/test_simulator_run.py +++ b/tests/test_simulator_run.py @@ -9,9 +9,7 @@ NUM_STEPS = 50 -# First smoke test, we could split it into different parts later (initialization, run, ...) def test_simulator_run(): - simulator_state = init_simulator_state() agents_state = init_agent_state(simulator_state=simulator_state) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 0b33aae..b53c9da 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -10,7 +10,7 @@ from jax_md.rigid_body import RigidBody -# TODO : Add documentation on these classes +# TODO : Add documentation on these classes class EntityType(Enum): AGENT = 0 OBJECT = 1 @@ -30,8 +30,6 @@ def to_entity_type(self): assert self.is_entity() return EntityType(self.value) - -# NVE (we could potentially rename it entities ? What do you think ?) # No need to define position, momentum, force, and mass (i.e already in simulate.NVEState) @dataclass class NVEState(simulate.NVEState): @@ -79,17 +77,19 @@ class SimulatorState: neighbor_radius: util.Array to_jit: util.Array use_fori_loop: util.Array + collision_alpha: util.Array + collision_eps: util.Array @staticmethod def get_type(attr): if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']: return int - elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius']: + elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius', 'collision_alpha', 'collision_eps']: return float elif attr in ['to_jit', 'use_fori_loop']: return bool else: - raise ValueError() + raise ValueError(f"Unknown attribute {attr}") @dataclass @@ -154,7 +154,9 @@ def init_simulator_state( freq: float = 40., neighbor_radius: float = 100., to_jit: bool = True, - use_fori_loop: bool = False + use_fori_loop: bool = False, + collision_alpha: float = 0.5, + collision_eps: float = 0.1 ) -> SimulatorState: """ Initialize simulator state with given parameters @@ -170,33 +172,32 @@ def init_simulator_state( neighbor_radius=jnp.array([neighbor_radius], dtype=float), # Use 1*bool to transform True to 1 and False to 0 to_jit= jnp.array([1*to_jit]), - use_fori_loop=jnp.array([1*use_fori_loop])) + use_fori_loop=jnp.array([1*use_fori_loop]), + collision_alpha=jnp.array([collision_alpha]), + collision_eps=jnp.array([collision_eps])) -def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2): - assert (len(positions) == n_elements if positions else True) +def _init_positions(key_pos, positions, n_entities, box_size, n_dims=2): + assert (positions is None or len(positions) == n_entities) # If positions are passed, transform them in jax array if positions: positions = jnp.array(positions) # Else initialize random positions else: - positions = random.uniform(key_pos, (n_elements, n_dims)) * box_size + positions = random.uniform(key_pos, (n_entities, n_dims)) * box_size return positions -def _init_existing(n_existing, n_elements): +def _init_existing(n_existing, n_entities): if n_existing: - assert n_existing <= n_elements + assert n_existing <= n_entities existing_arr = jnp.ones((n_existing)) - non_existing_arr = jnp.zeros((n_elements - n_existing)) + non_existing_arr = jnp.zeros((n_entities - n_existing)) exists_array = jnp.concatenate((existing_arr, non_existing_arr)) else: - exists_array = jnp.ones((n_elements)) + exists_array = jnp.ones((n_entities)) return exists_array -# TODO : Should also add union float, list[float] for friction, diameter .. -# Also think it would be easier to only handle the nve state in the physics engine. (if it doesn't make the simulation run slower) -# Here it makes a really long function, and it isn't very modular -# Plus we store different attributes of entities (e.g Agents positions and sensors) in different dataclasses + def init_nve_state( simulator_state: SimulatorState, diameter: float = 5., @@ -237,7 +238,6 @@ def init_nve_state( existing_objects = _init_existing(existing_objects, n_objects) exists = jnp.concatenate((existing_agents, existing_objects), dtype=int) - # TODO: Why is momentum set to none ? return NVEState( position=RigidBody(center=positions, orientation=orientations), momentum=None, @@ -251,7 +251,6 @@ def init_nve_state( ) -# Could implement it as a static or class method def init_agent_state( simulator_state: SimulatorState, behavior: int = 1, @@ -266,9 +265,6 @@ def init_agent_state( Initialize agent state with given parameters """ n_agents = simulator_state.n_agents[0] - # TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters) - # if the shape if just 1 value, assign it to all agents - # else, ensure you are given a list of arguments of size n_agents and transform it into a jax array return AgentState( nve_idx=jnp.arange(n_agents, dtype=int),