Skip to content

Commit

Permalink
Merge pull request #6 from corentinlger/fix_cond_aquarium_step
Browse files Browse the repository at this point in the history
Fix unjitted drop food function in aquarium step
  • Loading branch information
corentinlger authored Jul 16, 2024
2 parents b525c03 + 4b3b5f1 commit 15f2e8d
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions simulationsandbox/environments/aquarium.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,19 @@ def init_state(self, num_agents=None, num_obs=2, seed=0):
# TODO : split the code into different functions to make it more digestible
@partial(jit, static_argnums=(0,))
def _step(self, state, key):
# Drop new food if there is not enough in the aquarium, else keep the same state
key, drop_food_key = random.split(key)
drop_food = jnp.sum(state.objects.exist) < self.drop_food_threshold
non_ex_food_idx = jnp.where(state.objects.exist == 0, 1., 0.)
state = jax.lax.cond(drop_food, self.drop_food, self.dont_change, state, non_ex_food_idx, drop_food_key)

# Update agents positions
keys = random.split(key, self.max_agents)
d_vel = move(state.agents.obs, keys)
velocity = state.agents.velocity + d_vel
velocity = (velocity / jnp.linalg.norm(velocity)) * FISH_SPEED
agents_pos = state.agents.pos + velocity

# Collide with walls
agents_pos = jnp.clip(agents_pos, 0, self.grid_size)

Expand All @@ -169,46 +176,41 @@ def _step(self, state, key):
food_pos = jnp.where(exist_food, new_food_pos, cur_food_pos)
food_pos = jnp.clip(food_pos, 0, self.grid_size)

agent_idx = jnp.arange(0, self.max_agents)
# Compute the exist food idx after each agent has eaten nearby food
agent_idx = jnp.arange(0, self.max_agents)
food_exist, eaten = eat(agent_idx, state)
agents_energy = state.agents.energy + eaten

# Multiply the masks between them to get the final exist array
food_exist = reduce(multiply_masks, food_exist)

objects = state.objects.replace(pos=food_pos, exist=food_exist)
agents = state.agents.replace(pos=agents_pos, velocity=velocity, energy=agents_energy)

# Update new state
time = state.time + 1
objects = state.objects.replace(pos=food_pos, exist=food_exist)
agents = state.agents.replace(pos=agents_pos, velocity=velocity, energy=agents_energy)
state = state.replace(time=time, agents=agents, objects=objects)
return state

# TODO : Same comment on integrating key in state
def step(self, state, key):
step_key, drop_food_key = random.split(key)
state = self._step(state, step_key)

# New function to drop food when there is not enough in the environment
drop_food = jnp.sum(state.objects.exist) < self.drop_food_threshold # Check if the > 10% max food condition is True
# TODO : See if better to do this with an if statement or do everything within jitted step function
if drop_food:
# Need to add the [0] otherwise idx are a tuple instead of jnp array
non_ex_food_idx = jnp.where(state.objects.exist == 0)[0]
state = self.drop_food(state, non_ex_food_idx, drop_food_key)
# At the moment only call the _step but could add new methods in the future
state = self._step(state, key)
return state

@partial(jit, static_argnums=(0,))
# function to modify the state in the lax.cond for food dropping
def drop_food(self, state, non_ex_food_idx, key):
# Get the idx of the new food that will become existing
idx_key, pos_key = random.split(key)
idx = jax.random.choice(idx_key, non_ex_food_idx, shape=(self.n_dropped_food,), replace=False)
idx = random.choice(idx_key, a=jnp.arange(len(non_ex_food_idx)), p=non_ex_food_idx, shape=(self.n_dropped_food,), replace=False)
# Assign new positions and existing values to the food objects at these idx
food_exists = state.objects.exist.at[idx].set(1.)
new_pos = self.place_food_at_surface(self.n_dropped_food, pos_key)
food_pos = state.objects.pos.at[idx].set(new_pos)
objects = state.objects.replace(pos=food_pos, exist=food_exists)
return state.replace(objects=objects )
return state.replace(objects=objects)

# function to keep the state as it is in the lax.cond for food dropping
def dont_change(self, state, non_ex_food_idx, key):
return state

def place_food_at_surface(self, n_food, key):
x_y_food_pos=random.uniform(key=key, shape=(n_food, 2), minval=0, maxval=self.grid_size)
Expand Down

0 comments on commit 15f2e8d

Please sign in to comment.