Skip to content

Commit

Permalink
Refactor init_positions function and update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Mar 21, 2024
1 parent b7e98d9 commit fd96746
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions vivarium/simulator/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ObjectState:
color: util.Array


# TODO : I think it would make more sense to have max_agents, max_objects here instead of n_***
@dataclass
class SimulatorState:
idx: util.Array
Expand Down Expand Up @@ -111,7 +112,7 @@ def field(self, stype_or_nested_fields):

return res

# Should we keep this function because it is duplicated below ?
# TODO : Should we keep this function because it is duplicated below ?
# def nve_idx(self, etype):
# cond = self.e_cond(etype)
# return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list
Expand Down Expand Up @@ -171,7 +172,21 @@ def init_simulator_state(
to_jit= jnp.array([1*to_jit]),
use_fori_loop=jnp.array([1*use_fori_loop]))

# TODO : Should also add union float, list[float] for friction, diameter ...

def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2):
assert (len(positions) == n_elements if positions else True)
# If positions are passed, transform them in jax array
if positions:
positions = jnp.array(positions)
# Else initialize random positions
else:
positions = random.uniform(key_pos, (n_elements, n_dims)) * box_size
return positions

# TODO : Should also add union float, list[float] for friction, diameter ..
# Also think it would be easier to only handle the nve state in the physics engine. (if it doesn't make the simulation run slower)
# Here it makes a really long function, and it isn't very modular
# Plus we store different attributes of entities (e.g Agents positions and sensors) in different dataclasses
def init_nve_state(
simulator_state: SimulatorState,
diameter: float = 5.,
Expand All @@ -190,30 +205,22 @@ def init_nve_state(
n_entities = n_agents + n_objects

key = random.PRNGKey(seed)
key_pos, key_or = random.split(key, 2)
key_pos, key_or = random.split(key)
key_ag, key_obj = random.split(key_pos)

# If we have a list of agents positions, transform it into a jax array
if agents_positions:
agents_positions = jnp.array(agents_positions)
# Else initialize random positions
else:
agents_positions = random.uniform(key_pos, (n_agents, 2)) * simulator_state.box_size
# If we have a list of agents or objects positions, transform it into a jax array, else initialize random positions
agents_positions = _init_positions(key_ag, agents_positions, n_agents, simulator_state.box_size)
objects_positions = _init_positions(key_obj, objects_positions, n_objects, simulator_state.box_size)
# Assign their positions to each entities
positions = jnp.concatenate((agents_positions, objects_positions))

# Same for ojects positions
if objects_positions:
objects_positions = jnp.array(objects_positions)
else:
objects_positions = random.uniform(key_pos, (n_objects, 2)) * simulator_state.box_size
# Assign random orientations between 0 and 2*pi
orientations = random.uniform(key_or, (n_entities,)) * 2 * jnp.pi

agents_entities = jnp.full(n_agents, EntityType.AGENT.value)
object_entities = jnp.full(n_objects, EntityType.OBJECT.value)
entity_types = jnp.concatenate((agents_entities, object_entities), dtype=int)

# Assign their positions to each entities
positions = jnp.concatenate((agents_positions, objects_positions))
# Assign random orientations between 0 and 2*pi
orientations = random.uniform(key_or, (n_entities,)) * 2 * jnp.pi

return NVEState(
position=RigidBody(center=positions, orientation=orientations),
# TODO: Why is momentum set to none ?
Expand All @@ -225,7 +232,7 @@ def init_nve_state(
entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))),
diameter=jnp.full((n_entities), diameter),
friction=jnp.full((n_entities), friction),
# Set all the entities to exist by default, but we should add a function to be able to change that
# Set all the entities to exist by default, but we should add a function to be able to change that
exists=jnp.ones(n_entities, dtype=int)
)

Expand All @@ -246,8 +253,8 @@ def init_agent_state(
"""
n_agents = simulator_state.n_agents[0]
# TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters)
# TODO : if the shape if just 1 value, assign it to all agents
# TODO : else, ensure you are given a list of arguments of size n_agents and transform it into a jax array
# if the shape if just 1 value, assign it to all agents
# else, ensure you are given a list of arguments of size n_agents and transform it into a jax array

return AgentState(
nve_idx=jnp.arange(n_agents, dtype=int),
Expand Down

0 comments on commit fd96746

Please sign in to comment.