Skip to content

Commit

Permalink
Remove changing arrays from Simulation class variables
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Jan 31, 2024
1 parent 1f2d6aa commit 99a8749
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 100 deletions.
32 changes: 12 additions & 20 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,34 @@ def main(cfg: DictConfig):
visualize = cfg.params.visualize
viz_delay = cfg.params.viz_delay

rng_key = random.PRNGKey(cfg.params.random_seed)
key = random.PRNGKey(cfg.params.random_seed)

sim = Simulation(num_agents, max_agents, grid_size, rng_key)
sim = Simulation(max_agents, grid_size)
grid = sim.init_grid(grid_size)
agents_pos, agents_states, num_agents = sim.init_agents(num_agents, max_agents, key)

# Launch a simulation
print("\nSimulation started")

grid, agents_pos, agents_states, key = sim.get_env_state()

for step in range(num_steps):
key, a_key, add_key = random.split(key, 3)

if step % 10 == 0:
print(f"step {step}")

if step == 20:
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)
agents_pos, agents_states = sim.add_agent(agents_pos, agents_states)

for _ in range(8):
agents_pos, agents_states, num_agents = sim.add_agent(agents_pos, agents_states, num_agents, add_key)

if step == 40:
sim.remove_agent()
sim.remove_agent()
sim.remove_agent()
sim.remove_agent()

key, a_key = random.split(key)
for _ in range(4):
num_agents = sim.remove_agent(num_agents)

agents_pos = sim.move_agents(agents_pos, grid_size, a_key)
agents_states += 0.1

if visualize:
sim.visualize(grid, agents_pos, viz_delay)
sim.visualize(grid, agents_pos, num_agents, viz_delay)

print("\nSimulation ended")

