Skip to content

Commit

Permalink
Update and run simulation with SimulationWrapper in the server
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Feb 14, 2024
1 parent bb01f48 commit e1ac2fe
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 46 deletions.
11 changes: 6 additions & 5 deletions client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from flax import serialization

from MultiAgentsSim.two_d_simulation import SimpleSimulation
from MultiAgentsSim.utils.network import SERVER
from MultiAgentsSim.sim_types import SIMULATIONS
from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.utils.network import SERVER
from simulationsandbox.sim_types import SIMULATIONS


PORT = 5050
ADDR = (SERVER, PORT)
Expand Down Expand Up @@ -35,7 +36,7 @@ def receive_loop():
i += 1
raw_data = client.recv(state_bytes_size)
state = serialization.from_bytes(state_example, raw_data)
Simulation.visualize_sim(state, grid_size=GRID_SIZE)
Simulation.visualize_sim(state)

except socket.error as e:
print(e)
Expand All @@ -50,7 +51,7 @@ def test():
i += 1
raw_data = client.recv(state_bytes_size)
state = serialization.from_bytes(state_example, raw_data)
Simulation.visualize_sim(state, grid_size=GRID_SIZE)
Simulation.visualize_sim(state)
client.close()

print(f"{i = } : {i / EVAL_TIME } data received per second")
Expand Down
60 changes: 21 additions & 39 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from jax import random
from flax import serialization

from MultiAgentsSim.utils.network import SERVER
from MultiAgentsSim.sim_types import SIMULATIONS
from simulationsandbox.wrapper import SimulationWrapper
from simulationsandbox.utils.network import SERVER
from simulationsandbox.sim_types import SIMULATIONS

parser = argparse.ArgumentParser()
parser.add_argument('--sim_type', type=str, default="two_d")
Expand All @@ -28,7 +29,6 @@
MAX_AGENTS = 10
NUM_OBS = 3
GRID_SIZE = 20
VIZUALIZE = True

# Initialize server
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -44,37 +44,21 @@
state = sim.init_state(NUM_AGENTS, NUM_OBS, key)
state_byte_size = len(serialization.to_bytes(state))

# TODO : Rethink the code to delete the latest data var and only use state
# Shared variables and locks
latest_data = state
data_lock = threading.Lock()
state_lock = threading.Lock()
new_data_event = threading.Event()

print(f"{len(pickle.dumps(latest_data))}")
# Shared variables and locks
sim_lock = threading.Lock()
update_event = threading.Event()

# Continuously update the state of the simulation
def update_latest_data():
global latest_data
global state
global key

while True:
with state_lock:
key, a_key, step_key = random.split(key, 3)
actions = sim.choose_action(state.obs, a_key)
state = sim.step(state, actions, step_key)
with data_lock:
latest_data = state
new_data_event.set()
time.sleep(STEP_DELAY)
simulation = SimulationWrapper(sim, state, key, step_delay=STEP_DELAY, update_event=update_event)
print(f"{len(pickle.dumps(simulation.state))}")


# Establish a connection with a client
def establish_connection(client, addr):
try:
client.send(SIM_TYPE.encode())
client.send(pickle.dumps(latest_data))
client.send(pickle.dumps(simulation.state))
connection_type = client.recv(DATA_SIZE).decode()
print(f"{connection_type} connection established with {addr}")
return connection_type
Expand All @@ -92,10 +76,10 @@ def communicate_with_client(client, addr, connection_type):
if connection_type == "RECEIVE":
while True:
try:
new_data_event.wait()
with data_lock:
client.send(serialization.to_bytes(latest_data))
new_data_event.clear()
update_event.wait()
# with data_lock:
client.send(serialization.to_bytes(simulation.state))
update_event.clear()

except socket.error as e:
print(f"error: {e}")
Expand All @@ -113,15 +97,14 @@ def communicate_with_client(client, addr, connection_type):
print(f"Client {addr} disconnected")

elif request == "GET_STATE":
with data_lock:
client.send(serialization.to_bytes(latest_data))
with sim_lock:
client.send(serialization.to_bytes(simulation.state))

elif request == "SET_STATE":
with data_lock:
client.send(serialization.to_bytes(latest_data))
with sim_lock:
client.send(serialization.to_bytes(simulation.state))
updated_state = serialization.from_bytes(state, client.recv(state_byte_size))
with state_lock:
state = updated_state
simulation.state = updated_state

else:
print(f"Unknow request type {request}")
Expand All @@ -142,10 +125,9 @@ def handle_client(client, addr):
communicate_with_client(client, addr, connection_type)


# Create a thread to continuously update the data
update_data_thread = threading.Thread(target=update_latest_data)
update_data_thread.start()

# Start the simulation
simulation.start()
print("Simulation started")

# Start listening to clients and launch their threads
while True:
Expand Down
8 changes: 6 additions & 2 deletions simulationsandbox/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

class SimulationWrapper:

def __init__(self, simulation, state, key, step_delay=0.1, print_data=False):
def __init__(self, simulation, state, key, step_delay=0.1, update_event=None, print_data=False):
self.running = False
self.paused = False
self.stop_requested = False
self.update_thread = None
self.print_data = print_data
self.step_delay = step_delay
# simulation_dependent
self.update_event = update_event
self.simulation = simulation
self.state = state
self.key = key
Expand Down Expand Up @@ -40,6 +40,10 @@ def simulation_loop(self):
continue

self.state = self._update_simulation()

if self.update_event:
self.update_event.set()

if self.print_data:
print(f"{self.state = }")

Expand Down

0 comments on commit e1ac2fe

Please sign in to comment.