Skip to content

Commit

Permalink
Simplify _step function in prey_predator env
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Aug 2, 2024
1 parent 0b0a66c commit 00f4019
Showing 1 changed file with 72 additions and 6 deletions.
78 changes: 72 additions & 6 deletions vivarium/experimental/notebooks/prey_predator_braitenberg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-07-04 11:03:15.059320: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
"2024-08-02 15:52:56.509482: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
]
}
],
Expand Down Expand Up @@ -258,6 +258,72 @@
" return state, neighbors"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class PreyPredBraitenbergEnv(BraitenbergEnv):\n",
" def __init__(\n",
" self,\n",
" state,\n",
" pred_eating_range\n",
" ): \n",
" super().__init__(state=state)\n",
" # Add idx utils to simplify conversions between entities and agent states\n",
" self.agents_idx = jnp.where(state.entities.entity_type == EntityType.AGENT.value)\n",
" self.prey_idx = jnp.where(state.agents.agent_type == AgentType.PREY.value)\n",
" self.pred_idx = jnp.where(state.agents.agent_type == AgentType.PREDATOR.value)\n",
" self.pred_eating_range = pred_eating_range\n",
"\n",
" # Add a function to detect if a prey will be eaten by a predator in the current step\n",
" def can_all_be_eaten(self, R_prey, R_predators, predator_exist):\n",
" # Could maybe create this as a method in the class, or above idk\n",
" distance_to_all_preds = vmap(self.distance, in_axes=(None, 0))\n",
"\n",
" # Same for this, the only pb is that the fn above needs the displacement arg, so can't define it in the cell above \n",
" def can_be_eaten(R_prey, R_predators, predator_exist):\n",
" dist_to_preds = distance_to_all_preds(R_prey, R_predators)\n",
" in_range = jnp.where(dist_to_preds < self.pred_eating_range, 1, 0)\n",
" # Could also return which agent ate the other one (e.g to increase their energy) \n",
" will_be_eaten_by = in_range * predator_exist\n",
" eaten_or_not = jnp.where(jnp.sum(will_be_eaten_by) > 0., 1, 0)\n",
" return eaten_or_not\n",
" \n",
" can_be_eaten = vmap(can_be_eaten, in_axes=(0, None, None))\n",
" \n",
" return can_be_eaten(R_prey, R_predators, predator_exist)\n",
" \n",
" # Add functions so predators eat preys\n",
" def eat_preys(self, state):\n",
" # See which preys can be eaten by predators and update the exists array accordingly\n",
" R = state.entities.position.center\n",
" exist = state.entities.exists\n",
" prey_idx = self.prey_idx\n",
" pred_idx = self.pred_idx\n",
"\n",
" agents_ent_idx = state.agents.ent_idx\n",
" predator_exist = exist[agents_ent_idx][pred_idx]\n",
" can_be_eaten_idx = self.can_all_be_eaten(R[prey_idx], R[pred_idx], predator_exist)\n",
"\n",
" # Kill the agents that are being eaten\n",
" exist_prey = exist[agents_ent_idx[prey_idx]]\n",
" new_exists_prey = jnp.where(can_be_eaten_idx == 1, 0, exist_prey)\n",
" exist = exist.at[agents_ent_idx[prey_idx]].set(new_exists_prey)\n",
"\n",
" return exist\n",
"\n",
" # Add the eat_preys function in the _step loop\n",
" @partial(jit, static_argnums=(0,))\n",
" def _step(self, state: State, neighbors: jnp.array, agents_neighs_idx: jnp.array) -> Tuple[State, jnp.array]:\n",
" # 1 Compute which agents are being eaten\n",
" exist = self.eat_preys(state)\n",
" entities = state.entities.replace(exists=exist)\n",
" state = state.replace(entities=entities)\n",
" return super()._step(state, neighbors, agents_neighs_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -273,7 +339,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -282,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -294,7 +360,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -321,7 +387,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -335,7 +401,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit 00f4019

Please sign in to comment.