Expand Down
33 changes: 20 additions & 13 deletions packages_training/panel/biker_walker.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 82,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -89,12 +89,12 @@
"data": {
"application/vnd.holoviews_exec.v0+json": "",
"text/html": [
"<div id='7ad8e236-5494-4afa-9d3b-2e2fc76f76ba'>\n",
" <div id=\"e4c6313e-0c31-4ca3-b0b3-e214b1934ca7\" data-root-id=\"7ad8e236-5494-4afa-9d3b-2e2fc76f76ba\" style=\"display: contents;\"></div>\n",
"<div id='986f7705-ea5f-403f-991c-c58fd8daf686'>\n",
" <div id=\"bb5404fd-2ab8-4283-95c2-65dc6e6de4f8\" data-root-id=\"986f7705-ea5f-403f-991c-c58fd8daf686\" style=\"display: contents;\"></div>\n",
"</div>\n",
"<script type=\"application/javascript\">(function(root) {\n",
" var docs_json = {\"3920800d-2f86-4dfe-9e0c-7db8292212a8\":{\"version\":\"3.3.4\",\"title\":\"Bokeh Application\",\"roots\":[{\"type\":\"object\",\"name\":\"panel.models.browser.BrowserInfo\",\"id\":\"7ad8e236-5494-4afa-9d3b-2e2fc76f76ba\"},{\"type\":\"object\",\"name\":\"panel.models.comm_manager.CommManager\",\"id\":\"3c6e0399-d911-4d43-b94b-a408d007f7d7\",\"attributes\":{\"plot_id\":\"7ad8e236-5494-4afa-9d3b-2e2fc76f76ba\",\"comm_id\":\"26cd1095def44a0dbbd835f7de8c12de\",\"client_comm_id\":\"8afe73b3b19b41a0a944265c212f29d7\"}}],\"defs\":[{\"type\":\"model\",\"name\":\"ReactiveHTML1\"},{\"type\":\"model\",\"name\":\"FlexBox1\",\"properties\":[{\"name\":\"align_content\",\"kind\":\"Any\",\"default\":\"flex-start\"},{\"name\":\"align_items\",\"kind\":\"Any\",\"default\":\"flex-start\"},{\"name\":\"flex_direction\",\"kind\":\"Any\",\"default\":\"row\"},{\"name\":\"flex_wrap\",\"kind\":\"Any\",\"default\":\"wrap\"},{\"name\":\"justify_content\",\"kind\":\"Any\",\"default\":\"flex-start\"}]},{\"type\":\"model\",\"name\":\"FloatPanel1\",\"properties\":[{\"name\":\"config\",\"kind\":\"Any\",\"default\":{\"type\":\"map\"}},{\"name\":\"contained\",\"kind\":\"Any\",\"default\":true},{\"name\":\"position\",\"kind\":\"Any\",\"default\":\"right-top\"},{\"name\":\"offsetx\",\"kind\":\"Any\",\"default\":null},{\"name\":\"offsety\",\"kind\":\"Any\",\"default\":null},{\"name\":\"theme\",\"kind\":\"Any\",\"default\":\"primary\"},{\"name\":\"status\",\"kind\":\"Any\",\"default\":\"normalized\"}]},{\"type\":\"model\",\"name\":\"GridStack1\",\"properties\":[{\"name\":\"mode\",\"kind\":\"Any\",\"default\":\"warn\"},{\"name\":\"ncols\",\"kind\":\"Any\",\"default\":null},{\"name\":\"nrows\",\"kind\":\"Any\",\"default\":null},{\"name\":\"allow_resize\",\"kind\":\"Any\",\"default\":true},{\"name\":\"allow_drag\",\"kind\":\"Any\",\"default\":true},{\"name\":\"state\",\"kind\":\"Any\",\"default\":[]}]},{\"type\":\"model\",\"name\":\"drag1\",\"properties\":[{\"name\":\"slider_width\",\"kind\":\"Any\",\"default\":5},{\"name\":\"slider_color\",\"kind\":\"Any\",\"default\":\"black\"},{\"name\":\"value\",\"kind\":\"Any\",\"default\":50}]},{\"type\":\"model\",\"name\":\"click1\",\"properties\":[{\"name\":\"terminal_output\",\"kind\":\"Any\",\"default\":\"\"},{\"name\":\"debug_name\",\"kind\":\"Any\",\"default\":\"\"},{\"name\":\"clears\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"copy_to_clipboard1\",\"properties\":[{\"name\":\"fill\",\"kind\":\"Any\",\"default\":\"none\"},{\"name\":\"value\",\"kind\":\"Any\",\"default\":null}]},{\"type\":\"model\",\"name\":\"FastWrapper1\",\"properties\":[{\"name\":\"object\",\"kind\":\"Any\",\"default\":null},{\"name\":\"style\",\"kind\":\"Any\",\"default\":null}]},{\"type\":\"model\",\"name\":\"NotificationAreaBase1\",\"properties\":[{\"name\":\"js_events\",\"kind\":\"Any\",\"default\":{\"type\":\"map\"}},{\"name\":\"position\",\"kind\":\"Any\",\"default\":\"bottom-right\"},{\"name\":\"_clear\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"NotificationArea1\",\"properties\":[{\"name\":\"js_events\",\"kind\":\"Any\",\"default\":{\"type\":\"map\"}},{\"name\":\"notifications\",\"kind\":\"Any\",\"default\":[]},{\"name\":\"position\",\"kind\":\"Any\",\"default\":\"bottom-right\"},{\"name\":\"_clear\",\"kind\":\"Any\",\"default\":0},{\"name\":\"types\",\"kind\":\"Any\",\"default\":[{\"type\":\"map\",\"entries\":[[\"type\",\"warning\"],[\"background\",\"#ffc107\"],[\"icon\",{\"type\":\"map\",\"entries\":[[\"className\",\"fas fa-exclamation-triangle\"],[\"tagName\",\"i\"],[\"color\",\"white\"]]}]]},{\"type\":\"map\",\"entries\":[[\"type\",\"info\"],[\"background\",\"#007bff\"],[\"icon\",{\"type\":\"map\",\"entries\":[[\"className\",\"fas fa-info-circle\"],[\"tagName\",\"i\"],[\"color\",\"white\"]]}]]}]}]},{\"type\":\"model\",\"name\":\"Notification\",\"properties\":[{\"name\":\"background\",\"kind\":\"Any\",\"default\":null},{\"name\":\"duration\",\"kind\":\"Any\",\"default\":3000},{\"name\":\"icon\",\"kind\":\"Any\",\"default\":null},{\"name\":\"message\",\"kind\":\"Any\",\"default\":\"\"},{\"name\":\"notification_type\",\"kind\":\"Any\",\"default\":null},{\"name\":\"_destroyed\",\"kind\":\"Any\",\"default\":false}]},{\"type\":\"model\",\"name\":\"TemplateActions1\",\"properties\":[{\"name\":\"open_modal\",\"kind\":\"Any\",\"default\":0},{\"name\":\"close_modal\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"BootstrapTemplateActions1\",\"properties\":[{\"name\":\"open_modal\",\"kind\":\"Any\",\"default\":0},{\"name\":\"close_modal\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"MaterialTemplateActions1\",\"properties\":[{\"name\":\"open_modal\",\"kind\":\"Any\",\"default\":0},{\"name\":\"close_modal\",\"kind\":\"Any\",\"default\":0}]}]}};\n",
" var render_items = [{\"docid\":\"3920800d-2f86-4dfe-9e0c-7db8292212a8\",\"roots\":{\"7ad8e236-5494-4afa-9d3b-2e2fc76f76ba\":\"e4c6313e-0c31-4ca3-b0b3-e214b1934ca7\"},\"root_ids\":[\"7ad8e236-5494-4afa-9d3b-2e2fc76f76ba\"]}];\n",
" var docs_json = {\"36b22d77-95e9-49e6-8125-99d5192de67d\":{\"version\":\"3.3.4\",\"title\":\"Bokeh Application\",\"roots\":[{\"type\":\"object\",\"name\":\"panel.models.browser.BrowserInfo\",\"id\":\"986f7705-ea5f-403f-991c-c58fd8daf686\"},{\"type\":\"object\",\"name\":\"panel.models.comm_manager.CommManager\",\"id\":\"38f01f27-681e-4fc5-ad33-1bd67f577ebd\",\"attributes\":{\"plot_id\":\"986f7705-ea5f-403f-991c-c58fd8daf686\",\"comm_id\":\"2df62da64e4d4436ac7b73250c4c41d4\",\"client_comm_id\":\"55c122ad305246a6becf178ac8e446bc\"}}],\"defs\":[{\"type\":\"model\",\"name\":\"ReactiveHTML1\"},{\"type\":\"model\",\"name\":\"FlexBox1\",\"properties\":[{\"name\":\"align_content\",\"kind\":\"Any\",\"default\":\"flex-start\"},{\"name\":\"align_items\",\"kind\":\"Any\",\"default\":\"flex-start\"},{\"name\":\"flex_direction\",\"kind\":\"Any\",\"default\":\"row\"},{\"name\":\"flex_wrap\",\"kind\":\"Any\",\"default\":\"wrap\"},{\"name\":\"justify_content\",\"kind\":\"Any\",\"default\":\"flex-start\"}]},{\"type\":\"model\",\"name\":\"FloatPanel1\",\"properties\":[{\"name\":\"config\",\"kind\":\"Any\",\"default\":{\"type\":\"map\"}},{\"name\":\"contained\",\"kind\":\"Any\",\"default\":true},{\"name\":\"position\",\"kind\":\"Any\",\"default\":\"right-top\"},{\"name\":\"offsetx\",\"kind\":\"Any\",\"default\":null},{\"name\":\"offsety\",\"kind\":\"Any\",\"default\":null},{\"name\":\"theme\",\"kind\":\"Any\",\"default\":\"primary\"},{\"name\":\"status\",\"kind\":\"Any\",\"default\":\"normalized\"}]},{\"type\":\"model\",\"name\":\"GridStack1\",\"properties\":[{\"name\":\"mode\",\"kind\":\"Any\",\"default\":\"warn\"},{\"name\":\"ncols\",\"kind\":\"Any\",\"default\":null},{\"name\":\"nrows\",\"kind\":\"Any\",\"default\":null},{\"name\":\"allow_resize\",\"kind\":\"Any\",\"default\":true},{\"name\":\"allow_drag\",\"kind\":\"Any\",\"default\":true},{\"name\":\"state\",\"kind\":\"Any\",\"default\":[]}]},{\"type\":\"model\",\"name\":\"drag1\",\"properties\":[{\"name\":\"slider_width\",\"kind\":\"Any\",\"default\":5},{\"name\":\"slider_color\",\"kind\":\"Any\",\"default\":\"black\"},{\"name\":\"value\",\"kind\":\"Any\",\"default\":50}]},{\"type\":\"model\",\"name\":\"click1\",\"properties\":[{\"name\":\"terminal_output\",\"kind\":\"Any\",\"default\":\"\"},{\"name\":\"debug_name\",\"kind\":\"Any\",\"default\":\"\"},{\"name\":\"clears\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"copy_to_clipboard1\",\"properties\":[{\"name\":\"fill\",\"kind\":\"Any\",\"default\":\"none\"},{\"name\":\"value\",\"kind\":\"Any\",\"default\":null}]},{\"type\":\"model\",\"name\":\"FastWrapper1\",\"properties\":[{\"name\":\"object\",\"kind\":\"Any\",\"default\":null},{\"name\":\"style\",\"kind\":\"Any\",\"default\":null}]},{\"type\":\"model\",\"name\":\"NotificationAreaBase1\",\"properties\":[{\"name\":\"js_events\",\"kind\":\"Any\",\"default\":{\"type\":\"map\"}},{\"name\":\"position\",\"kind\":\"Any\",\"default\":\"bottom-right\"},{\"name\":\"_clear\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"NotificationArea1\",\"properties\":[{\"name\":\"js_events\",\"kind\":\"Any\",\"default\":{\"type\":\"map\"}},{\"name\":\"notifications\",\"kind\":\"Any\",\"default\":[]},{\"name\":\"position\",\"kind\":\"Any\",\"default\":\"bottom-right\"},{\"name\":\"_clear\",\"kind\":\"Any\",\"default\":0},{\"name\":\"types\",\"kind\":\"Any\",\"default\":[{\"type\":\"map\",\"entries\":[[\"type\",\"warning\"],[\"background\",\"#ffc107\"],[\"icon\",{\"type\":\"map\",\"entries\":[[\"className\",\"fas fa-exclamation-triangle\"],[\"tagName\",\"i\"],[\"color\",\"white\"]]}]]},{\"type\":\"map\",\"entries\":[[\"type\",\"info\"],[\"background\",\"#007bff\"],[\"icon\",{\"type\":\"map\",\"entries\":[[\"className\",\"fas fa-info-circle\"],[\"tagName\",\"i\"],[\"color\",\"white\"]]}]]}]}]},{\"type\":\"model\",\"name\":\"Notification\",\"properties\":[{\"name\":\"background\",\"kind\":\"Any\",\"default\":null},{\"name\":\"duration\",\"kind\":\"Any\",\"default\":3000},{\"name\":\"icon\",\"kind\":\"Any\",\"default\":null},{\"name\":\"message\",\"kind\":\"Any\",\"default\":\"\"},{\"name\":\"notification_type\",\"kind\":\"Any\",\"default\":null},{\"name\":\"_destroyed\",\"kind\":\"Any\",\"default\":false}]},{\"type\":\"model\",\"name\":\"TemplateActions1\",\"properties\":[{\"name\":\"open_modal\",\"kind\":\"Any\",\"default\":0},{\"name\":\"close_modal\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"BootstrapTemplateActions1\",\"properties\":[{\"name\":\"open_modal\",\"kind\":\"Any\",\"default\":0},{\"name\":\"close_modal\",\"kind\":\"Any\",\"default\":0}]},{\"type\":\"model\",\"name\":\"MaterialTemplateActions1\",\"properties\":[{\"name\":\"open_modal\",\"kind\":\"Any\",\"default\":0},{\"name\":\"close_modal\",\"kind\":\"Any\",\"default\":0}]}]}};\n",
" var render_items = [{\"docid\":\"36b22d77-95e9-49e6-8125-99d5192de67d\",\"roots\":{\"986f7705-ea5f-403f-991c-c58fd8daf686\":\"bb5404fd-2ab8-4283-95c2-65dc6e6de4f8\"},\"root_ids\":[\"986f7705-ea5f-403f-991c-c58fd8daf686\"]}];\n",
" var docs = Object.values(docs_json)\n",
" if (!docs) {\n",
" return\n",
Expand Down Expand Up @@ -158,7 +158,7 @@
},
"metadata": {
"application/vnd.holoviews_exec.v0+json": {
"id": "7ad8e236-5494-4afa-9d3b-2e2fc76f76ba"
"id": "986f7705-ea5f-403f-991c-c58fd8daf686"
}
},
"output_type": "display_data"
Expand All @@ -173,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -216,7 +216,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -252,7 +252,7 @@
},
{
"cell_type": "code",
"execution_count": 71,
"execution_count": 85,
"metadata": {},
"outputs": [
{
Expand All @@ -279,21 +279,21 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c3fb840e8afa498ba1d2d7d39674a8ba",
"model_id": "75e2c08e943e446d86a70748161ef1ff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"BokehModel(combine_events=True, render_bundle={'docs_json': {'32d89942-4de7-4a6f-a0d4-52fad3e489c3': {'version…"
"BokehModel(combine_events=True, render_bundle={'docs_json': {'7816836c-c2a4-46da-a074-d05334756b09': {'version…"
]
},
"execution_count": 77,
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -319,6 +319,13 @@
" pn.pane.Markdown(bound_fn, styles=pn.bind(styles, background))\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
50 changes: 23 additions & 27 deletions simulation.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
import jax.numpy as jnp
from jax import random

import matplotlib.pyplot as plt

# TODO : CHANGE ALL THE FILE SO WE EITHER USE SELF.AGENT8POS OR JUST USE IT IN THE SIMULATION BUTT NOT BOTH

# Will change this later : take the habit of using less classes and more pure functions without self argument
class Simulation:
def __init__(self, num_agents, max_agents, grid_size, key):
self.grid = self.init_grid(grid_size)
def __init__(self, max_agents, grid_size):
self.grid_size = grid_size
self.max_agents = max_agents
self.agents_pos, self.agents_states, self.num_agents = self.init_agents(
num_agents, max_agents, grid_size, key
)
self.key = key

def init_grid(self, grid_size):
return jnp.zeros((grid_size, grid_size), dtype=jnp.float32)

def init_agents(self, num_agents, max_agents, grid_size, key):
def init_agents(self, num_agents, max_agents, key):
if num_agents > max_agents:
raise(ValueError("num_agents cannot exceed max_agents"))

Expand All @@ -28,38 +20,42 @@ def init_agents(self, num_agents, max_agents, grid_size, key):
agents_states = jnp.zeros((max_agents,), dtype=jnp.float32)

# only update the ones of agents that actually exist
agents_pos = agents_pos.at[:num_agents].set(random.randint(key, (num_agents, 2), 0, grid_size))
agents_pos = agents_pos.at[:num_agents].set(random.randint(key, (num_agents, 2), 0, self.grid_size))
agents_states = agents_states.at[:num_agents].set(jnp.ones((num_agents,), dtype=jnp.float32))

return agents_pos, agents_states, num_agents


def choose_random_action(self, key_a):
return random.randint(key_a, (1, 2))

# TODO : Only move existing agents
def move_agents(self, agents_pos, grid_size, key):
# Shouldn't be able to do this when jit because of the +=
agents_pos += random.randint(key, agents_pos.shape, -1, 2)
return jnp.clip(agents_pos, 0, grid_size - 1)

def add_agent(self, agents_pos, agents_states):
if self.num_agents < self.max_agents:
agents_pos = agents_pos.at[self.num_agents].set(*random.randint(self.key, (1, 2), 0, self.grid_size))
agents_states = agents_states.at[self.num_agents].set(1)
self.num_agents += 1
print(f"Added agent {self.num_agents}")
def add_agent(self, agents_pos, agents_states, num_agents, key):
if num_agents < self.max_agents:
agents_pos = agents_pos.at[num_agents].set(*random.randint(key, (1, 2), 0, self.grid_size))
agents_states = agents_states.at[num_agents].set(1)
num_agents += 1
print(f"Added agent {num_agents}")

else:
print("Impossible to add more agents")

return agents_pos, agents_states
return agents_pos, agents_states, num_agents

def remove_agent(self):
if self.num_agents <= 0:
def remove_agent(self, num_agents):
if num_agents <= 0:
print("There is no agents to remove")
else:
self.num_agents -= 1
print(f"Removed agent {self.num_agents + 1}")
num_agents -= 1
print(f"Removed agent {num_agents + 1}")
return num_agents


def visualize(self, grid, agents_pos, delay=0.1):
def visualize(self, grid, agents_pos, num_agents, delay=0.1):
if not plt.fignum_exists(1):
plt.ion()
plt.figure(figsize=(10, 10))
Expand All @@ -68,7 +64,7 @@ def visualize(self, grid, agents_pos, delay=0.1):

plt.imshow(grid, cmap="viridis", origin="upper")
plt.scatter(
agents_pos[:self.num_agents, 0], agents_pos[:self.num_agents, 1], color="red", marker="o", label="Agents"
agents_pos[:num_agents, 0], agents_pos[:num_agents, 1], color="red", marker="o", label="Agents"
)
plt.title("Multi-Agent Simulation")
plt.xlabel("X-axis")
Expand All @@ -79,5 +75,5 @@ def visualize(self, grid, agents_pos, delay=0.1):
plt.pause(delay)


def get_env_state(self):
return self.grid, self.agents_pos, self.agents_states, self.key
def get_env_params(self):
return self.grid_size, self.max_agents
Loading

0 comments on commit 99a8749

Please sign in to comment.