Skip to content

Commit

Permalink
Add collision parameters and clean files
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Apr 4, 2024
1 parent adefcce commit a54f920
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 34 deletions.
4 changes: 2 additions & 2 deletions scripts/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def parse_args():
# By default jit compile the code and use normal python loops
parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation')
parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop')
parser.add_argument('--collision_eps', type=float, required=False, default=0.3)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.7)
parser.add_argument('--collision_eps', type=float, required=False, default=0.1)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.5)

return parser.parse_args()

Expand Down
17 changes: 10 additions & 7 deletions tests/test_simulator_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_init_simulator_args():
box_size = 100.0
n_agents = 10
n_objects = 2
col_eps = 0.1
col_alpha = 0.5

diameter = 5.0
friction = 0.1
Expand All @@ -45,13 +47,15 @@ def test_init_simulator_args():

simulator_state = init_simulator_state(
box_size=box_size,
n_agents=n_agents,
n_objects=n_objects)
n_agents=n_agents,
n_objects=n_objects,
collision_eps=col_eps,
collision_alpha=col_alpha)

nve_state = init_nve_state(
simulator_state,
diameter=diameter,
friction=friction)
diameter=diameter,
friction=friction)

agent_state = init_agent_state(
simulator_state,
Expand All @@ -63,15 +67,14 @@ def test_init_simulator_args():
prox_cos_min=prox_cos_min)

object_state = init_object_state(
simulator_state,
simulator_state,
color=color)

state = init_state(
simulator_state=simulator_state,
agents_state=agent_state,
objects_state=object_state,
nve_state=nve_state
)
nve_state=nve_state)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

Expand Down
2 changes: 0 additions & 2 deletions tests/test_simulator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

NUM_STEPS = 50

# First smoke test, we could split it into different parts later (initialization, run, ...)
def test_simulator_run():

simulator_state = init_simulator_state()

agents_state = init_agent_state(simulator_state=simulator_state)
Expand Down
42 changes: 19 additions & 23 deletions vivarium/simulator/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax_md.rigid_body import RigidBody


# TODO : Add documentation on these classes
# TODO : Add documentation on these classes
class EntityType(Enum):
AGENT = 0
OBJECT = 1
Expand All @@ -30,8 +30,6 @@ def to_entity_type(self):
assert self.is_entity()
return EntityType(self.value)


# NVE (we could potentially rename it entities ? What do you think ?)
# No need to define position, momentum, force, and mass (i.e already in simulate.NVEState)
@dataclass
class NVEState(simulate.NVEState):
Expand Down Expand Up @@ -79,17 +77,19 @@ class SimulatorState:
neighbor_radius: util.Array
to_jit: util.Array
use_fori_loop: util.Array
collision_alpha: util.Array
collision_eps: util.Array

@staticmethod
def get_type(attr):
if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']:
return int
elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius']:
elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius', 'collision_alpha', 'collision_eps']:
return float
elif attr in ['to_jit', 'use_fori_loop']:
return bool
else:
raise ValueError()
raise ValueError(f"Unknown attribute {attr}")


@dataclass
Expand Down Expand Up @@ -154,7 +154,9 @@ def init_simulator_state(
freq: float = 40.,
neighbor_radius: float = 100.,
to_jit: bool = True,
use_fori_loop: bool = False
use_fori_loop: bool = False,
collision_alpha: float = 0.5,
collision_eps: float = 0.1
) -> SimulatorState:
"""
Initialize simulator state with given parameters
Expand All @@ -170,33 +172,32 @@ def init_simulator_state(
neighbor_radius=jnp.array([neighbor_radius], dtype=float),
# Use 1*bool to transform True to 1 and False to 0
to_jit= jnp.array([1*to_jit]),
use_fori_loop=jnp.array([1*use_fori_loop]))
use_fori_loop=jnp.array([1*use_fori_loop]),
collision_alpha=jnp.array([collision_alpha]),
collision_eps=jnp.array([collision_eps]))


def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2):
assert (len(positions) == n_elements if positions else True)
def _init_positions(key_pos, positions, n_entities, box_size, n_dims=2):
assert (positions is None or len(positions) == n_entities)
# If positions are passed, transform them in jax array
if positions:
positions = jnp.array(positions)
# Else initialize random positions
else:
positions = random.uniform(key_pos, (n_elements, n_dims)) * box_size
positions = random.uniform(key_pos, (n_entities, n_dims)) * box_size
return positions

def _init_existing(n_existing, n_elements):
def _init_existing(n_existing, n_entities):
if n_existing:
assert n_existing <= n_elements
assert n_existing <= n_entities
existing_arr = jnp.ones((n_existing))
non_existing_arr = jnp.zeros((n_elements - n_existing))
non_existing_arr = jnp.zeros((n_entities - n_existing))
exists_array = jnp.concatenate((existing_arr, non_existing_arr))
else:
exists_array = jnp.ones((n_elements))
exists_array = jnp.ones((n_entities))
return exists_array

# TODO : Should also add union float, list[float] for friction, diameter ..
# Also think it would be easier to only handle the nve state in the physics engine. (if it doesn't make the simulation run slower)
# Here it makes a really long function, and it isn't very modular
# Plus we store different attributes of entities (e.g Agents positions and sensors) in different dataclasses

def init_nve_state(
simulator_state: SimulatorState,
diameter: float = 5.,
Expand Down Expand Up @@ -237,7 +238,6 @@ def init_nve_state(
existing_objects = _init_existing(existing_objects, n_objects)
exists = jnp.concatenate((existing_agents, existing_objects), dtype=int)

# TODO: Why is momentum set to none ?
return NVEState(
position=RigidBody(center=positions, orientation=orientations),
momentum=None,
Expand All @@ -251,7 +251,6 @@ def init_nve_state(
)


# Could implement it as a static or class method
def init_agent_state(
simulator_state: SimulatorState,
behavior: int = 1,
Expand All @@ -266,9 +265,6 @@ def init_agent_state(
Initialize agent state with given parameters
"""
n_agents = simulator_state.n_agents[0]
# TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters)
# if the shape if just 1 value, assign it to all agents
# else, ensure you are given a list of arguments of size n_agents and transform it into a jax array

return AgentState(
nve_idx=jnp.arange(n_agents, dtype=int),
Expand Down

0 comments on commit a54f920

Please sign in to comment.