Skip to content

Commit

Permalink
Add wrapper to pause and run the simulation and rename package
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Feb 14, 2024
1 parent c333e02 commit 0889fa0
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 27 deletions.
5 changes: 0 additions & 5 deletions MultiAgentsSim/sim_types.py

This file was deleted.

5 changes: 2 additions & 3 deletions simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from omegaconf import DictConfig, OmegaConf
from jax import random

from MultiAgentsSim.two_d_simulation import SimpleSimulation
from MultiAgentsSim.three_d_simulation import ThreeDSimulation
from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.three_d_simulation import ThreeDSimulation


@hydra.main(version_base=None, config_path="conf", config_name="config")
Expand Down Expand Up @@ -34,7 +34,6 @@ def main(cfg: DictConfig):
sim = Simulation(max_agents, grid_size)
state = sim.init_state(num_agents, num_obs, key)


# Launch a simulation
print("Simulation started")
for timestep in range(num_steps):
Expand Down
File renamed without changes.
8 changes: 1 addition & 7 deletions MultiAgentsSim/base.py → simulationsandbox/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,28 @@
class SimState:
time: int




class Simulation:
def __init__(self):
raise(NotImplementedError)


@partial(jit, static_argnums=(0, 1, 2))
def init_state(self) -> SimState:
raise(NotImplementedError)


# Should be moved in another class
@partial(jit, static_argnums=(0,))
def choose_action(self, obs):
raise(NotImplementedError)


# Should also return new obs for agents
@partial(jit, static_argnums=(0,))
def step(self, state: SimState, actions, key) -> SimState:
raise(NotImplementedError)


def get_env_params(self):
raise(NotImplementedError)


@staticmethod
def encode(state):
raise(NotImplementedError)
Expand Down
5 changes: 5 additions & 0 deletions simulationsandbox/sim_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.three_d_simulation import ThreeDSimulation

SIMULATIONS = {"two_d": SimpleSimulation,
"three_d": ThreeDSimulation}
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from flax import struct
import matplotlib.pyplot as plt

from MultiAgentsSim.base import Simulation, SimState
from simulationsandbox.base import Simulation, SimState

N_DIMS = 3

@struct.dataclass
class ThreeDSimState(SimState):
time: int
grid: jnp.array
grid_size: int
alive: jnp.array
x_pos: jnp.array
y_pos: jnp.array
Expand All @@ -32,6 +33,7 @@ def init_state(self, num_agents, num_obs, key):
x_key, y_key = random.split(key)
return ThreeDSimState(time=0,
grid=jnp.zeros((self.grid_size, self.grid_size, self.grid_size), dtype=jnp.float32),
grid_size=self.grid_size,
alive = jnp.hstack((jnp.ones(num_agents), jnp.zeros(self.max_agents - num_agents))),
x_pos=random.randint(key=x_key, shape=(self.max_agents,), minval=0, maxval=self.grid_size),
y_pos=random.randint(key=y_key, shape=(self.max_agents,), minval=0, maxval=self.grid_size),
Expand Down Expand Up @@ -68,7 +70,7 @@ def get_env_params(self):
return self.grid_size, self.max_agents

@staticmethod
def visualize_sim(state, grid_size):
def visualize_sim(state):
if not plt.fignum_exists(1):
plt.ion()
fig = plt.figure(figsize=(10, 10))
Expand All @@ -90,9 +92,9 @@ def visualize_sim(state, grid_size):
ax.set_ylabel("Y-axis")
ax.set_zlabel("Z-axis")

ax.set_xlim(0, grid_size)
ax.set_ylim(0, grid_size)
ax.set_zlim(0, grid_size)
ax.set_xlim(0, state.grid_size)
ax.set_ylim(0, state.grid_size)
ax.set_zlim(0, state.grid_size)

ax.legend()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import jax.numpy as jnp
from jax import random, jit
from flax import struct
import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

from MultiAgentsSim.base import Simulation, SimState
from simulationsandbox.base import Simulation, SimState

N_DIMS = 2

Expand Down Expand Up @@ -39,6 +37,7 @@ def init_state(self, num_agents, num_obs, key):
colors=jnp.full(shape=(self.max_agents, 3), fill_value=jnp.array([1.0, 0.0, 0.0]))
)

