From 8eb30d6360f262115b5e9bde033c3c37a7118677 Mon Sep 17 00:00:00 2001 From: Corentin <111868204+corentinlger@users.noreply.github.com> Date: Wed, 24 Apr 2024 17:08:08 +0200 Subject: [PATCH] Simplify init fn (#83) --- scripts/run_server.py | 21 +-------------------- scripts/run_simulation.py | 23 +---------------------- tests/test_simulator_init.py | 16 +++++++++++++--- tests/test_simulator_run.py | 23 +---------------------- vivarium/simulator/states.py | 26 +++++++++++++++++++++++++- 5 files changed, 41 insertions(+), 68 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index feb6f98..43c01e0 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -4,10 +4,6 @@ from omegaconf import DictConfig, OmegaConf from vivarium.simulator import behaviors -from vivarium.simulator.states import init_simulator_state -from vivarium.simulator.states import init_agent_state -from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid @@ -20,22 +16,7 @@ def main(cfg: DictConfig = None) -> None: logging.basicConfig(level=cfg.log_level) args = OmegaConf.merge(cfg.default, cfg.scene) - - simulator_state = init_simulator_state(**args.simulator) - - agents_state = init_agent_state(simulator_state=simulator_state, **args.agents) - - objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - - entities_state = init_entities_state(simulator_state=simulator_state, **args.entities) - - state = init_state( - simulator_state=simulator_state, - agents_state=agents_state, - objects_state=objects_state, - entities_state=entities_state - ) - + state = init_state(args) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) lg.info('Simulator server started') diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index 4aae95b..c42ef86 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -4,10 +4,6 @@ from omegaconf import DictConfig, OmegaConf from vivarium.simulator import behaviors -from vivarium.simulator.states import init_simulator_state -from vivarium.simulator.states import init_agent_state -from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid @@ -19,28 +15,11 @@ def main(cfg: DictConfig = None) -> None: logging.basicConfig(level=cfg.log_level) args = OmegaConf.merge(cfg.default, cfg.scene) - - simulator_state = init_simulator_state(**args.simulator) - - agents_state = init_agent_state(simulator_state=simulator_state, **args.agents) - - objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - - entities_state = init_entities_state(simulator_state=simulator_state, **args.entities) - - state = init_state( - simulator_state=simulator_state, - agents_state=agents_state, - objects_state=objects_state, - entities_state=entities_state - ) - + state = init_state(args) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) lg.info("Running simulation") - simulator.run(threaded=False, num_steps=cfg.num_steps) - lg.info("Simulation complete") if __name__ == "__main__": diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index 5b6447b..5d5b522 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -4,18 +4,28 @@ from vivarium.simulator.states import init_object_state from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state +from vivarium.simulator.states import _init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid -def test_init_simulator_no_args(): + +def test_init_simulator_no_args2(): + """ Test the initialization of state without arguments """ + state = init_state() + simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + + assert simulator + + +def test_init_simulator_helper_fns(): """ Test the initialization of state without arguments """ simulator_state = init_simulator_state() agents_state = init_agent_state(simulator_state=simulator_state) objects_state = init_object_state(simulator_state=simulator_state) entities_state = init_entities_state(simulator_state=simulator_state) - state = init_state( + state = _init_state( simulator_state=simulator_state, agents_state=agents_state, objects_state=objects_state, @@ -72,7 +82,7 @@ def test_init_simulator_args(): simulator_state, color=color) - state = init_state( + state = _init_state( simulator_state=simulator_state, agents_state=agent_state, objects_state=object_state, diff --git a/tests/test_simulator_run.py b/tests/test_simulator_run.py index f018afd..70d6210 100644 --- a/tests/test_simulator_run.py +++ b/tests/test_simulator_run.py @@ -1,8 +1,4 @@ from vivarium.simulator import behaviors -from vivarium.simulator.states import init_simulator_state -from vivarium.simulator.states import init_agent_state -from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid @@ -10,25 +6,8 @@ NUM_STEPS = 50 def test_simulator_run(): - simulator_state = init_simulator_state() - - agents_state = init_agent_state(simulator_state=simulator_state) - - objects_state = init_object_state(simulator_state=simulator_state) - - entities_state = init_entities_state(simulator_state=simulator_state) - - state = init_state( - simulator_state=simulator_state, - agents_state=agents_state, - objects_state=objects_state, - entities_state=entities_state - ) - - simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) - + state = init_state() simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) - simulator.run(threaded=False, num_steps=NUM_STEPS) assert simulator \ No newline at end of file diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 40c931f..2dd9210 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -296,7 +296,7 @@ def init_object_state( ) -def init_state( +def _init_state( simulator_state: SimulatorState, agents_state: AgentState, objects_state: ObjectState, @@ -310,6 +310,30 @@ def init_state( entities_state=entities_state ) +def init_state(args=None): + # Use default parameters of functions if user didn't provide input + if not args: + args = {} + + simulator_args = args.get('simulator', {}) + agents_args = args.get('agents', {}) + objects_args = args.get('agents', {}) + entities_args = args.get('agents', {}) + + simulator_state = init_simulator_state(**simulator_args) + agents_state = init_agent_state(simulator_state=simulator_state, **agents_args) + objects_state = init_object_state(simulator_state=simulator_state, **objects_args) + entities_state = init_entities_state(simulator_state=simulator_state, **entities_args) + + state = _init_state( + simulator_state=simulator_state, + agents_state=agents_state, + objects_state=objects_state, + entities_state=entities_state + ) + + return state + def generate_default_config_files(): """