Skip to content

Commit

Permalink
Simplify init fn (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger authored Apr 24, 2024
1 parent 45fbdbf commit 8eb30d6
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 68 deletions.
21 changes: 1 addition & 20 deletions scripts/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
from omegaconf import DictConfig, OmegaConf

from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.physics_engine import dynamics_rigid
Expand All @@ -20,22 +16,7 @@ def main(cfg: DictConfig = None) -> None:
logging.basicConfig(level=cfg.log_level)

args = OmegaConf.merge(cfg.default, cfg.scene)

simulator_state = init_simulator_state(**args.simulator)

agents_state = init_agent_state(simulator_state=simulator_state, **args.agents)

objects_state = init_object_state(simulator_state=simulator_state, **args.objects)

entities_state = init_entities_state(simulator_state=simulator_state, **args.entities)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
entities_state=entities_state
)

state = init_state(args)
simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

lg.info('Simulator server started')
Expand Down
23 changes: 1 addition & 22 deletions scripts/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
from omegaconf import DictConfig, OmegaConf

from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.physics_engine import dynamics_rigid
Expand All @@ -19,28 +15,11 @@ def main(cfg: DictConfig = None) -> None:
logging.basicConfig(level=cfg.log_level)

args = OmegaConf.merge(cfg.default, cfg.scene)

simulator_state = init_simulator_state(**args.simulator)

agents_state = init_agent_state(simulator_state=simulator_state, **args.agents)

objects_state = init_object_state(simulator_state=simulator_state, **args.objects)

entities_state = init_entities_state(simulator_state=simulator_state, **args.entities)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
entities_state=entities_state
)

state = init_state(args)
simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

lg.info("Running simulation")

simulator.run(threaded=False, num_steps=cfg.num_steps)

lg.info("Simulation complete")

if __name__ == "__main__":
Expand Down
16 changes: 13 additions & 3 deletions tests/test_simulator_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,28 @@
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.states import _init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.physics_engine import dynamics_rigid


def test_init_simulator_no_args():

def test_init_simulator_no_args2():
""" Test the initialization of state without arguments """
state = init_state()
simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

assert simulator


def test_init_simulator_helper_fns():
""" Test the initialization of state without arguments """
simulator_state = init_simulator_state()
agents_state = init_agent_state(simulator_state=simulator_state)
objects_state = init_object_state(simulator_state=simulator_state)
entities_state = init_entities_state(simulator_state=simulator_state)

state = init_state(
state = _init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
Expand Down Expand Up @@ -72,7 +82,7 @@ def test_init_simulator_args():
simulator_state,
color=color)

state = init_state(
state = _init_state(
simulator_state=simulator_state,
agents_state=agent_state,
objects_state=object_state,
Expand Down
23 changes: 1 addition & 22 deletions tests/test_simulator_run.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,13 @@
from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.physics_engine import dynamics_rigid

NUM_STEPS = 50

def test_simulator_run():
simulator_state = init_simulator_state()

agents_state = init_agent_state(simulator_state=simulator_state)

objects_state = init_object_state(simulator_state=simulator_state)

entities_state = init_entities_state(simulator_state=simulator_state)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
entities_state=entities_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

state = init_state()
simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

simulator.run(threaded=False, num_steps=NUM_STEPS)

assert simulator
26 changes: 25 additions & 1 deletion vivarium/simulator/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def init_object_state(
)


def init_state(
def _init_state(
simulator_state: SimulatorState,
agents_state: AgentState,
objects_state: ObjectState,
Expand All @@ -310,6 +310,30 @@ def init_state(
entities_state=entities_state
)

def init_state(args=None):
# Use default parameters of functions if user didn't provide input
if not args:
args = {}

simulator_args = args.get('simulator', {})
agents_args = args.get('agents', {})
objects_args = args.get('agents', {})
entities_args = args.get('agents', {})

simulator_state = init_simulator_state(**simulator_args)
agents_state = init_agent_state(simulator_state=simulator_state, **agents_args)
objects_state = init_object_state(simulator_state=simulator_state, **objects_args)
entities_state = init_entities_state(simulator_state=simulator_state, **entities_args)

state = _init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
entities_state=entities_state
)

return state


def generate_default_config_files():
"""
Expand Down

0 comments on commit 8eb30d6

Please sign in to comment.