Skip to content

Commit

Permalink
Increase agent's size when eat food and fix can_eat_idx bug
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed May 3, 2024
1 parent 69c9bd4 commit 0d80b63
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions simulationsandbox/environments/aquarium.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Agents:
alive: jnp.array
color: jnp.array
obs: jnp.array
# energy: jnp.array
energy: jnp.array


@struct.dataclass
Expand Down Expand Up @@ -70,7 +70,6 @@ def move(obs, key):

move = jit(vmap(move, in_axes=(0, 0)))


def eat(agent_idx, state):
agent_pos = state.agents.pos[agent_idx]
alive = state.agents.alive[agent_idx]
Expand All @@ -80,12 +79,15 @@ def eat(agent_idx, state):

food_dist = distance(agent_pos, food_position)
can_eat_idx = jnp.where(food_dist < FISH_EATING_RANGE, 1, 0)
can_eat_idx = jnp.where(can_eat_idx, food_exist, 0)

n_eaten = jnp.sum(can_eat_idx)

mask = jnp.where(can_eat_idx, 0, 1)
updated_food_exist = food_exist * mask

# Return the old mask if the agent isn't alive
return jnp.where(alive, updated_food_exist, food_exist)
return jnp.where(alive, updated_food_exist, food_exist), jnp.where(alive, n_eaten, 0)

eat = vmap(eat, in_axes=(0, None))

Expand All @@ -110,7 +112,8 @@ def init_state(self, num_agents=10, num_obs=2, seed=0):
alive=jnp.hstack((jnp.ones(num_agents), jnp.zeros(self.max_agents - num_agents))),
# color=random.uniform(key=agents_color_key, shape=(self.max_agents, 3), minval=0., maxval=1.),
color=jnp.full((self.max_agents, 3), jnp.array(FISH_COLOR)),
obs=jnp.zeros((self.max_agents, num_obs))
obs=jnp.zeros((self.max_agents, num_obs)),
energy=jnp.zeros((self.max_agents,))
)

# Add food at the surface of the aquarium
Expand Down Expand Up @@ -145,21 +148,23 @@ def step(self, state, key):
agents_pos = state.agents.pos + velocity
# Collide with walls
agents_pos = jnp.clip(agents_pos, 0, self.grid_size)
agents = state.agents.replace(pos=agents_pos, velocity=velocity)

# TODO : Only move the food that exists
# Update food position
food_vel = state.objects.velocity
food_pos = state.objects.pos + food_vel
food_pos = jnp.clip(food_pos, 0, self.grid_size)

# TODO : Also increase energy level of agents
agent_idx = jnp.arange(0, self.max_agents)
# Compute the exist food idx after each agent has eaten nearby food
food_exist = eat(agent_idx, state)
food_exist, eaten = eat(agent_idx, state)
agents_energy = state.agents.energy + eaten
# Multiply the masks between them to get the final exist array
food_exist = reduce(multiply_masks, food_exist)

objects = state.objects.replace(pos=food_pos, exist=food_exist)
agents = state.agents.replace(pos=agents_pos, velocity=velocity, energy=agents_energy)
jax.debug.print("energy {x}", x=agents_energy)

# Update new state
time = state.time + 1
Expand Down Expand Up @@ -192,15 +197,17 @@ def render(state):
agents_y_pos = state.agents.pos[:, 1][alive_agents]
agents_z_pos = state.agents.pos[:, 2][alive_agents]
agents_colors = state.agents.color[alive_agents]
agents_energy = state.agents.energy[alive_agents]

exist_object = jnp.where(state.objects.exist != 0.0)
objects_x_pos = state.objects.pos[:, 0][exist_object]
objects_y_pos = state.objects.pos[:, 1][exist_object]
objects_z_pos = state.objects.pos[:, 2][exist_object]
objects_colors = state.objects.color[exist_object]

SCALE = 15
# TODO : see how to add cmap=colormaps["gist_rainbow"] for fish colors
ax.scatter(agents_x_pos, agents_y_pos, agents_z_pos, c=agents_colors, marker="o", label="Fish")
ax.scatter(agents_x_pos, agents_y_pos, agents_z_pos, c=agents_colors, s=(1 +agents_energy)*SCALE, marker="o", label="Fish")
ax.scatter(objects_x_pos, objects_y_pos, objects_z_pos, c=objects_colors, marker="o", label="Fish")

ax.set_title("Multi-Agent Simulation")
Expand Down

0 comments on commit 0d80b63

Please sign in to comment.