Skip to content

Commit

Permalink
Merge pull request #42 from clement-moulin-frier/corentin/refactor_si…
Browse files Browse the repository at this point in the history
…mulator

Refactor simulation.py file
  • Loading branch information
corentinlger authored Mar 15, 2024
2 parents 7eeb2dd + 092118a commit 1f592d8
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 55 deletions.
4 changes: 2 additions & 2 deletions scripts/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
# Experiment run arguments
parser.add_argument('--num_loops', type=int, default=10, help='Number of simulation loops')
parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation loops')
# Simulator config arguments
parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box')
parser.add_argument('--n_agents', type=int, default=10, help='Number of agents')
Expand Down Expand Up @@ -77,6 +77,6 @@ def parse_args():

lg.info("Running simulation")

simulator.run(threaded=False, num_loops=10)
simulator.run(threaded=False, num_steps=args.num_steps)

lg.info("Simulation complete")
4 changes: 2 additions & 2 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vivarium.controllers import converters


NUM_LOOPS = 50
NUM_STEPS = 50


# First smoke test, we could split it into different parts later (initialization, run, ...)
Expand Down Expand Up @@ -41,6 +41,6 @@ def test_simulator_run():

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

simulator.run(threaded=False, num_loops=NUM_LOOPS)
simulator.run(threaded=False, num_steps=NUM_STEPS)

assert simulator
187 changes: 136 additions & 51 deletions vivarium/simulator/simulator.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from jax import jit
from jax import lax
import time
import threading
import math
import logging

from functools import partial
from contextlib import contextmanager

import jax
import jax.numpy as jnp
import numpy as np

from jax import jit
from jax import lax
from jax_md import space, partition, dataclasses

from contextlib import contextmanager

from vivarium.simulator.sim_computation import dynamics_rigid, EntityType, StateType, SimulatorState
from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig
from vivarium.simulator.sim_computation import EntityType, SimulatorState
from vivarium.controllers import converters
import vivarium.simulator.behaviors as behaviors

import time
import threading
import math
import logging

lg = logging.getLogger(__name__)


class Simulator:
def __init__(self, state, behavior_bank, dynamics_fn):

self.state = state
self.behavior_bank = behavior_bank
self.dynamics_fn = dynamics_fn

# TODO: explicitely copy the attributes of simulator_state (prevents linting errors and easier to understand which element is an attriute of the class)
all_attrs = [f.name for f in dataclasses.fields(SimulatorState)]
for attr in all_attrs:
self.update_attr(attr, SimulatorState.get_type(attr))
Expand All @@ -35,54 +35,134 @@ def __init__(self, state, behavior_bank, dynamics_fn):
self._to_stop = False
self.key = jax.random.PRNGKey(0)

# TODO: Define which attributes are affected but these functions
self.update_space(self.box_size)
self.update_function_update()
self.init_state(state)
self.update_neighbor_fn(self.box_size, self.neighbor_radius)
self.allocate_neighbors()
self.simulation_loop = self.select_simulation_loop_type()


def classic_simulation_loop(self, state, neighbors, num_iterations):
"""Update the state and the neighbors on a few iterations with a classic python loop
def run(self, threaded=False, num_loops=math.inf):
:param state: current_state of the simulation
:param neighbors: array of neighbors for simulation entities
:return: state, neighbors
"""
for i in range(0, num_iterations):
state, neighbors = self.update_fn(i, (state, neighbors))
return state, neighbors

def lax_simulation_loop(self, state, neighbors, num_iterations):
"""Update the state and the neighbors on a few iterations with lax loop
:param state: current_state of the simulation
:param neighbors: array of neighbors for simulation entities
:return: state, neighbors
"""
state, neighbors = lax.fori_loop(0, num_iterations, self.update_fn, (state, neighbors))
return state, neighbors

def select_simulation_loop_type(self):
"""Choose wether to use a lax or a classic simulation loop in function step
:return: appropriate simulation loop
"""
if self.state.simulator_state.use_fori_loop:
return self.lax_simulation_loop
else:
return self.classic_simulation_loop

def step(self, state, neighbors):
"""Do a step in the simulation by applying the update function a few iterations on the state and the neighbors
:param state: current simulation state
:param neighbors: current simulation neighbors array
:return: updated state and neighbors
"""
# Create a copy of the current state in case of neighbor buffer overflow
current_state = state
# TODO : find a more explicit name than num_steps_lax and modify it in all the pipeline
new_state, neighbors = self.simulation_loop(state=current_state, neighbors=neighbors, num_iterations=self.num_steps_lax)

# If the neighbor list can't fit in the allocation, rebuild it but bigger.
if neighbors.did_buffer_overflow:
lg.warning('REBUILDING NEIGHBORS ARRAY')
neighbors = self.allocate_neighbors(current_state.nve_state.position.center)
# Because there was an error, we need to re-run this simulation loop from the copy of the current_state we created
new_state, neighbors = self.simulation_loop(state=current_state, neighbors=neighbors, num_iterations=self.num_steps_lax)
# Check that neighbors array is now ok but should be the case (allocate neighbors tries to compute a new list that is large enough according to the simulation state)
assert not neighbors.did_buffer_overflow

return new_state, neighbors

