Skip to content

Commit

Permalink
Add remove agents functionality and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Jan 31, 2024
1 parent 16f6e8e commit 1f2d6aa
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 36 deletions.
23 changes: 15 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def main(cfg: DictConfig):
num_steps = cfg.params.num_steps
visualize = cfg.params.visualize
viz_delay = cfg.params.viz_delay

rng_key = random.PRNGKey(cfg.params.random_seed)

sim = Simulation(num_agents, max_agents, grid_size, rng_key)
Expand All @@ -29,14 +30,20 @@ def main(cfg: DictConfig):
print(f"step {step}")

if step == 20:
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)

if step == 40:
sim.remove_agent()
sim.remove_agent()
sim.remove_agent()
sim.remove_agent()

key, a_key = random.split(key)

Expand Down
12 changes: 8 additions & 4 deletions simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def init_agents(self, num_agents, max_agents, grid_size, key):

# TODO : Only move existing agents
def move_agents(self, agents_pos, grid_size, key):
# Shouldn't be able to do this when jit because of the +=
agents_pos += random.randint(key, agents_pos.shape, -1, 2)
return jnp.clip(agents_pos, 0, grid_size - 1)

Expand All @@ -50,10 +51,13 @@ def add_agent(self, agents_pos, agents_states):

return agents_pos, agents_states


# TODO:
def remove_agent(idx=None):
pass
def remove_agent(self):
if self.num_agents <= 0:
print("There is no agents to remove")
else:
self.num_agents -= 1
print(f"Removed agent {self.num_agents + 1}")


def visualize(self, grid, agents_pos, delay=0.1):
if not plt.fignum_exists(1):
Expand Down
15 changes: 0 additions & 15 deletions test_black.py

This file was deleted.

46 changes: 37 additions & 9 deletions test_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,22 @@ def test_simulation_run():
print(f"step {step}")

if step == 20:
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
# Add more agents than permitted and expect to reach max_agents = 10
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)

if step == 40:
# Remove 4 agents and expect to have 6 left
sim.remove_agent()
sim.remove_agent()
sim.remove_agent()
sim.remove_agent()

key, a_key = random.split(key)

Expand All @@ -54,5 +62,25 @@ def test_simulation_run():
if VIZUALIZE:
sim.visualize(grid, agents_pos, VIZ_DELAY)

assert sim.num_agents == 6
assert agents_pos.shape == (MAX_AGENTS, 2)


def test_add_remove_agents():
key = jax.random.PRNGKey(SEED)

sim = Simulation(num_agents=NUM_AGENTS, max_agents=MAX_AGENTS, grid_size=GRID_SIZE, key=key)
grid, agents_pos, agents_states, key = sim.get_env_state()

assert sim.num_agents == 5

for i in range(15):
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)

assert sim.num_agents == MAX_AGENTS
assert agents_pos.shape == (MAX_AGENTS, 2)

for i in range(15):
sim.remove_agent()

assert sim.num_agents == 0

0 comments on commit 1f2d6aa

Please sign in to comment.