Skip to content

Commit

Permalink
Add tests for simulation wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Feb 14, 2024
1 parent 0889fa0 commit bb01f48
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 68 deletions.
51 changes: 51 additions & 0 deletions simulationsandbox/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import threading
import time

from jax import random

class SimulationWrapper:

def __init__(self, simulation, state, key, step_delay=0.1, print_data=False):
self.running = False
self.paused = False
self.stop_requested = False
self.update_thread = None
self.print_data = print_data
self.step_delay = step_delay
# simulation_dependent
self.simulation = simulation
self.state = state
self.key = key

def start(self):
if not self.running:
self.running = True
self.stop_requested = False
self.update_thread = threading.Thread(target=self.simulation_loop)
self.update_thread.start()

def pause(self):
self.paused = True

def resume(self):
self.paused = False

def stop(self):
self.stop_requested = True

def simulation_loop(self):
while not self.stop_requested:
if self.paused:
time.sleep(0.1)
continue

self.state = self._update_simulation()
if self.print_data:
print(f"{self.state = }")

time.sleep(self.step_delay)

def _update_simulation(self):
self.key, a_key, step_key = random.split(self.key, 3)
actions = self.simulation.choose_action(self.state.obs, a_key)
return self.simulation.step(self.state, actions, step_key)
31 changes: 31 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import time

import jax
import jax.numpy as jnp
from jax import random

from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.three_d_simulation import ThreeDSimulation

NUM_AGENTS = 5
MAX_AGENTS = 10
NUM_OBS = 3
GRID_SIZE = 20
NUM_STEPS = 50
VIZUALIZE = True
STEP_DELAY = 0.000001
SEED = 0


def test_simulation_init():
key = random.PRNGKey(SEED)

sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE)
state = sim.init_state(NUM_AGENTS, NUM_OBS, key)

assert sim.max_agents == MAX_AGENTS
assert sim.grid_size == GRID_SIZE
assert state.x_pos.shape == (MAX_AGENTS,)
assert jnp.sum(state.alive) == NUM_AGENTS
assert state.grid.shape == (GRID_SIZE, GRID_SIZE)

74 changes: 74 additions & 0 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time

import jax
import jax.numpy as jnp
from jax import random

from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.three_d_simulation import ThreeDSimulation
from simulationsandbox.wrapper import SimulationWrapper

NUM_AGENTS = 5
MAX_AGENTS = 10
NUM_OBS = 3
GRID_SIZE = 20
NUM_STEPS = 50
VIZUALIZE = True
SLEEP_TIME = 2
STEP_DELAY = 0.01
PRINT_DATA = True
SEED = 0


def test_wrapper_two_d_sim():
key = random.PRNGKey(SEED)
env = SimpleSimulation(MAX_AGENTS, GRID_SIZE)
state = env.init_state(NUM_AGENTS, NUM_OBS, key)

# Example usage:
sim = SimulationWrapper(env, state, key, step_delay=STEP_DELAY, print_data=PRINT_DATA)

print('Started')
sim.start()
time.sleep(SLEEP_TIME)
assert sim.running == True

sim.pause()
print('Paused')
time.sleep(SLEEP_TIME)
assert sim.paused == True

print('Resumed')
sim.resume()
time.sleep(SLEEP_TIME)
assert sim.running == True

sim.stop()
print('stopped')


# def test_wrapper_three_d_sim():
# key = random.PRNGKey(SEED)
# env = ThreeDSimulation(MAX_AGENTS, GRID_SIZE)
# state = env.init_state(NUM_AGENTS, NUM_OBS, key)

# # Example usage:
# sim = SimulationWrapper(env, state, key, print_data=True)

# print('Started')
# sim.start()
# time.sleep(SLEEP_TIME)
# assert sim.running == True

# sim.pause()
# print('Paused')
# time.sleep(SLEEP_TIME)
# assert sim.paused == True

# print('Resumed')
# sim.resume()
# time.sleep(SLEEP_TIME)
# assert sim.running == True

# sim.stop()
# print('stopped')
16 changes: 2 additions & 14 deletions tests/test_pytest.py → tests/test_zrun_sims.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,6 @@
SEED = 0


def test_simulation_init():
key = random.PRNGKey(SEED)

sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE)
state = sim.init_state(NUM_AGENTS, NUM_OBS, key)

assert sim.max_agents == MAX_AGENTS
assert sim.grid_size == GRID_SIZE
assert state.x_pos.shape == (MAX_AGENTS,)
assert jnp.sum(state.alive) == NUM_AGENTS
assert state.grid.shape == (GRID_SIZE, GRID_SIZE)


def test_simple_simulation_run():
key = random.PRNGKey(SEED)
sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE)
Expand Down Expand Up @@ -64,14 +51,15 @@ def test_simple_simulation_run():
state = sim.step(state, actions, step_key)

if VIZUALIZE:
SimpleSimulation.visualize_sim(state, grid_size=None)
SimpleSimulation.visualize_sim(state)
print("\nSimulation ended")

assert jnp.sum(state.alive) == 5
assert state.x_pos.shape == (MAX_AGENTS,)
assert state.time == NUM_STEPS



def test_three_d_simulation_run():
key = random.PRNGKey(SEED)
sim = ThreeDSimulation(MAX_AGENTS, GRID_SIZE)
Expand Down
60 changes: 6 additions & 54 deletions wrapp_simulation.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,9 @@
import threading
import time

from jax import random

from simulationsandbox.two_d_simulation import SimpleSimulation

# Independant of any simulation environment
class SimulationWrapper:
"""
Simulation wrapper to start, pause, resume and stop a simulation.
Independant of any simulation type.
"""
def __init__(self, simulation, state, key):
self.running = False
self.paused = False
self.stop_requested = False
self.update_thread = None
# simulation_dependent
self.simulation = simulation
self.state = state
self.key = key

def start(self):
if not self.running:
self.running = True
self.stop_requested = False
self.update_thread = threading.Thread(target=self.simulation_loop)
self.update_thread.start()

def pause(self):
self.paused = True

def resume(self):
self.paused = False

def stop(self):
self.stop_requested = True

def simulation_loop(self):
while not self.stop_requested:
if self.paused:
time.sleep(0.1)
continue

self.state = self._update_simulation()
print(f"{self.state = }")

time.sleep(0.1)

def _update_simulation(self):
self.key, a_key, step_key = random.split(self.key, 3)
actions = self.simulation.choose_action(self.state.obs, a_key)
return self.simulation.step(self.state, actions, step_key)
from simulationsandbox.wrapper import SimulationWrapper

NUM_AGENTS = 5
MAX_AGENTS = 10
Expand All @@ -65,19 +17,19 @@ def _update_simulation(self):
state = env.init_state(NUM_AGENTS, NUM_OBS, key)

# Example usage:
sim = SimulationWrapper(env, state, key)
sim = SimulationWrapper(env, state, key, print_data=True)

print('Started')
sim.start_simulation()
sim.start()
time.sleep(SLEEP_TIME)

sim.pause_simulation()
sim.pause()
print('Paused')
time.sleep(SLEEP_TIME)

print('Resumed')
sim.resume_simulation()
sim.resume()
time.sleep(SLEEP_TIME)

sim.stop_simulation()
sim.stop()
print('stopped')

0 comments on commit bb01f48

Please sign in to comment.