def run(self, threaded=False, num_steps=math.inf):
"""Run the simulator for the desired number of timesteps, either in a separate thread or not
:param threaded: wether to run the simulation in a thread or not, defaults to False
:param num_steps: number of step loops before stopping the simulation run, defaults to math.inf
:raises ValueError: raise an error if the simulator is already running
"""
# Check is the simulator isn't already running
if self._is_started:
raise Exception("Simulator is already started")
raise ValueError("Simulator is already started")
# Else run it either in a thread or not
if threaded:
threading.Thread(target=self._run).start()
# Set the num_loops attribute with a partial func to launch _run in a thread
_run = partial(self._run, num_steps=num_steps)
threading.Thread(target=_run).start()
else:
return self._run(num_loops)
self._run(num_steps)

def _run(self, num_steps):
"""Function that runs the simulator for the desired number of steps. Used to be called either normally or in a thread.
def _run(self, num_loops=math.inf):
:param num_steps: number of simulation steps
"""
# Encode that the simulation is started in the class
self._is_started = True
lg.info('Run starts')

loop_count = 0
while loop_count < num_loops:
sleep_time = 0

# Update the simulation with step for num_steps
while loop_count < num_steps:
start = time.time()
if self._to_stop:
self._to_stop = False
break
if float(self.freq) > 0.:
time.sleep(1. / float(self.freq))
new_state = self.state
neighbors = self.neighbors
if self.state.simulator_state.use_fori_loop:
new_state, neighbors = lax.fori_loop(0, self.num_steps_lax, self.update_fn,
(new_state, neighbors))
else:

for i in range(0, self.num_steps_lax):
new_state, neighbors = self.update_fn(i, (new_state, neighbors))
# If the neighbor list can't fit in the allocation, rebuild it but bigger.
if neighbors.did_buffer_overflow:
lg.warning('REBUILDING')
neighbors = self.allocate_neighbors(new_state.nve_state.position.center)
# new_state, neighbors = lax.fori_loop(0, self.simulation_config.num_lax_loops, self.update_fn, (self.state, neighbors))
for i in range(0, self.num_steps_lax):
new_state, neighbors = self.update_fn(i, (self.state, neighbors))
assert not neighbors.did_buffer_overflow
self.state = new_state
self.neighbors = neighbors
# lg.info(self.state)

self.state, self.neighbors = self.step(state=self.state, neighbors=self.neighbors)
loop_count += 1

# Sleep for updated sleep_time seconds
end = time.time()
sleep_time = self.update_sleep_time(frequency=self.freq, elapsed_time=end-start)
time.sleep(sleep_time)

# Encode that the simulation isn't started anymore
self._is_started = False
lg.info('Run stops')

def update_sleep_time(self, frequency, elapsed_time):
"""Compute the time we need to sleep to respect the update frequency
:param frequency: update state frequency
:param elapsed_time: time already used to compute the state
:return: time needed to sleep in addition to elapsed time to respect the frequency
"""
# if we use the freq, compute the correct sleep time
if float(frequency) > 0.:
perfect_time = 1. / float(frequency)
sleep_time = max(perfect_time - elapsed_time, 0)
# Else set it to zero
else:
sleep_time = 0
return sleep_time

def set_state(self, nested_field, nve_idx, column_idx, value):
lg.info(f'set_state {nested_field} {nve_idx} {column_idx} {value}')
row_idx = self.state.row_idx(nested_field[0], jnp.array(nve_idx))
Expand All @@ -96,14 +176,15 @@ def set_state(self, nested_field, nve_idx, column_idx, value):
if nested_field == ('simulator_state', 'box_size'):
self.update_space(self.box_size)

if nested_field == ('simulator_state', 'box_size') or nested_field == ('simulator_state', 'neighbor_radius'):
self.update_neighbor_fn(box_size=self.box_size,
neighbor_radius=self.neighbor_radius)
if nested_field in (('simulator_state', 'box_size'), ('simulator_state', 'neighbor_radius')):
self.update_neighbor_fn(box_size=self.box_size, neighbor_radius=self.neighbor_radius)

if nested_field == ('simulator_state', 'box_size') or nested_field == ('simulator_state', 'dt') or \
nested_field == ('simulator_state', 'to_jit'):
if nested_field in (('simulator_state', 'box_size'), ('simulator_state', 'dt'), ('simulator_state', 'to_jit')):
self.update_function_update()


# Functions to start, stop, pause

def start(self):
self.run(threaded=True)

Expand All @@ -118,11 +199,6 @@ def stop(self, blocking=True):
def is_started(self):
return self._is_started

def step(self):
assert not self._is_started
self.run(threaded=False, num_loops=1)
return self.state

@contextmanager
def pause(self):
self.stop(blocking=True)
Expand All @@ -131,6 +207,9 @@ def pause(self):
finally:
self.run(threaded=True)


# Other update functions

def update_attr(self, attr, type_):
lg.info('update_attr')
setattr(self, attr, type_(getattr(self.state.simulator_state, attr)[0]))
Expand Down Expand Up @@ -158,6 +237,9 @@ def init_state(self, state):
lg.info('init_state')
self.state = self.init_fn(state, self.key)


# Neighbor functions

def update_neighbor_fn(self, box_size, neighbor_radius):
lg.info('update_neighbor_fn')
self.neighbor_fn = partition.neighbor_list(self.displacement, box_size,
Expand All @@ -174,6 +256,9 @@ def allocate_neighbors(self, position=None):
mask = self.state.nve_state.entity_type[self.neighbors.idx[0]] == EntityType.AGENT.value
self.agent_neighs_idx = self.neighbors.idx[:, mask]
return self.neighbors


# Other functions

def get_change_time(self):
return 0
Expand Down

0 comments on commit 1f592d8

Please sign in to comment.