diff --git a/client.py b/client.py index b876ce1..29387e5 100644 --- a/client.py +++ b/client.py @@ -35,7 +35,7 @@ def receive_loop(): i += 1 raw_data = client.recv(state_bytes_size) state = serialization.from_bytes(state_example, raw_data) - Simulation.visualize_sim(state) + Simulation.render(state) except socket.error as e: print(e) @@ -50,7 +50,7 @@ def test(): i += 1 raw_data = client.recv(state_bytes_size) state = serialization.from_bytes(state_example, raw_data) - Simulation.visualize_sim(state) + Simulation.render(state) client.close() print(f"{i = } : {i / EVAL_TIME } data received per second") diff --git a/simulate.py b/simulate.py index 6c99c2c..22750e8 100644 --- a/simulate.py +++ b/simulate.py @@ -40,7 +40,7 @@ def main(): state = sim.step(state, step_key) if args.visualize: - Simulation.visualize_sim(state) + Simulation.render(state) print("\nSimulation ended") if __name__ == "__main__": diff --git a/simulationsandbox/environments/lake_env.py b/simulationsandbox/environments/lake_env.py index d9c8cef..b5724b7 100644 --- a/simulationsandbox/environments/lake_env.py +++ b/simulationsandbox/environments/lake_env.py @@ -1,9 +1,4 @@ -# TODO : Créer un aquarium env en jax -# TODO : Ensure the agents do not spawn in the lake -# TODO : Change the init of lake # TODO : Add more elements to the environment -# TODO : Maybe test particle based with different speeds on earth / lake - from functools import partial @@ -28,6 +23,10 @@ class Agents: obs: jnp.array +class Object: + pos: jnp.array + + @struct.dataclass class LakeEnvState(BaseEnvState): time: int @@ -135,7 +134,7 @@ def remove_agent(self, state, agent_idx): return state @staticmethod - def visualize_sim(state): + def render(state): if not plt.fignum_exists(1): plt.ion() plt.figure(figsize=(10, 10)) diff --git a/simulationsandbox/environments/three_d_example_env.py b/simulationsandbox/environments/three_d_example_env.py index dafcb8d..efb2802 100644 --- a/simulationsandbox/environments/three_d_example_env.py +++ b/simulationsandbox/environments/three_d_example_env.py @@ -69,7 +69,7 @@ def get_env_params(self): return self.grid_size, self.max_agents @staticmethod - def visualize_sim(state): + def render(state): if not plt.fignum_exists(1): plt.ion() fig = plt.figure(figsize=(10, 10)) diff --git a/simulationsandbox/environments/two_d_example_env.py b/simulationsandbox/environments/two_d_example_env.py index 89c59bd..766a1be 100644 --- a/simulationsandbox/environments/two_d_example_env.py +++ b/simulationsandbox/environments/two_d_example_env.py @@ -65,7 +65,7 @@ def get_env_params(self): return self.grid_size, self.max_agents @staticmethod - def visualize_sim(state): + def render(state): if not plt.fignum_exists(1): plt.ion() plt.figure(figsize=(10, 10)) diff --git a/tests/test_run_sims.py b/tests/test_run_sims.py index 0b9f362..e98f5a5 100644 --- a/tests/test_run_sims.py +++ b/tests/test_run_sims.py @@ -51,7 +51,7 @@ def test_simple_simulation_run(): state = sim.step(state, actions, step_key) if VIZUALIZE: - TwoDEnv.visualize_sim(state) + TwoDEnv.render(state) print("\nSimulation ended") assert jnp.sum(state.alive) == 5 @@ -93,7 +93,7 @@ def test_three_d_simulation_run(): state = sim.step(state, actions, step_key) if VIZUALIZE: - ThreeDEnv.visualize_sim(state) + ThreeDEnv.render(state) plt.close() print("\nSimulation ended")