Skip to content

Commit

Permalink
Improve server-client files and create a first minimalistic example o…
Browse files Browse the repository at this point in the history
…f plotting data on a distant client
  • Loading branch information
corentinlger committed Feb 6, 2024
1 parent 3b2e98b commit cb01597
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 49 deletions.
21 changes: 19 additions & 2 deletions MultiAgentsSim/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,29 @@

import jax.numpy as jnp
from jax import random, jit
from flax import struct
import matplotlib.pyplot as plt


# import matplotlib
# matplotlib.use('agg')

# TODO :
@struct.dataclass
class SimState:
time: int

@struct.dataclass
class SimParams:
max_agents: int
grid_size: int



class Simulation:
def __init__(self, max_agents, grid_size):
self.grid_size = grid_size
self.grid = self.init_grid(grid_size)
self.max_agents = max_agents

def init_grid(self, grid_size):
Expand Down Expand Up @@ -56,7 +73,7 @@ def remove_agent(self, num_agents):
return num_agents

@staticmethod
def visualize_sim(grid, agents_pos, num_agents, delay=0.1, color="red"):
def visualize_sim(grid, agents_pos, num_agents, color="red"):
if not plt.fignum_exists(1):
plt.ion()
plt.figure(figsize=(10, 10))
Expand All @@ -73,7 +90,7 @@ def visualize_sim(grid, agents_pos, num_agents, delay=0.1, color="red"):
plt.legend()

plt.draw()
plt.pause(delay)
plt.pause(0.001)


def get_env_params(self):
Expand Down
43 changes: 43 additions & 0 deletions client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import socket
import pickle
import threading

from MultiAgentsSim.simulation import Simulation

SERVER = '10.204.2.189'
# SERVER = '192.168.1.24'
PORT = 5050
ADDR = (SERVER, PORT)

DATA_SIZE = 4096


client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(ADDR)
print(f"Connected to {ADDR}")

msg = client.recv(1024).decode()
print(f"server message: {msg}")
response = "RECEIVE"
client.send(response.encode())
print(f"responded: {response}")

def receive():
while True:
try:
raw_data = client.recv(DATA_SIZE)
data = pickle.loads(raw_data)
print(f"data received: {data}")
timestep, grid, agents_pos, agents_states, num_agents, color, key = data
Simulation.visualize_sim(grid, agents_pos, num_agents, color)

except socket.error as e:
print(e)
client.close()
break

# Matplotlib intreactive doesn't work outside the main thread
# receive_thread = threading.Thread(target=receive)
# receive_thread.start()

receive()
2 changes: 1 addition & 1 deletion conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ params:
num_steps: 50
random_seed: 0
visualize: True
viz_delay: 0.1
step_delay: 0.1
7 changes: 5 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import hydra
from omegaconf import DictConfig, OmegaConf
from jax import random
Expand All @@ -15,7 +17,7 @@ def main(cfg: DictConfig):
grid_size = cfg.params.grid_size
num_steps = cfg.params.num_steps
visualize = cfg.params.visualize
viz_delay = cfg.params.viz_delay
step_delay = cfg.params.step_delay

key = random.PRNGKey(cfg.params.random_seed)

Expand All @@ -30,6 +32,7 @@ def main(cfg: DictConfig):

color = "red"
for step in range(num_steps):
time.sleep(step_delay)
key, a_key, add_key = random.split(key, 3)

if step % 10 == 0:
Expand All @@ -50,7 +53,7 @@ def main(cfg: DictConfig):
agents_states += 0.1

if visualize:
Simulation.visualize_sim(grid, agents_pos, num_agents, viz_delay, color)
Simulation.visualize_sim(grid, agents_pos, num_agents, color)

print("\nSimulation ended")

Expand Down
130 changes: 130 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import time
import socket
import pickle
import threading

import jax

from MultiAgentsSim.simulation import Simulation
from MultiAgentsSim.agents import Agents

SERVER = '10.204.2.189'
# SERVER = '192.168.1.24'
PORT = 5050
ADDR = (SERVER, PORT)

server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(ADDR)
server.listen()
print(f"Server started and listening ...")



# SIM PART
NUM_AGENTS = 5
MAX_AGENTS = 10
GRID_SIZE = 20
NUM_STEPS = 50
VIZUALIZE = True
STEP_DELAY = 0.3
SEED = 0

DATA_SIZE = 4096

key = jax.random.PRNGKey(SEED)

sim = Simulation(MAX_AGENTS, GRID_SIZE)
agents = Agents(MAX_AGENTS, GRID_SIZE)

grid = sim.init_grid(GRID_SIZE)
agents_pos, agents_states, num_agents = agents.init_agents(NUM_AGENTS, MAX_AGENTS, key)
color = "red"



