From 863cab1aacf1f716d93c4b437eec03246c743600 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 24 Jan 2024 12:04:58 +0100 Subject: [PATCH] Add real minimalistic tests --- simulation.py | 7 ++++--- test_pytest.py | 39 ++++++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/simulation.py b/simulation.py index 09825df..d49533f 100644 --- a/simulation.py +++ b/simulation.py @@ -44,15 +44,16 @@ def visualize(self, grid, agents_pos): plt.draw() plt.pause(0.1) - def simulate(self, grid, agents_pos, agents_states, num_steps, grid_size, key): + def simulate(self, grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=True): # use a fori_loop after for step in range(num_steps): key, a_key = random.split(key) agents_pos = self.move_agents(agents_pos, grid_size, a_key) agents_states += 0.1 - - self.visualize(grid, agents_pos) + + if visualize: + self.visualize(grid, agents_pos) return grid, agents_pos, agents_states diff --git a/test_pytest.py b/test_pytest.py index 56165b8..5608177 100644 --- a/test_pytest.py +++ b/test_pytest.py @@ -1,14 +1,31 @@ -# test_assert_examples.py +import jax -def test_uppercase(): - assert "loud noises".upper() == "LOUD NOISES" +from simulation import Simulation -def test_reversed(): - assert list(reversed([1, 2, 3, 4])) == [4, 3, 2, 1] -def test_some_primes(): - assert 37 in { - num - for num in range(2, 50) - if not any(num % div == 0 for div in range(2, num)) - } \ No newline at end of file +def test_simulation_init(): + num_agents = 5 + grid_size = 20 + key = jax.random.PRNGKey(0) + + sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key) + + assert sim.agents_pos.shape == (num_agents, 2) + assert sim.grid.shape == (grid_size, grid_size) + + +def test_simulation_run(): + num_agents = 5 + grid_size = 20 + key = jax.random.PRNGKey(0) + num_steps = 100 + + sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key) + + grid, agents_pos, agents_states, key = sim.get_env_state() + grid, final_agent_positions, final_agent_states = sim.simulate( + grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=False + ) + + assert final_agent_positions.shape == (num_agents, 2) + assert final_agent_states.shape == (num_agents,)