Skip to content

Commit

Permalink
Add food dropping from the top of aquarium
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Apr 30, 2024
1 parent 09f99a2 commit c1b1a65
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions simulationsandbox/environments/aquarium.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from jax import random, jit, vmap
from flax import struct
from matplotlib import colormaps

from simulationsandbox.environments.base_env import BaseEnv, BaseEnvState

N_DIMS = 3
FISH_SPEED = 3.
FISH_SPEED = 4.
FISH_COLOR = [0., 0.5, 1.]
MOVE_SCALE = 1.
FOOD_SPEED = 1.
FOOD_COLOR = [0.8, 0.5, 0.]


@struct.dataclass
Expand All @@ -28,6 +29,8 @@ class Agents:
class Objects:
pos: jnp.array
velocity: jnp.array
color: jnp.array
exist: jnp.array


@struct.dataclass
Expand All @@ -47,14 +50,13 @@ def normal(theta):
# Change the angle and speed of the agent a bit
def move(obs, key):
return random.normal(key, shape=(3,)) * MOVE_SCALE
return random.uniform(key, shape=(N_DIMS,), minval=-1, maxval=1) * MOVE_SCALE

move = jit(vmap(move, in_axes=(0, 0)))


class Aquarium(BaseEnv):
""" Minimalistic aquarium environmnent"""
def __init__(self, max_agents=10, max_objects=20, grid_size=20):
def __init__(self, max_agents=10, max_objects=20, grid_size=50):
self.max_agents = max_agents
self.max_objects = max_objects
self.grid_size = grid_size
Expand All @@ -69,7 +71,8 @@ def init_state(self, num_agents, num_obs, key):
pos=random.uniform(key=agents_key_pos, shape=(self.max_agents, N_DIMS), minval=0, maxval=self.grid_size),
velocity=fish_velocity,
alive=jnp.hstack((jnp.ones(num_agents), jnp.zeros(self.max_agents - num_agents))),
color=random.uniform(key=agents_color_key, shape=(self.max_agents, 3), minval=0., maxval=1.),
# color=random.uniform(key=agents_color_key, shape=(self.max_agents, 3), minval=0., maxval=1.),
color=jnp.full((self.max_agents, 3), jnp.array(FISH_COLOR)),
obs=jnp.zeros((self.max_agents, num_obs))
)

Expand All @@ -81,6 +84,8 @@ def init_state(self, num_agents, num_obs, key):
food = Objects(
pos=food_pos,
velocity=jnp.tile(jnp.array([0., 0., -1]), (self.max_objects, 1)) * FOOD_SPEED,
color=jnp.full((self.max_objects, 3), jnp.array(FOOD_COLOR)),
exist=jnp.full((self.max_objects), 1.)
)

aquarium_env = AquiariumState(
Expand All @@ -94,19 +99,25 @@ def init_state(self, num_agents, num_obs, key):

@partial(jit, static_argnums=(0,))
def step(self, state, key):
# Update agents positions
keys = random.split(key, self.max_agents)
d_vel = move(state.agents.obs, keys)
velocity = state.agents.velocity + d_vel
velocity = (velocity / jnp.linalg.norm(velocity)) * FISH_SPEED
agents_pos = state.agents.pos + velocity

# Collide with walls
agents_pos = jnp.clip(agents_pos, 0, self.grid_size - 1)
agents_pos = jnp.clip(agents_pos, 0, self.grid_size)
agents = state.agents.replace(pos=agents_pos, velocity=velocity)

# Update food position
food_vel = state.objects.velocity
food_pos = state.objects.pos + food_vel
food_pos = jnp.clip(food_pos, 0, self.grid_size)
objects = state.objects.replace(pos=food_pos)

# Update new state
time = state.time + 1
agents = state.agents.replace(pos=agents_pos, velocity=velocity)
state = state.replace(time=time, agents=agents)
state = state.replace(time=time, agents=agents, objects=objects)
return state

def add_agent(self, state, agent_idx):
Expand Down Expand Up @@ -136,8 +147,16 @@ def render(state):
agents_z_pos = state.agents.pos[:, 2][alive_agents]
agents_colors = state.agents.color[alive_agents]

# TODO : see how to add cmap=colormaps["gist_rainbow"]
exist_object = jnp.where(state.objects.exist != 0.0)
objects_x_pos = state.objects.pos[:, 0][exist_object]
objects_y_pos = state.objects.pos[:, 1][exist_object]
objects_z_pos = state.objects.pos[:, 2][exist_object]
objects_colors = state.objects.color[exist_object]
print(objects_x_pos)

# TODO : see how to add cmap=colormaps["gist_rainbow"] for fish colors
ax.scatter(agents_x_pos, agents_y_pos, agents_z_pos, c=agents_colors, marker="o", label="Fish")
ax.scatter(objects_x_pos, objects_y_pos, objects_z_pos, c=objects_colors, marker="o", label="Fish")

ax.set_title("Multi-Agent Simulation")
ax.set_xlabel("X-axis")
Expand Down

0 comments on commit c1b1a65

Please sign in to comment.