Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Nov 15, 2024
1 parent e98d489 commit d56990c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
1 change: 1 addition & 0 deletions vivarium/environments/base_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO : Update this file to make it match with current architecture
import logging as lg

from functools import partial
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def set_to_none_if_all_none(lst):
return lst



### Helper functions to generate elements sub states of the state

def init_entities(
Expand Down
3 changes: 3 additions & 0 deletions vivarium/environments/braitenberg/simple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# TODO : reorganize this directory to put simple / selective sensing in subdirectories
# TODO : move the classes and init functions in separate files in the simple env

import logging as lg

from enum import Enum
Expand Down
34 changes: 31 additions & 3 deletions vivarium/environments/physics_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# TODO : Add documentation for all functions
from functools import partial

import jax
Expand All @@ -8,10 +7,8 @@
from jax_md import rigid_body, util, simulate, energy, quantity
f32 = util.f32


SPACE_NDIMS = 2

# Helper functions for collisions
def collision_energy(displacement_fn, r_a, r_b, l_a, l_b, epsilon, alpha, mask):
"""Compute the collision energy between a pair of particles
Expand Down Expand Up @@ -66,6 +63,12 @@ def total_collision_energy(positions, diameter, neighbor, displacement, exists_m

# Functions to compute the verlet force on the whole system
def friction_force(state, exists_mask):
"""Compute the friction force on the system
:param state: current state of the system
:param exists_mask: mask to specify which particles exist
:return: friction force on the system
"""
cur_vel = state.entities.momentum.center / state.entities.mass.center
# stack the mask to give it the same shape as cur_vel (that has 2 rows for forward and angular velocities)
mask = jnp.stack([exists_mask] * 2, axis=1)
Expand All @@ -89,6 +92,11 @@ def collision_force(state, neighbor, exists_mask, displacement):


def verlet_force_fn(displacement):
"""Compute the verlet force on the whole system
:param displacement: displacement function of jax_md
:return: force function of the system
"""
coll_force_fn = quantity.force(partial(total_collision_energy, displacement=displacement))

def collision_force(state, neighbor, exists_mask):
Expand All @@ -111,9 +119,23 @@ def force_fn(state, neighbor, exists_mask):


def dynamics_fn(displacement, shift, force_fn=None):
"""Compute the dynamics of the system
:param displacement: displacement function of jax_md
:param shift: shift function of jax_md
:param force_fn: given force function, defaults to None
:return: init_fn, step_fn functions of jax_md to compute the dynamics of the system
"""
force_fn = force_fn(displacement) if force_fn else verlet_force_fn(displacement)

def init_fn(state, key, kT=0.):
"""Initialize the system
:param state: current state of the system
:param key: random key
:param kT: kT, defaults to 0.
:return: new state of the system
"""
key, _ = jax.random.split(key)
assert state.entities.momentum is None
assert not jnp.any(state.entities.force.center) and not jnp.any(state.entities.force.orientation)
Expand All @@ -135,6 +157,12 @@ def mask_momentum(entity_state, exists_mask):
return entity_state.set(momentum=momentum)

def step_fn(state, neighbor):
"""Compute the next state of the system
:param state: current state of the system
:param neighbor: neighbor array of the system
:return: new state of the system
"""
exists_mask = (state.entities.exists == 1) # Only existing entities have effect on others
dt_2 = state.dt / 2.
# Compute forces
Expand Down

0 comments on commit d56990c

Please sign in to comment.