Skip to content

Commit

Permalink
Move simulation loop in main.py and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Jan 30, 2024
1 parent 8a7de2a commit f8161c1
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 54 deletions.
4 changes: 3 additions & 1 deletion conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ params:
max_agents: 10
grid_size: 20
num_steps: 50
random_seed: 0
random_seed: 0
visualize: True
viz_delay: 0.1
33 changes: 28 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,40 @@ def main(cfg: DictConfig):
max_agents = cfg.params.max_agents
grid_size = cfg.params.grid_size
num_steps = cfg.params.num_steps
visualize = cfg.params.visualize
viz_delay = cfg.params.viz_delay
rng_key = random.PRNGKey(cfg.params.random_seed)

sim = Simulation(num_agents, max_agents,grid_size, rng_key)
sim = Simulation(num_agents, max_agents, grid_size, rng_key)

# Launch a simulation
print("\nSimulation started")

grid, agents_pos, agents_states, key = sim.get_env_state()
grid, final_agent_positions, final_agent_states = sim.simulate(
grid, agents_pos, agents_states, num_steps, grid_size, key
)
print("\nSimulation ended")

for step in range(num_steps):
if step % 10 == 0:
print(f"step {step}")

if step == 20:
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)

key, a_key = random.split(key)

agents_pos = sim.move_agents(agents_pos, grid_size, a_key)
agents_states += 0.1

if visualize:
sim.visualize(grid, agents_pos, viz_delay)

print("\nSimulation ended")

if __name__ == "__main__":
main()
36 changes: 3 additions & 33 deletions simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def add_agent(self, agents_pos, agents_states):
agents_pos = agents_pos.at[self.num_agents].set(*random.randint(self.key, (1, 2), 0, self.grid_size))
agents_states = agents_states.at[self.num_agents].set(1)
self.num_agents += 1
print(f"Added agent {self.num_agents}")

else:
print("Impossible to add more agents")
Expand All @@ -54,7 +55,7 @@ def add_agent(self, agents_pos, agents_states):
def remove_agent(idx=None):
pass

def visualize(self, grid, agents_pos):
def visualize(self, grid, agents_pos, delay=0.1):
if not plt.fignum_exists(1):
plt.ion()
plt.figure(figsize=(10, 10))
Expand All @@ -71,39 +72,8 @@ def visualize(self, grid, agents_pos):
plt.legend()

plt.draw()
plt.pause(0.1)
plt.pause(delay)

# TODO : move in the main class and have a thing like init, step functions like in gym
def simulate(
self, grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=True
):

# use a fori_loop after
for step in range(num_steps):

if step % 10 == 0:
print(f"step {step}")
if step == 20:
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
agents_pos, agents_states = self.add_agent(agents_pos, agents_states)
print(self.num_agents)

key, a_key = random.split(key)

agents_pos = self.move_agents(agents_pos, grid_size, a_key)

agents_states += 0.1

if visualize:
self.visualize(grid, agents_pos)

return grid, agents_pos, agents_states

def get_env_state(self):
return self.grid, self.agents_pos, self.agents_states, self.key
57 changes: 42 additions & 15 deletions test_pytest.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,58 @@
import jax
from jax import random

from simulation import Simulation

NUM_AGENTS = 5
MAX_AGENTS = 10
GRID_SIZE = 20
NUM_STEPS = 50
VIZUALIZE = True
VIZ_DELAY = 0.001
SEED = 0


def test_simulation_init():
num_agents = 5
grid_size = 20
key = jax.random.PRNGKey(0)
key = jax.random.PRNGKey(SEED)

sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key)
sim = Simulation(num_agents=NUM_AGENTS, max_agents=MAX_AGENTS, grid_size=GRID_SIZE, key=key)

assert sim.agents_pos.shape == (num_agents, 2)
assert sim.grid.shape == (grid_size, grid_size)
assert sim.num_agents == NUM_AGENTS
assert sim.agents_pos.shape == (MAX_AGENTS, 2)
assert sim.grid.shape == (GRID_SIZE, GRID_SIZE)


def test_simulation_run():
num_agents = 5
grid_size = 20
key = jax.random.PRNGKey(0)
key = jax.random.PRNGKey(SEED)
num_steps = 100

sim = Simulation(num_agents=num_agents, grid_size=grid_size, key=key)
sim = Simulation(num_agents=NUM_AGENTS, max_agents=MAX_AGENTS, grid_size=GRID_SIZE, key=key)

assert sim.num_agents == 5

grid, agents_pos, agents_states, key = sim.get_env_state()
grid, final_agent_positions, final_agent_states = sim.simulate(
grid, agents_pos, agents_states, num_steps, grid_size, key, visualize=False
)

assert final_agent_positions.shape == (num_agents, 2)
assert final_agent_states.shape == (num_agents,)
for step in range(num_steps):
if step % 10 == 0:
print(f"step {step}")

if step == 20:
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)

key, a_key = random.split(key)

agents_pos = sim.move_agents(agents_pos, GRID_SIZE, a_key)
agents_states += 0.1

if VIZUALIZE:
sim.visualize(grid, agents_pos, VIZ_DELAY)

assert sim.num_agents == MAX_AGENTS
assert agents_pos.shape == (MAX_AGENTS, 2)

0 comments on commit f8161c1

Please sign in to comment.