def get_new_timestep_data(timestep, grid, agents_pos, agents_states, num_agents, color, key):
key, a_key, add_key = jax.random.split(key, 3)
actions = agents.choose_action(agents_pos, a_key)
agents_pos = sim.move_agents(agents_pos, actions)
agents_states += 0.1
return (timestep, grid, agents_pos, agents_states, num_agents, color, key)


def update_latest_data():
global latest_data
global timestep
while True:
with data_lock:
timestep, grid, agents_pos, agents_states, num_agents, color, key = latest_data
latest_data = get_new_timestep_data(timestep, grid, agents_pos, agents_states, num_agents, color, key)
new_data_event.set()
timestep += 1
time.sleep(STEP_DELAY)

def update_array_size(new_size):
global array_size
array_size = new_size

def handle_client(client, addr):
try:
client.send("RECEIVE_OR_UPDATE".encode())
response = client.recv(DATA_SIZE).decode()
print(f"{response} established with {addr}")

if response == "RECEIVE":
while True:
try:
new_data_event.wait()
with data_lock:
data = latest_data
client.send(pickle.dumps(data))
new_data_event.clear()
except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")
break

elif response == "UPDATE":
while True:
try:
new_size = pickle.loads(client.recv(DATA_SIZE))
update_array_size(new_size)
except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")
break

else:
print(f"Unknown connection type {response} detected")


except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")


# Create a global variable to store the current array size and data + lock and event to access it
timestep = 0
latest_data = get_new_timestep_data(timestep, grid, agents_pos, agents_states, num_agents, color, key)
data_lock = threading.Lock()
new_data_event = threading.Event()

# Create a thread to continuously update the data
update_data_thread = threading.Thread(target=update_latest_data)
update_data_thread.start()


# Start listening to clients and launch their threads
while True:
try:
client, addr = server.accept()
print(f"Connected with {addr}")

client_thread = threading.Thread(target=handle_client, args=(client, addr))
client_thread.start()
except socket.error as e:
print(f"error: {e}")

6 changes: 6 additions & 0 deletions server_client_connection/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
client.connect(ADDR)
print(f"Connected to {ADDR}")

msg = client.recv(1024).decode()
print(f"server message: {msg}")
response = "RECEIVE"
client.send(response.encode())
print(f"responded: {response}")

def receive():
while True:
try:
Expand Down
80 changes: 39 additions & 41 deletions server_client_connection/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
SERVER = '10.204.2.189'
# SERVER = '192.168.1.24'
PORT = 5050
UPDATE_PORT = 5051
ADDR = (SERVER, PORT)

server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -32,35 +31,47 @@ def update_array_size(new_size):
global array_size
array_size = new_size

def handle_client(client):
while True:
try:
new_data_event.wait()
with data_lock:
data = latest_data
client.send(pickle.dumps(data))
new_data_event.clear()

except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")
break

def handle_update_client(update_client):
while True:
try:
new_size = pickle.loads(update_client.recv(1024))
update_array_size(new_size)
except socket.error as e:
print(f"error: {e}")
update_client.close()
print("Update client disconnected")
break
def handle_client(client, addr):
try:
client.send("RECEIVE_OR_UPDATE".encode())
response = client.recv(1024).decode()
print(f"{response} established with {addr}")

if response == "RECEIVE":
while True:
try:
new_data_event.wait()
with data_lock:
data = latest_data
client.send(pickle.dumps(data))
new_data_event.clear()
except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")

elif response == "UPDATE":
while True:
try:
new_size = pickle.loads(client.recv(1024))
update_array_size(new_size)
except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")

else:
print(f"Unknown connection type {response} detected")


except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")


# Create a global variable to store the current array size and data + lock and event to access it
array_size = (5, 5)
# Create a global variable to store the current data + lock and event to access it
latest_data = generate_random_array(array_size)
data_lock = threading.Lock()
new_data_event = threading.Event()
Expand All @@ -70,26 +81,13 @@ def handle_update_client(update_client):
update_data_thread.start()


# UPDATE SERVER PART : NOT PRETTY BUT WILL CHANGE IT
update_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
update_server.bind((SERVER, UPDATE_PORT))
update_server.listen()
print(f"Update server started and listening ...")

update_client_socket, _ = update_server.accept()
print("Update client connected")

update_client_thread = threading.Thread(target=handle_update_client, args=(update_client_socket,))
update_client_thread.start()


# Start listening to clients and launch their threads
while True:
try:
client, addr = server.accept()
print(f"Connected with {addr}")

client_thread = threading.Thread(target=handle_client, args=(client, ))
client_thread = threading.Thread(target=handle_client, args=(client, addr))
client_thread.start()
except socket.error as e:
print(f"error: {e}")
Expand Down
Loading

0 comments on commit cb01597

Please sign in to comment.