diff --git a/simulationsandbox/wrapper.py b/simulationsandbox/wrapper.py new file mode 100644 index 0000000..9b8d2b3 --- /dev/null +++ b/simulationsandbox/wrapper.py @@ -0,0 +1,51 @@ +import threading +import time + +from jax import random + +class SimulationWrapper: + + def __init__(self, simulation, state, key, step_delay=0.1, print_data=False): + self.running = False + self.paused = False + self.stop_requested = False + self.update_thread = None + self.print_data = print_data + self.step_delay = step_delay + # simulation_dependent + self.simulation = simulation + self.state = state + self.key = key + + def start(self): + if not self.running: + self.running = True + self.stop_requested = False + self.update_thread = threading.Thread(target=self.simulation_loop) + self.update_thread.start() + + def pause(self): + self.paused = True + + def resume(self): + self.paused = False + + def stop(self): + self.stop_requested = True + + def simulation_loop(self): + while not self.stop_requested: + if self.paused: + time.sleep(0.1) + continue + + self.state = self._update_simulation() + if self.print_data: + print(f"{self.state = }") + + time.sleep(self.step_delay) + + def _update_simulation(self): + self.key, a_key, step_key = random.split(self.key, 3) + actions = self.simulation.choose_action(self.state.obs, a_key) + return self.simulation.step(self.state, actions, step_key) diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 0000000..6c93870 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,31 @@ +import time + +import jax +import jax.numpy as jnp +from jax import random + +from simulationsandbox.two_d_simulation import SimpleSimulation +from simulationsandbox.three_d_simulation import ThreeDSimulation + +NUM_AGENTS = 5 +MAX_AGENTS = 10 +NUM_OBS = 3 +GRID_SIZE = 20 +NUM_STEPS = 50 +VIZUALIZE = True +STEP_DELAY = 0.000001 +SEED = 0 + + +def test_simulation_init(): + key = random.PRNGKey(SEED) + + sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE) + state = sim.init_state(NUM_AGENTS, NUM_OBS, key) + + assert sim.max_agents == MAX_AGENTS + assert sim.grid_size == GRID_SIZE + assert state.x_pos.shape == (MAX_AGENTS,) + assert jnp.sum(state.alive) == NUM_AGENTS + assert state.grid.shape == (GRID_SIZE, GRID_SIZE) + diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py new file mode 100644 index 0000000..6254c7a --- /dev/null +++ b/tests/test_wrapper.py @@ -0,0 +1,74 @@ +import time + +import jax +import jax.numpy as jnp +from jax import random + +from simulationsandbox.two_d_simulation import SimpleSimulation +from simulationsandbox.three_d_simulation import ThreeDSimulation +from simulationsandbox.wrapper import SimulationWrapper + +NUM_AGENTS = 5 +MAX_AGENTS = 10 +NUM_OBS = 3 +GRID_SIZE = 20 +NUM_STEPS = 50 +VIZUALIZE = True +SLEEP_TIME = 2 +STEP_DELAY = 0.01 +PRINT_DATA = True +SEED = 0 + + +def test_wrapper_two_d_sim(): + key = random.PRNGKey(SEED) + env = SimpleSimulation(MAX_AGENTS, GRID_SIZE) + state = env.init_state(NUM_AGENTS, NUM_OBS, key) + + # Example usage: + sim = SimulationWrapper(env, state, key, step_delay=STEP_DELAY, print_data=PRINT_DATA) + + print('Started') + sim.start() + time.sleep(SLEEP_TIME) + assert sim.running == True + + sim.pause() + print('Paused') + time.sleep(SLEEP_TIME) + assert sim.paused == True + + print('Resumed') + sim.resume() + time.sleep(SLEEP_TIME) + assert sim.running == True + + sim.stop() + print('stopped') + + +# def test_wrapper_three_d_sim(): +# key = random.PRNGKey(SEED) +# env = ThreeDSimulation(MAX_AGENTS, GRID_SIZE) +# state = env.init_state(NUM_AGENTS, NUM_OBS, key) + +# # Example usage: +# sim = SimulationWrapper(env, state, key, print_data=True) + +# print('Started') +# sim.start() +# time.sleep(SLEEP_TIME) +# assert sim.running == True + +# sim.pause() +# print('Paused') +# time.sleep(SLEEP_TIME) +# assert sim.paused == True + +# print('Resumed') +# sim.resume() +# time.sleep(SLEEP_TIME) +# assert sim.running == True + +# sim.stop() +# print('stopped') \ No newline at end of file diff --git a/tests/test_pytest.py b/tests/test_zrun_sims.py similarity index 85% rename from tests/test_pytest.py rename to tests/test_zrun_sims.py index a609638..9c5d460 100644 --- a/tests/test_pytest.py +++ b/tests/test_zrun_sims.py @@ -17,19 +17,6 @@ SEED = 0 -def test_simulation_init(): - key = random.PRNGKey(SEED) - - sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE) - state = sim.init_state(NUM_AGENTS, NUM_OBS, key) - - assert sim.max_agents == MAX_AGENTS - assert sim.grid_size == GRID_SIZE - assert state.x_pos.shape == (MAX_AGENTS,) - assert jnp.sum(state.alive) == NUM_AGENTS - assert state.grid.shape == (GRID_SIZE, GRID_SIZE) - - def test_simple_simulation_run(): key = random.PRNGKey(SEED) sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE) @@ -64,7 +51,7 @@ def test_simple_simulation_run(): state = sim.step(state, actions, step_key) if VIZUALIZE: - SimpleSimulation.visualize_sim(state, grid_size=None) + SimpleSimulation.visualize_sim(state) print("\nSimulation ended") assert jnp.sum(state.alive) == 5 @@ -72,6 +59,7 @@ def test_simple_simulation_run(): assert state.time == NUM_STEPS + def test_three_d_simulation_run(): key = random.PRNGKey(SEED) sim = ThreeDSimulation(MAX_AGENTS, GRID_SIZE) diff --git a/wrapp_simulation.py b/wrapp_simulation.py index 72950fe..a0100b4 100644 --- a/wrapp_simulation.py +++ b/wrapp_simulation.py @@ -1,57 +1,9 @@ -import threading import time from jax import random from simulationsandbox.two_d_simulation import SimpleSimulation - -# Independant of any simulation environment -class SimulationWrapper: - """ - Simulation wrapper to start, pause, resume and stop a simulation. - Independant of any simulation type. - """ - def __init__(self, simulation, state, key): - self.running = False - self.paused = False - self.stop_requested = False - self.update_thread = None - # simulation_dependent - self.simulation = simulation - self.state = state - self.key = key - - def start(self): - if not self.running: - self.running = True - self.stop_requested = False - self.update_thread = threading.Thread(target=self.simulation_loop) - self.update_thread.start() - - def pause(self): - self.paused = True - - def resume(self): - self.paused = False - - def stop(self): - self.stop_requested = True - - def simulation_loop(self): - while not self.stop_requested: - if self.paused: - time.sleep(0.1) - continue - - self.state = self._update_simulation() - print(f"{self.state = }") - - time.sleep(0.1) - - def _update_simulation(self): - self.key, a_key, step_key = random.split(self.key, 3) - actions = self.simulation.choose_action(self.state.obs, a_key) - return self.simulation.step(self.state, actions, step_key) +from simulationsandbox.wrapper import SimulationWrapper NUM_AGENTS = 5 MAX_AGENTS = 10 @@ -65,19 +17,19 @@ def _update_simulation(self): state = env.init_state(NUM_AGENTS, NUM_OBS, key) # Example usage: -sim = SimulationWrapper(env, state, key) +sim = SimulationWrapper(env, state, key, print_data=True) print('Started') -sim.start_simulation() +sim.start() time.sleep(SLEEP_TIME) -sim.pause_simulation() +sim.pause() print('Paused') time.sleep(SLEEP_TIME) print('Resumed') -sim.resume_simulation() +sim.resume() time.sleep(SLEEP_TIME) -sim.stop_simulation() +sim.stop() print('stopped')