From f8161c16df0457bce098abf34d838e4c3910464f Mon Sep 17 00:00:00 2001 From: corentinlger Date: Tue, 30 Jan 2024 15:08:36 +0100 Subject: [PATCH] Move simulation loop in main.py and update tests --- conf/config.yaml | 4 +++- main.py | 33 +++++++++++++++++++++++----- simulation.py | 36 +++--------------------------- test_pytest.py | 57 +++++++++++++++++++++++++++++++++++------------- 4 files changed, 76 insertions(+), 54 deletions(-) diff --git a/conf/config.yaml b/conf/config.yaml index 26598b0..2c4d4ef 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -3,4 +3,6 @@ params: max_agents: 10 grid_size: 20 num_steps: 50 - random_seed: 0 \ No newline at end of file + random_seed: 0 + visualize: True + viz_delay: 0.1 \ No newline at end of file diff --git a/main.py b/main.py index 44d0987..81d27b5 100644 --- a/main.py +++ b/main.py @@ -13,17 +13,40 @@ def main(cfg: DictConfig): max_agents = cfg.params.max_agents grid_size = cfg.params.grid_size num_steps = cfg.params.num_steps + visualize = cfg.params.visualize + viz_delay = cfg.params.viz_delay rng_key = random.PRNGKey(cfg.params.random_seed) - sim = Simulation(num_agents, max_agents,grid_size, rng_key) + sim = Simulation(num_agents, max_agents, grid_size, rng_key) + # Launch a simulation print("\nSimulation started") + grid, agents_pos, agents_states, key = sim.get_env_state() - grid, final_agent_positions, final_agent_states = sim.simulate( - grid, agents_pos, agents_states, num_steps, grid_size, key - ) - print("\nSimulation ended") + for step in range(num_steps): + if step % 10 == 0: + print(f"step {step}") + + if step == 20: + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + + key, a_key = random.split(key) + + agents_pos = sim.move_agents(agents_pos, grid_size, a_key) + agents_states += 0.1 + + if visualize: + sim.visualize(grid, agents_pos, viz_delay) + + print("\nSimulation ended") if __name__ == "__main__": main() diff --git a/simulation.py b/simulation.py index c0d0222..7641ead 100644 --- a/simulation.py +++ b/simulation.py @@ -43,6 +43,7 @@ def add_agent(self, agents_pos, agents_states): agents_pos = agents_pos.at[self.num_agents].set(*random.randint(self.key, (1, 2), 0, self.grid_size)) agents_states = agents_states.at[self.num_agents].set(1) self.num_agents += 1 + print(f"Added agent {self.num_agents}") else: print("Impossible to add more agents") @@ -54,7 +55,7 @@ def add_agent(self, agents_pos, agents_states): def remove_agent(idx=None): pass - def visualize(self, grid, agents_pos): + def visualize(self, grid, agents_pos, delay=0.1): if not plt.fignum_exists(1): plt.ion() plt.figure(figsize=(10, 10)) @@ -71,39 +72,8 @@ def visualize(self, grid, agents_pos): plt.legend() plt.draw() - plt.pause(0.1) + plt.pause(delay) - # TODO : move in the main class and have a thing like init, step functions like in gym - def simulate( - self, grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=True - ): - - # use a fori_loop after - for step in range(num_steps): - - if step % 10 == 0: - print(f"step {step}") - if step == 20: - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - agents_pos, agents_states = self.add_agent(agents_pos, agents_states) - print(self.num_agents) - - key, a_key = random.split(key) - - agents_pos = self.move_agents(agents_pos, grid_size, a_key) - - agents_states += 0.1 - - if visualize: - self.visualize(grid, agents_pos) - - return grid, agents_pos, agents_states def get_env_state(self): return self.grid, self.agents_pos, self.agents_states, self.key diff --git a/test_pytest.py b/test_pytest.py index 5608177..2e2911f 100644 --- a/test_pytest.py +++ b/test_pytest.py @@ -1,31 +1,58 @@ import jax +from jax import random from simulation import Simulation +NUM_AGENTS = 5 +MAX_AGENTS = 10 +GRID_SIZE = 20 +NUM_STEPS = 50 +VIZUALIZE = True +VIZ_DELAY = 0.001 +SEED = 0 + def test_simulation_init(): - num_agents = 5 - grid_size = 20 - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(SEED) - sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key) + sim = Simulation(num_agents=NUM_AGENTS, max_agents=MAX_AGENTS, grid_size=GRID_SIZE, key=key) - assert sim.agents_pos.shape == (num_agents, 2) - assert sim.grid.shape == (grid_size, grid_size) + assert sim.num_agents == NUM_AGENTS + assert sim.agents_pos.shape == (MAX_AGENTS, 2) + assert sim.grid.shape == (GRID_SIZE, GRID_SIZE) def test_simulation_run(): - num_agents = 5 - grid_size = 20 - key = jax.random.PRNGKey(0) + key = jax.random.PRNGKey(SEED) num_steps = 100 - sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key) + sim = Simulation(num_agents=NUM_AGENTS, max_agents=MAX_AGENTS, grid_size=GRID_SIZE, key=key) + + assert sim.num_agents == 5 grid, agents_pos, agents_states, key = sim.get_env_state() - grid, final_agent_positions, final_agent_states = sim.simulate( - grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=False - ) - assert final_agent_positions.shape == (num_agents, 2) - assert final_agent_states.shape == (num_agents,) + for step in range(num_steps): + if step % 10 == 0: + print(f"step {step}") + + if step == 20: + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + agents_pos, agents_states = sim.add_agent(agents_pos, agents_states) + + key, a_key = random.split(key) + + agents_pos = sim.move_agents(agents_pos, GRID_SIZE, a_key) + agents_states += 0.1 + + if VIZUALIZE: + sim.visualize(grid, agents_pos, VIZ_DELAY) + + assert sim.num_agents == MAX_AGENTS + assert agents_pos.shape == (MAX_AGENTS, 2) \ No newline at end of file