Skip to content

Commit 09f99a2

Browse files
committed
Add aquarium env
1 parent 2369589 commit 09f99a2

File tree

2 files changed

+157
-1
lines changed

2 files changed

+157
-1
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from functools import partial
2+
3+
import jax.numpy as jnp
4+
import matplotlib.pyplot as plt
5+
6+
from jax import random, jit, vmap
7+
from flax import struct
8+
from matplotlib import colormaps
9+
10+
from simulationsandbox.environments.base_env import BaseEnv, BaseEnvState
11+
12+
N_DIMS = 3
13+
FISH_SPEED = 3.
14+
MOVE_SCALE = 1.
15+
FOOD_SPEED = 1.
16+
17+
18+
@struct.dataclass
19+
class Agents:
20+
pos: jnp.array
21+
velocity: jnp.array
22+
alive: jnp.array
23+
color: jnp.array
24+
obs: jnp.array
25+
26+
27+
@struct.dataclass
28+
class Objects:
29+
pos: jnp.array
30+
velocity: jnp.array
31+
32+
33+
@struct.dataclass
34+
class AquiariumState(BaseEnvState):
35+
time: int
36+
grid_size: int
37+
agents: Agents
38+
objects: Objects
39+
40+
41+
def normal(theta):
42+
return jnp.array([jnp.cos(theta), jnp.sin(theta)])
43+
44+
normal = jit(vmap(normal))
45+
46+
47+
# Change the angle and speed of the agent a bit
48+
def move(obs, key):
49+
return random.normal(key, shape=(3,)) * MOVE_SCALE
50+
return random.uniform(key, shape=(N_DIMS,), minval=-1, maxval=1) * MOVE_SCALE
51+
52+
move = jit(vmap(move, in_axes=(0, 0)))
53+
54+
55+
class Aquarium(BaseEnv):
56+
""" Minimalistic aquarium environmnent"""
57+
def __init__(self, max_agents=10, max_objects=20, grid_size=20):
58+
self.max_agents = max_agents
59+
self.max_objects = max_objects
60+
self.grid_size = grid_size
61+
62+
def init_state(self, num_agents, num_obs, key):
63+
agents_key_pos, agents_key_vel, agents_color_key, objects_key_pos = random.split(key, 4)
64+
fish_velocity = random.uniform(agents_key_vel, shape=(self.max_agents, N_DIMS), minval=-1, maxval=1)
65+
fish_velocity = (fish_velocity / jnp.linalg.norm(fish_velocity)) * FISH_SPEED
66+
# fish_velocity = fish_velocity * FISH_SPEED
67+
68+
fish = Agents(
69+
pos=random.uniform(key=agents_key_pos, shape=(self.max_agents, N_DIMS), minval=0, maxval=self.grid_size),
70+
velocity=fish_velocity,
71+
alive=jnp.hstack((jnp.ones(num_agents), jnp.zeros(self.max_agents - num_agents))),
72+
color=random.uniform(key=agents_color_key, shape=(self.max_agents, 3), minval=0., maxval=1.),
73+
obs=jnp.zeros((self.max_agents, num_obs))
74+
)
75+
76+
# Add food at the surface of the aquarium
77+
x_y_food_pos=random.uniform(key=objects_key_pos, shape=(self.max_objects, 2), minval=0, maxval=self.grid_size)
78+
z_food_pos = jnp.full((self.max_objects, 1), fill_value=self.grid_size)
79+
food_pos = jnp.concatenate((x_y_food_pos, z_food_pos), axis=1)
80+
81+
food = Objects(
82+
pos=food_pos,
83+
velocity=jnp.tile(jnp.array([0., 0., -1]), (self.max_objects, 1)) * FOOD_SPEED,
84+
)
85+
86+
aquarium_env = AquiariumState(
87+
time=0,
88+
grid_size=self.grid_size,
89+
agents=fish,
90+
objects=food
91+
)
92+
93+
return aquarium_env
94+
95+
@partial(jit, static_argnums=(0,))
96+
def step(self, state, key):
97+
keys = random.split(key, self.max_agents)
98+
d_vel = move(state.agents.obs, keys)
99+
velocity = state.agents.velocity + d_vel
100+
velocity = (velocity / jnp.linalg.norm(velocity)) * FISH_SPEED
101+
agents_pos = state.agents.pos + velocity
102+
103+
# Collide with walls
104+
agents_pos = jnp.clip(agents_pos, 0, self.grid_size - 1)
105+
106+
# Update new state
107+
time = state.time + 1
108+
agents = state.agents.replace(pos=agents_pos, velocity=velocity)
109+
state = state.replace(time=time, agents=agents)
110+
return state
111+
112+
def add_agent(self, state, agent_idx):
113+
agents = state.agents.replace(alive=state.agents.alive.at[agent_idx].set(1.0))
114+
state = state.replace(agents=agents)
115+
return state
116+
117+
def remove_agent(self, state, agent_idx):
118+
agents = state.agents.replace(alive=state.agents.alive.at[agent_idx].set(0.0))
119+
state = state.replace(agents=agents)
120+
return state
121+
122+
@staticmethod
123+
def render(state):
124+
if not plt.fignum_exists(1):
125+
plt.ion()
126+
fig = plt.figure(figsize=(10, 10))
127+
ax = fig.add_subplot(111, projection='3d')
128+
129+
plt.clf()
130+
131+
ax = plt.axes(projection='3d')
132+
133+
alive_agents = jnp.where(state.agents.alive != 0.0)
134+
agents_x_pos = state.agents.pos[:, 0][alive_agents]
135+
agents_y_pos = state.agents.pos[:, 1][alive_agents]
136+
agents_z_pos = state.agents.pos[:, 2][alive_agents]
137+
agents_colors = state.agents.color[alive_agents]
138+
139+
# TODO : see how to add cmap=colormaps["gist_rainbow"]
140+
ax.scatter(agents_x_pos, agents_y_pos, agents_z_pos, c=agents_colors, marker="o", label="Fish")
141+
142+
ax.set_title("Multi-Agent Simulation")
143+
ax.set_xlabel("X-axis")
144+
ax.set_ylabel("Y-axis")
145+
ax.set_zlabel("Z-axis")
146+
147+
ax.set_xlim(0, state.grid_size)
148+
ax.set_ylim(0, state.grid_size)
149+
ax.set_zlim(0, state.grid_size)
150+
151+
ax.legend()
152+
153+
plt.draw()
154+
plt.pause(0.001)

simulationsandbox/utils/envs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from simulationsandbox.environments.two_d_example_env import TwoDEnv
22
from simulationsandbox.environments.three_d_example_env import ThreeDEnv
33
from simulationsandbox.environments.lake_env import LakeEnv
4+
from simulationsandbox.environments.aquarium import Aquarium
45

56
ENVS = {"two_d": TwoDEnv,
67
"three_d": ThreeDEnv,
7-
"lake": LakeEnv
8+
"lake": LakeEnv,
9+
"aquarium": Aquarium
810
}

0 commit comments

Comments
 (0)