Skip to content

Commit

Permalink
Merge pull request #4 from corentinlger/improve_aquarium_food_system
Browse files Browse the repository at this point in the history
Add food recharging mechanism in the environment
  • Loading branch information
corentinlger authored May 28, 2024
2 parents 2ddfcba + 920266b commit 9d16dcc
Showing 1 changed file with 71 additions and 20 deletions.
91 changes: 71 additions & 20 deletions simulationsandbox/environments/aquarium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Fish
FISH_SPEED = 4.
FISH_COLOR = [0., 0.5, 1.]
FISH_EATING_RANGE = 5.
FISH_EATING_RANGE = 3.
MOVE_SCALE = 1.
# Food
FOOD_SPEED = 1.
Expand All @@ -40,6 +40,7 @@ class Objects:
exist: jnp.array


# TODO : Add key in the state
@struct.dataclass
class AquiariumState(BaseEnvState):
time: int
Expand All @@ -48,6 +49,7 @@ class AquiariumState(BaseEnvState):
objects: Objects


# TODO : see if we don't just add all the functions in the class instead of some here and some below
# Helper functions
def normal(theta):
return jnp.array([jnp.cos(theta), jnp.sin(theta)])
Expand All @@ -70,6 +72,7 @@ def move(obs, key):

move = jit(vmap(move, in_axes=(0, 0)))

# TODO : add small documentations and comments to the lines
def eat(agent_idx, state):
agent_pos = state.agents.pos[agent_idx]
alive = state.agents.alive[agent_idx]
Expand All @@ -91,55 +94,62 @@ def eat(agent_idx, state):

eat = vmap(eat, in_axes=(0, None))


class Aquarium(BaseEnv):
""" Minimalistic aquarium environmnent"""
def __init__(self, max_agents=10, max_objects=20, grid_size=50):
def __init__(self, max_agents=10, max_objects=50, grid_size=50):
self.max_agents = max_agents
self.max_objects = max_objects
self.max_objects = max_objects
self.grid_size = grid_size
# recharge food when only 10% of max food in the aquarium
self.drop_food_threshold = self.max_objects // 10
# Add back half of the max food in the aquarium when recharching it
self.n_dropped_food = self.max_objects // 2

def init_state(self, num_agents=10, num_obs=2, seed=0):
def init_state(self, num_agents=None, num_obs=2, seed=0):
num_agents = num_agents if num_agents else self.max_agents
key = random.PRNGKey(seed)
agents_key_pos, agents_key_vel, agents_color_key, objects_key_pos = random.split(key, 4)
fish_velocity = random.uniform(agents_key_vel, shape=(self.max_agents, N_DIMS), minval=-1, maxval=1)
fish_velocity = (fish_velocity / jnp.linalg.norm(fish_velocity)) * FISH_SPEED
# fish_velocity = fish_velocity * FISH_SPEED
# Don't add fish too close to the surface / border of the aquarium
min_border_distance = self.grid_size // 10

fish = Agents(
pos=random.uniform(key=agents_key_pos, shape=(self.max_agents, N_DIMS), minval=0, maxval=self.grid_size),
pos=random.uniform(key=agents_key_pos, shape=(self.max_agents, N_DIMS), minval=min_border_distance, maxval=self.grid_size - min_border_distance),
velocity=fish_velocity,
alive=jnp.hstack((jnp.ones(num_agents), jnp.zeros(self.max_agents - num_agents))),
# color=random.uniform(key=agents_color_key, shape=(self.max_agents, 3), minval=0., maxval=1.),
color=jnp.full((self.max_agents, 3), jnp.array(FISH_COLOR)),
obs=jnp.zeros((self.max_agents, num_obs)),
energy=jnp.zeros((self.max_agents,))
)

# Add food at the surface of the aquarium
x_y_food_pos=random.uniform(key=objects_key_pos, shape=(self.max_objects, 2), minval=0, maxval=self.grid_size)
z_food_pos = jnp.full((self.max_objects, 1), fill_value=self.grid_size)
food_pos = jnp.concatenate((x_y_food_pos, z_food_pos), axis=1)
food_pos = self.place_food_at_surface(self.max_objects, objects_key_pos)
# Only add half of existing food at the beginning
half = self.max_objects // 2
exists_food = jnp.hstack((jnp.ones(half), jnp.zeros(half)))

# TODO : Do not set all the food to exist, make it spawn locally instead of uniformly at the surface
food = Objects(
pos=food_pos,
velocity=jnp.tile(jnp.array([0., 0., -1]), (self.max_objects, 1)) * FOOD_SPEED,
color=jnp.full((self.max_objects, 3), jnp.array(FOOD_COLOR)),
exist=jnp.full((self.max_objects), 1.)
exist=exists_food
)

