Skip to content

Commit

Permalink
Add notebook controller
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Feb 13, 2024
1 parent 7a28b31 commit ce1f8cd
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 25 deletions.
15 changes: 0 additions & 15 deletions MultiAgentsSim/three_d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

N_DIMS = 3

# TODO : Add colors
# Add another element in the flax dataclass so we can prove its rly ez
@struct.dataclass
class ThreeDSimState(SimState):
time: int
Expand All @@ -29,7 +27,6 @@ def __init__(self, max_agents, grid_size):
self.max_agents = max_agents
self.grid_size = grid_size


@partial(jit, static_argnums=(0, 1, 2))
def init_state(self, num_agents, num_obs, key):
x_key, y_key = random.split(key)
Expand All @@ -42,13 +39,11 @@ def init_state(self, num_agents, num_obs, key):
obs = jnp.zeros((self.max_agents, num_obs)),
colors=jnp.full(shape=(self.max_agents, 3), fill_value=jnp.array([1.0, 0.0, 0.0]))
)


@partial(jit, static_argnums=(0,))
def choose_action(self, obs, key):
return random.randint(key, shape=(obs.shape[0], N_DIMS), minval=-1, maxval=2)


@partial(jit, static_argnums=(0,))
def step(self, sim_state, actions, key):
x_pos = jnp.clip(sim_state.x_pos + actions[:, 0], 0, self.grid_size - 1)
Expand All @@ -57,7 +52,6 @@ def step(self, sim_state, actions, key):
time = sim_state.time + 1
sim_state = sim_state.replace(time=time, x_pos=x_pos, y_pos=y_pos, z_pos=z_pos)
return sim_state


def add_agent(self, sim_state, agent_idx):
sim_state = sim_state.replace(alive=sim_state.alive.at[agent_idx].set(1.0))
Expand All @@ -69,19 +63,10 @@ def remove_agent(self, sim_state, agent_idx):
sim_state = sim_state.replace(alive=sim_state.alive.at[agent_idx].set(0.0))
print(f"agent {agent_idx} removed")
return sim_state


def get_env_params(self):
return self.grid_size, self.max_agents

@staticmethod
def encode(state):
pass

@staticmethod
def decode(state):
pass

@staticmethod
def visualize_sim(state, grid_size):
if not plt.fignum_exists(1):
Expand Down
170 changes: 170 additions & 0 deletions notebook_controller.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Notebook controller to update the state of Simulation\n"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import socket\n",
"import pickle\n",
"\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"from flax import serialization\n",
"\n",
"from MultiAgentsSim.two_d_simulation import SimpleSimulation\n",
"from MultiAgentsSim.utils.network import SERVER"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"PORT = 5050\n",
"ADDR = (SERVER, PORT)\n",
"DATA_SIZE = 10835\n",
"EVAL_TIME = 10\n",
"\n",
"color_map = {\"r\": (1.0, 0.0, 0.0),\n",
" \"g\": (0.0, 1.0, 0.0),\n",
" \"b\": (0.0, 0.0, 1.0)}"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"# Start the server and intialize connection\n",
"\n",
"def connect_client():\n",
" client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
" client.connect(ADDR)\n",
" print(f\"Connected to {ADDR}\")\n",
"\n",
" msg = client.recv(1024).decode()\n",
" state_example = pickle.loads(client.recv(DATA_SIZE))\n",
" state_bytes_size = len(serialization.to_bytes(state_example))\n",
" response = \"NOTEBOOK\"\n",
" client.send(response.encode())\n",
" time.sleep(1)\n",
"\n",
" return client, state_example, state_bytes_size"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"def close_client(client):\n",
" client.send(\"CLOSE_CONNECTION\".encode())\n",
"\n",
"def get_state(client, state_example, state_bytes_size):\n",
" client.send(\"GET_STATE\".encode())\n",
" response = client.recv(state_bytes_size)\n",
" return serialization.from_bytes(state_example, response)\n",
"\n",
"def change_state_agent_color(state, idx, color):\n",
" colors = np.array(state.colors)\n",
" colors[idx] = color_map[color]\n",
" state = state.replace(colors=colors) \n",
" return state\n",
"\n",
"def set_color(client, agent_idx, color, state_example, state_bytes_size):\n",
" client.send(\"SET_STATE\".encode())\n",
" current_state = serialization.from_bytes(state_example, client.recv(state_bytes_size))\n",
" response_state = change_state_agent_color(current_state, agent_idx, color)\n",
" client.send(serialization.to_bytes(response_state))\n",
" return "
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Connected to ('localhost', 5050)\n"
]
}
],
"source": [
"client, state_example, state_bytes_size = connect_client()"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"state = get_state(client, state_example, state_bytes_size)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"state = get_state()"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"set_color(client, 2, 'g', state_example, state_bytes_size)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"close_client(client)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "myvenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
54 changes: 44 additions & 10 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@
key = random.PRNGKey(SEED)
sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE)
state = sim.init_state(NUM_AGENTS, NUM_OBS, key)
state_byte_size = len(serialization.to_bytes(state))

# TODO : Rethink the code to delete the latest data var and only use state
# Shared variables and locks
latest_data = state
data_lock = threading.Lock()
state_lock = threading.Lock()
new_data_event = threading.Event()


Expand All @@ -54,12 +57,13 @@ def update_latest_data():
global key

while True:
key, a_key, step_key = random.split(key, 3)
actions = sim.choose_action(state.obs, a_key)
state = sim.step(state, actions, step_key)
with data_lock:
latest_data = state
new_data_event.set()
with state_lock:
key, a_key, step_key = random.split(key, 3)
actions = sim.choose_action(state.obs, a_key)
state = sim.step(state, actions, step_key)
with data_lock:
latest_data = state
new_data_event.set()
time.sleep(STEP_DELAY)


Expand All @@ -79,16 +83,46 @@ def establish_connection(client, addr):


# Define how to communicate with a client
def communicate_with_client(connection_type):
def communicate_with_client(client, addr, connection_type):
global state

if connection_type == "RECEIVE":
while True:
try:
new_data_event.wait()
with data_lock:
serialized_data = serialization.to_bytes(latest_data)
client.send(serialized_data)
client.send(serialization.to_bytes(latest_data))
new_data_event.clear()

except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {addr} disconnected")
break

elif connection_type == "NOTEBOOK":
while True:
try:
request = client.recv(DATA_SIZE).decode()
print(f"{request}")
if request == "CLOSE_CONNECTION":
client.close()
print(f"Client {addr} disconnected")

elif request == "GET_STATE":
with data_lock:
client.send(serialization.to_bytes(latest_data))

elif request == "SET_STATE":
with data_lock:
client.send(serialization.to_bytes(latest_data))
updated_state = serialization.from_bytes(state, client.recv(state_byte_size))
with state_lock:
state = updated_state

else:
print(f"Unknow request type {request}")

except socket.error as e:
print(f"error: {e}")
client.close()
Expand All @@ -102,7 +136,7 @@ def communicate_with_client(connection_type):
# Function to handle a client when it connects to the server
def handle_client(client, addr):
connection_type = establish_connection(client, addr)
communicate_with_client(connection_type)
communicate_with_client(client, addr, connection_type)


# Create a thread to continuously update the data
Expand Down

0 comments on commit ce1f8cd

Please sign in to comment.