Skip to content

Commit

Permalink
Clean server and client files
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Feb 13, 2024
1 parent e8cd72c commit 7a28b31
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 54 deletions.
7 changes: 3 additions & 4 deletions client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

from flax import serialization

from MultiAgentsSim.simple_simulation import SimpleSimulation
from MultiAgentsSim.two_d_simulation import SimpleSimulation
from MultiAgentsSim.utils.network import SERVER

PORT = 5050
ADDR = (SERVER, PORT)
DATA_SIZE = 10835
EVAL_TIME = 10
color = "red"

client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(ADDR)
Expand All @@ -32,7 +31,7 @@ def receive_loop():
i += 1
raw_data = client.recv(state_bytes_size)
state = serialization.from_bytes(state_example, raw_data)
SimpleSimulation.visualize_sim(state, color, grid_size=None)
SimpleSimulation.visualize_sim(state, grid_size=None)

except socket.error as e:
print(e)
Expand All @@ -47,7 +46,7 @@ def test():
i += 1
raw_data = client.recv(state_bytes_size)
state = serialization.from_bytes(state_example, raw_data)
SimpleSimulation.visualize_sim(state, color, grid_size=None)
SimpleSimulation.visualize_sim(state, grid_size=None)
client.close()

print(f"{i = } : {i / EVAL_TIME } data received per second")
Expand Down
91 changes: 41 additions & 50 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,47 @@
from jax import random
from flax import serialization

from MultiAgentsSim.simple_simulation import SimpleSimulation
from MultiAgentsSim.two_d_simulation import SimpleSimulation
from MultiAgentsSim.three_d_simulation import ThreeDSimulation
from MultiAgentsSim.utils.network import SERVER

parser = argparse.ArgumentParser()
parser.add_argument('--step_delay', type=float, default=0.5)
parser.add_argument('--step_delay', type=float, default=0.1)
args = parser.parse_args()


# Initialize server parameters
# Networking constants
PORT = 5050
ADDR = (SERVER, PORT)
DATA_SIZE = 10835 # size of data that is being transfered at each timestep
DATA_SIZE = 10835
STEP_DELAY = args.step_delay
# Simulation constants
SEED = 0
NUM_AGENTS = 5
MAX_AGENTS = 10
NUM_OBS = 3
GRID_SIZE = 20
VIZUALIZE = True

# Initialize server
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(ADDR)
server.listen()
print(f"Server started and listening ...")


# Initialize simulation parameters
NUM_AGENTS = 5
MAX_AGENTS = 10
NUM_OBS = 3
GRID_SIZE = 20
VIZUALIZE = True
STEP_DELAY = args.step_delay
print(f"{STEP_DELAY = }")

SEED = 0
# Initialize simulation
key = random.PRNGKey(SEED)
color = (1.0, 0.0, 0.0)

sim = SimpleSimulation(MAX_AGENTS, GRID_SIZE)
state = sim.init_state(NUM_AGENTS, NUM_OBS, key)

# Create a global variable to store the current array size and data + lock and event to access it
# Shared variables and locks
latest_data = state
data_lock = threading.Lock()
new_data_event = threading.Event()

# Create a function to continuously update the state of the simulation

# Continuously update the state of the simulation
def update_latest_data():
global latest_data
global state
Expand All @@ -62,56 +62,47 @@ def update_latest_data():
new_data_event.set()
time.sleep(STEP_DELAY)

# def update_color(new_color):
# global latest_data
# latest_data[1] = new_color

def handle_client(client, addr):
# Establish a connection with a client
def establish_connection(client, addr):
try:
client.send("RECEIVE_OR_UPDATE".encode())
client.send(pickle.dumps(latest_data))
print(f"{len(pickle.dumps(latest_data)) = }")
response = client.recv(DATA_SIZE).decode()
print(f"{response} connection established with {addr}")
connection_type = client.recv(DATA_SIZE).decode()
print(f"{connection_type} connection established with {addr}")
return connection_type

except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")


if response == "RECEIVE":
# Define how to communicate with a client
def communicate_with_client(connection_type):
if connection_type == "RECEIVE":
while True:
try:
new_data_event.wait()
with data_lock:
sent = serialization.to_bytes(latest_data)
client.send(sent)
# print(f"{len(sent) = }")
test = serialization.from_bytes(latest_data, sent)
# print(f"{test = }")
serialized_data = serialization.to_bytes(latest_data)
client.send(serialized_data)
new_data_event.clear()

except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")
break

# elif response == "UPDATE":
# while True:
# try:
# new_color = pickle.loads(client.recv(DATA_SIZE))
# print(f"received new color {new_color} from client")
# update_color(new_color)
# except socket.error as e:
# print(f"error: {e}")
# client.close()
# print(f"Client {client} disconnected")
# break

else:
print(f"Unknown connection type {response} detected")
else:
print(f"Unknown connection type {connection_type} detected")


except socket.error as e:
print(f"error: {e}")
client.close()
print(f"Client {client} disconnected")

# Function to handle a client when it connects to the server
def handle_client(client, addr):
connection_type = establish_connection(client, addr)
communicate_with_client(connection_type)


# Create a thread to continuously update the data
Expand Down

0 comments on commit 7a28b31

Please sign in to comment.