Skip to content

Commit

Permalink
Add real minimalistic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Jan 24, 2024
1 parent 12f16c3 commit 863cab1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
7 changes: 4 additions & 3 deletions simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ def visualize(self, grid, agents_pos):
plt.draw()
plt.pause(0.1)

def simulate(self, grid, agents_pos, agents_states, num_steps, grid_size, key):
def simulate(self, grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=True):
# use a fori_loop after
for step in range(num_steps):
key, a_key = random.split(key)
agents_pos = self.move_agents(agents_pos, grid_size, a_key)

agents_states += 0.1

self.visualize(grid, agents_pos)

if visualize:
self.visualize(grid, agents_pos)

return grid, agents_pos, agents_states

Expand Down
39 changes: 28 additions & 11 deletions test_pytest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
# test_assert_examples.py
import jax

def test_uppercase():
assert "loud noises".upper() == "LOUD NOISES"
from simulation import Simulation

def test_reversed():
assert list(reversed([1, 2, 3, 4])) == [4, 3, 2, 1]

def test_some_primes():
assert 37 in {
num
for num in range(2, 50)
if not any(num % div == 0 for div in range(2, num))
}
def test_simulation_init():
num_agents = 5
grid_size = 20
key = jax.random.PRNGKey(0)

sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key)

assert sim.agents_pos.shape == (num_agents, 2)
assert sim.grid.shape == (grid_size, grid_size)


def test_simulation_run():
num_agents = 5
grid_size = 20
key = jax.random.PRNGKey(0)
num_steps = 100

sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key)

grid, agents_pos, agents_states, key = sim.get_env_state()
grid, final_agent_positions, final_agent_states = sim.simulate(
grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=False
)

assert final_agent_positions.shape == (num_agents, 2)
assert final_agent_states.shape == (num_agents,)

0 comments on commit 863cab1

Please sign in to comment.