Skip to content

Commit

Permalink
Add fish eating nearby food
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Apr 30, 2024
1 parent 4fcf3b2 commit 927b9b0
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions simulationsandbox/environments/aquarium.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import partial
from functools import reduce

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

Expand All @@ -8,10 +10,14 @@

from simulationsandbox.environments.base_env import BaseEnv, BaseEnvState

# Env
N_DIMS = 3
# Fish
FISH_SPEED = 4.
FISH_COLOR = [0., 0.5, 1.]
FISH_EATING_RANGE = 5.
MOVE_SCALE = 1.
# Food
FOOD_SPEED = 1.
FOOD_COLOR = [0.8, 0.5, 0.]

Expand All @@ -23,6 +29,7 @@ class Agents:
alive: jnp.array
color: jnp.array
obs: jnp.array
# energy: jnp.array


@struct.dataclass
Expand All @@ -41,11 +48,21 @@ class AquiariumState(BaseEnvState):
objects: Objects


# Helper functions
def normal(theta):
return jnp.array([jnp.cos(theta), jnp.sin(theta)])

normal = jit(vmap(normal))

def multiply_masks(mask1, mask2):
return mask1 * mask2

def distance(point1, point2):
diff = point1 - point2
squared_diff = jnp.sum(jnp.square(diff))
return jnp.sqrt(squared_diff)

distance = jit(vmap(distance, in_axes=(None, 0)))

# Change the angle and speed of the agent a bit
def move(obs, key):
Expand All @@ -54,14 +71,34 @@ def move(obs, key):
move = jit(vmap(move, in_axes=(0, 0)))


def eat(agent_idx, state):
agent_pos = state.agents.pos[agent_idx]
alive = state.agents.alive[agent_idx]

food_position = state.objects.pos
food_exist = state.objects.exist

food_dist = distance(agent_pos, food_position)
can_eat_idx = jnp.where(food_dist < FISH_EATING_RANGE, 1, 0)

mask = jnp.where(can_eat_idx, 0, 1)
updated_food_exist = food_exist * mask

# Return the old mask if the agent isn't alive
return jnp.where(alive, updated_food_exist, food_exist)

eat = vmap(eat, in_axes=(0, None))


class Aquarium(BaseEnv):
""" Minimalistic aquarium environmnent"""
def __init__(self, max_agents=10, max_objects=20, grid_size=50):
self.max_agents = max_agents
self.max_objects = max_objects
self.grid_size = grid_size

def init_state(self, num_agents, num_obs, key):
def init_state(self, num_agents=10, num_obs=2, seed=0):
key = random.PRNGKey(seed)
agents_key_pos, agents_key_vel, agents_color_key, objects_key_pos = random.split(key, 4)
fish_velocity = random.uniform(agents_key_vel, shape=(self.max_agents, N_DIMS), minval=-1, maxval=1)
fish_velocity = (fish_velocity / jnp.linalg.norm(fish_velocity)) * FISH_SPEED
Expand All @@ -81,6 +118,7 @@ def init_state(self, num_agents, num_obs, key):
z_food_pos = jnp.full((self.max_objects, 1), fill_value=self.grid_size)
food_pos = jnp.concatenate((x_y_food_pos, z_food_pos), axis=1)

# TODO : Do not set all the food to exist, make it spawn locally instead of uniformly at the surface
food = Objects(
pos=food_pos,
velocity=jnp.tile(jnp.array([0., 0., -1]), (self.max_objects, 1)) * FOOD_SPEED,
Expand Down Expand Up @@ -109,12 +147,20 @@ def step(self, state, key):
agents_pos = jnp.clip(agents_pos, 0, self.grid_size)
agents = state.agents.replace(pos=agents_pos, velocity=velocity)

# TODO : Only move the food that exists
# Update food position
food_vel = state.objects.velocity
food_pos = state.objects.pos + food_vel
food_pos = jnp.clip(food_pos, 0, self.grid_size)
objects = state.objects.replace(pos=food_pos)


# TODO : Also increase energy level of agents
agent_idx = jnp.arange(0, self.max_agents)
# Compute the exist food idx after each agent has eaten nearby food
food_exist = eat(agent_idx, state)
# Multiply the masks between them to get the final exist array
food_exist = reduce(multiply_masks, food_exist)
objects = state.objects.replace(pos=food_pos, exist=food_exist)

# Update new state
time = state.time + 1
state = state.replace(time=time, agents=agents, objects=objects)
Expand Down Expand Up @@ -152,7 +198,6 @@ def render(state):
objects_y_pos = state.objects.pos[:, 1][exist_object]
objects_z_pos = state.objects.pos[:, 2][exist_object]
objects_colors = state.objects.color[exist_object]
print(objects_x_pos)

# TODO : see how to add cmap=colormaps["gist_rainbow"] for fish colors
ax.scatter(agents_x_pos, agents_y_pos, agents_z_pos, c=agents_colors, marker="o", label="Fish")
Expand Down

0 comments on commit 927b9b0

Please sign in to comment.