# Could even be implemented in the step function because we do not have RL agents choosing actions
@partial(jit, static_argnums=(0,))
def choose_action(self, obs, key):
return random.randint(key, shape=(obs.shape[0], N_DIMS), minval=-1, maxval=2)
Expand All @@ -65,7 +64,7 @@ def get_env_params(self):
return self.grid_size, self.max_agents

@staticmethod
def visualize_sim(state, grid_size=None):
def visualize_sim(state):
if not plt.fignum_exists(1):
plt.ion()
plt.figure(figsize=(10, 10))
Expand Down
File renamed without changes.
7 changes: 4 additions & 3 deletions tests/test_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import jax.numpy as jnp
from jax import random

from MultiAgentsSim.two_d_simulation import SimpleSimulation
from MultiAgentsSim.three_d_simulation import ThreeDSimulation
from simulationsandbox.two_d_simulation import SimpleSimulation
from simulationsandbox.three_d_simulation import ThreeDSimulation

NUM_AGENTS = 5
MAX_AGENTS = 10
Expand Down Expand Up @@ -95,6 +95,7 @@ def test_three_d_simulation_run():
state = state.replace(colors=state.colors.at[0, 2].set(1.0))

if timestep == 40:

state = sim.remove_agent(state, 2)
state = sim.remove_agent(state, 1)
state = sim.remove_agent(state, 4)
Expand All @@ -104,6 +105,6 @@ def test_three_d_simulation_run():
state = sim.step(state, actions, step_key)

if VIZUALIZE:
ThreeDSimulation.visualize_sim(state, grid_size=GRID_SIZE)
ThreeDSimulation.visualize_sim(state)
print("\nSimulation ended")

83 changes: 83 additions & 0 deletions wrapp_simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import threading
import time

from jax import random

from simulationsandbox.two_d_simulation import SimpleSimulation

# Independant of any simulation environment
class SimulationWrapper:
"""
Simulation wrapper to start, pause, resume and stop a simulation.
Independant of any simulation type.
"""
def __init__(self, simulation, state, key):
self.running = False
self.paused = False
self.stop_requested = False
self.update_thread = None
# simulation_dependent
self.simulation = simulation
self.state = state
self.key = key

def start(self):
if not self.running:
self.running = True
self.stop_requested = False
self.update_thread = threading.Thread(target=self.simulation_loop)
self.update_thread.start()

def pause(self):
self.paused = True

def resume(self):
self.paused = False

def stop(self):
self.stop_requested = True

def simulation_loop(self):
while not self.stop_requested:
if self.paused:
time.sleep(0.1)
continue

self.state = self._update_simulation()
print(f"{self.state = }")

time.sleep(0.1)

def _update_simulation(self):
self.key, a_key, step_key = random.split(self.key, 3)
actions = self.simulation.choose_action(self.state.obs, a_key)
return self.simulation.step(self.state, actions, step_key)

NUM_AGENTS = 5
MAX_AGENTS = 10
NUM_OBS = 3
GRID_SIZE = 20
SLEEP_TIME = 5
SEED = 0

key = random.PRNGKey(SEED)
env = SimpleSimulation(MAX_AGENTS, GRID_SIZE)
state = env.init_state(NUM_AGENTS, NUM_OBS, key)

# Example usage:
sim = SimulationWrapper(env, state, key)

print('Started')
sim.start_simulation()
time.sleep(SLEEP_TIME)

sim.pause_simulation()
print('Paused')
time.sleep(SLEEP_TIME)

print('Resumed')
sim.resume_simulation()
time.sleep(SLEEP_TIME)

sim.stop_simulation()
print('stopped')

0 comments on commit 0889fa0

Please sign in to comment.