aquarium_env = AquiariumState(
aquarium_state = AquiariumState(
time=0,
grid_size=self.grid_size,
agents=fish,
objects=food
)

return aquarium_env
return aquarium_state

# TODO : add key in the state to remove it from those arguments
# TODO : split the code into different functions to make it more digestible
@partial(jit, static_argnums=(0,))
def step(self, state, key):
def _step(self, state, key):
# Update agents positions
keys = random.split(key, self.max_agents)
d_vel = move(state.agents.obs, keys)
Expand All @@ -149,10 +159,14 @@ def step(self, state, key):
# Collide with walls
agents_pos = jnp.clip(agents_pos, 0, self.grid_size)

# TODO : Only move the food that exists
# Update food position
cur_food_pos = state.objects.pos
exist_food = jnp.where(state.objects.exist != 0.0, 1, 0)
# Adapt mask to pos shape
exist_food = jnp.broadcast_to(jnp.expand_dims(exist_food, 1), cur_food_pos.shape)
food_vel = state.objects.velocity
food_pos = state.objects.pos + food_vel
new_food_pos = cur_food_pos + food_vel
food_pos = jnp.where(exist_food, new_food_pos, cur_food_pos)
food_pos = jnp.clip(food_pos, 0, self.grid_size)

agent_idx = jnp.arange(0, self.max_agents)
Expand All @@ -164,12 +178,48 @@ def step(self, state, key):

objects = state.objects.replace(pos=food_pos, exist=food_exist)
agents = state.agents.replace(pos=agents_pos, velocity=velocity, energy=agents_energy)

# Update new state
time = state.time + 1
state = state.replace(time=time, agents=agents, objects=objects)
return state

# TODO : Same comment on integrating key in state
def step(self, state, key):
step_key, drop_food_key = random.split(key)
state = self._step(state, step_key)

# New function to drop food when there is not enough in the environment
drop_food = jnp.sum(state.objects.exist) < self.drop_food_threshold # Check if the > 10% max food condition is True
# TODO : See if better to do this with an if statement or do everything within jitted step function
if drop_food:
# Need to add the [0] otherwise idx are a tuple instead of jnp array
non_ex_food_idx = jnp.where(state.objects.exist == 0)[0]
state = self.drop_food(state, non_ex_food_idx, drop_food_key)
return state

@partial(jit, static_argnums=(0,))
def drop_food(self, state, non_ex_food_idx, key):
# Get the idx of the new food that will become existing
idx_key, pos_key = random.split(key)
jax.debug.print("idx_key {x}", x=idx_key)
print(idx_key)
jax.debug.print("non_ex_food_idx {x}", x=non_ex_food_idx)
idx = jax.random.choice(idx_key, non_ex_food_idx, shape=(self.n_dropped_food,), replace=False)
# Assign new positions and existing values to the food objects at these idx
food_exists = state.objects.exist.at[idx].set(1.)
new_pos = self.place_food_at_surface(self.n_dropped_food, pos_key)
food_pos = state.objects.pos.at[idx].set(new_pos)
objects = state.objects.replace(pos=food_pos, exist=food_exists)
return state.replace(objects=objects )

def place_food_at_surface(self, n_food, key):
x_y_food_pos=random.uniform(key=key, shape=(n_food, 2), minval=0, maxval=self.grid_size)
z_food_pos = jnp.full((n_food, 1), fill_value=self.grid_size)
food_pos = jnp.concatenate((x_y_food_pos, z_food_pos), axis=1)
return food_pos

# TODO : Check these unused functions
def add_agent(self, state, agent_idx):
agents = state.agents.replace(alive=state.agents.alive.at[agent_idx].set(1.0))
state = state.replace(agents=agents)
Expand All @@ -180,6 +230,7 @@ def remove_agent(self, state, agent_idx):
state = state.replace(agents=agents)
return state

# TODO : See how to potentially render the env with PIL / VideoWriter like in other jax libs
@staticmethod
def render(state):
if not plt.fignum_exists(1):
Expand Down Expand Up @@ -209,7 +260,7 @@ def render(state):
ax.scatter(agents_x_pos, agents_y_pos, agents_z_pos, c=agents_colors, s=(1 +agents_energy)*SCALE, marker="o", label="Fish")
ax.scatter(objects_x_pos, objects_y_pos, objects_z_pos, c=objects_colors, marker="o", label="Food")

ax.set_title("Multi-Agent Simulation")
ax.set_title("Aquarium Simulation")
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_zlabel("Z-axis")
Expand Down

0 comments on commit 9d16dcc

Please sign in to comment.