From eca0bb52113f292b3cac88850b74303480b3b893 Mon Sep 17 00:00:00 2001 From: Marsolo1 Date: Tue, 16 Apr 2024 16:44:24 +0200 Subject: [PATCH] added scene load functions --- scripts/run_server.py | 27 +++++---------- vivarium/controllers/simulator_controller.py | 4 +++ .../grpc_server/protos/simulator.proto | 5 +++ .../simulator/grpc_server/simulator_client.py | 4 +++ .../simulator/grpc_server/simulator_pb2.py | 8 +++-- .../simulator/grpc_server/simulator_pb2.pyi | 6 ++++ .../grpc_server/simulator_pb2_grpc.py | 33 +++++++++++++++++++ .../simulator/grpc_server/simulator_server.py | 5 +++ vivarium/simulator/simulator.py | 17 ++++++++-- vivarium/simulator/states.py | 16 ++++++++- 10 files changed, 100 insertions(+), 25 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index feb6f98..240a400 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -1,14 +1,12 @@ import logging import hydra +import hydra.core +import hydra.core.global_hydra from omegaconf import DictConfig, OmegaConf from vivarium.simulator import behaviors -from vivarium.simulator.states import init_simulator_state -from vivarium.simulator.states import init_agent_state -from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_entities_state -from vivarium.simulator.states import init_state +from vivarium.simulator.states import init_state_from_dict from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid from vivarium.simulator.grpc_server.simulator_server import serve @@ -21,23 +19,14 @@ def main(cfg: DictConfig = None) -> None: args = OmegaConf.merge(cfg.default, cfg.scene) - simulator_state = init_simulator_state(**args.simulator) - - agents_state = init_agent_state(simulator_state=simulator_state, **args.agents) - - objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - - entities_state = init_entities_state(simulator_state=simulator_state, **args.entities) - - state = init_state( - simulator_state=simulator_state, - agents_state=agents_state, - objects_state=objects_state, - entities_state=entities_state - ) + state = init_state_from_dict(args) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + # necessary to be able to load other scenes + glob = hydra.core.global_hydra.GlobalHydra() + glob.clear() + lg.info('Simulator server started') serve(simulator) diff --git a/vivarium/controllers/simulator_controller.py b/vivarium/controllers/simulator_controller.py index f2bceae..6b5ec91 100644 --- a/vivarium/controllers/simulator_controller.py +++ b/vivarium/controllers/simulator_controller.py @@ -104,6 +104,10 @@ def get_nve_state(self): self.state = self.client.get_nve_state() return self.state + def load_scene(self, scene): + self.client.load_scene(scene) + self.client.state = self.client.get_state() + self.__init__(client=self.client) if __name__ == "__main__": diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index dc92806..b1824be 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -40,6 +40,7 @@ service SimulatorServer { rpc Stop(google.protobuf.Empty) returns (google.protobuf.Empty) {} + rpc LoadScene(Scene) returns (google.protobuf.Empty) {} } message AgentIdx { @@ -124,3 +125,7 @@ message AddAgentInput { message IsStartedState { bool is_started = 1; } + +message Scene { + string scene = 1; +} \ No newline at end of file diff --git a/vivarium/simulator/grpc_server/simulator_client.py b/vivarium/simulator/grpc_server/simulator_client.py index a43e7a7..c95b7b4 100644 --- a/vivarium/simulator/grpc_server/simulator_client.py +++ b/vivarium/simulator/grpc_server/simulator_client.py @@ -57,3 +57,7 @@ def step(self): def is_started(self): return self.stub.IsStarted(Empty()).is_started + + def load_scene(self, scene): + message = simulator_pb2.Scene(scene=scene) + return self.stub.LoadScene(message) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index e459f53..c53a362 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xbb\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\"\x16\n\x05Scene\x12\r\n\x05scene\x18\x01 \x01(\t2\xf4\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x37\n\tLoadScene\x12\x10.simulator.Scene\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -45,6 +45,8 @@ _globals['_ADDAGENTINPUT']._serialized_end=1999 _globals['_ISSTARTEDSTATE']._serialized_start=2001 _globals['_ISSTARTEDSTATE']._serialized_end=2037 - _globals['_SIMULATORSERVER']._serialized_start=2040 - _globals['_SIMULATORSERVER']._serialized_end=2611 + _globals['_SCENE']._serialized_start=2039 + _globals['_SCENE']._serialized_end=2061 + _globals['_SIMULATORSERVER']._serialized_start=2064 + _globals['_SIMULATORSERVER']._serialized_end=2692 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 269cdd5..233dd55 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -149,3 +149,9 @@ class IsStartedState(_message.Message): IS_STARTED_FIELD_NUMBER: _ClassVar[int] is_started: bool def __init__(self, is_started: bool = ...) -> None: ... + +class Scene(_message.Message): + __slots__ = ("scene",) + SCENE_FIELD_NUMBER: _ClassVar[int] + scene: str + def __init__(self, scene: _Optional[str] = ...) -> None: ... diff --git a/vivarium/simulator/grpc_server/simulator_pb2_grpc.py b/vivarium/simulator/grpc_server/simulator_pb2_grpc.py index f11b97c..509fa15 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2_grpc.py +++ b/vivarium/simulator/grpc_server/simulator_pb2_grpc.py @@ -61,6 +61,11 @@ def __init__(self, channel): request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) + self.LoadScene = channel.unary_unary( + '/simulator.SimulatorServer/LoadScene', + request_serializer=simulator__pb2.Scene.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) class SimulatorServerServicer(object): @@ -121,6 +126,12 @@ def Stop(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def LoadScene(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_SimulatorServerServicer_to_server(servicer, server): rpc_method_handlers = { @@ -169,6 +180,11 @@ def add_SimulatorServerServicer_to_server(servicer, server): request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, ), + 'LoadScene': grpc.unary_unary_rpc_method_handler( + servicer.LoadScene, + request_deserializer=simulator__pb2.Scene.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'simulator.SimulatorServer', rpc_method_handlers) @@ -332,3 +348,20 @@ def Stop(request, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def LoadScene(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/simulator.SimulatorServer/LoadScene', + simulator__pb2.Scene.SerializeToString, + google_dot_protobuf_dot_empty__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/vivarium/simulator/grpc_server/simulator_server.py b/vivarium/simulator/grpc_server/simulator_server.py index aa919a1..06b3e02 100644 --- a/vivarium/simulator/grpc_server/simulator_server.py +++ b/vivarium/simulator/grpc_server/simulator_server.py @@ -62,6 +62,11 @@ def Start(self, request, context): def IsStarted(self, request, context): return simulator_pb2.IsStartedState(is_started=self.simulator.is_started()) + + def LoadScene(self, request, context): + with self._lock: + self.simulator.load_scene(request.scene) + return Empty() def Stop(self, request, context): self.simulator.stop() diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index d16892f..f12b716 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -13,8 +13,11 @@ from jax import lax from jax_md import space, partition, dataclasses +from hydra import compose, initialize +from omegaconf import OmegaConf + from vivarium.controllers import converters -from vivarium.simulator.states import EntityType, SimulatorState +from vivarium.simulator.states import EntityType, SimulatorState, init_state_from_dict lg = logging.getLogger(__name__) @@ -188,6 +191,14 @@ def set_state(self, nested_field, nve_idx, column_idx, value): if nested_field in (('simulator_state', 'box_size'), ('simulator_state', 'dt'), ('simulator_state', 'to_jit')): self.update_function_update() + def load_scene(self, scene): + with initialize(version_base=None, config_path="../../conf"): + args = compose(config_name="config", overrides=[f"scene={scene}"]) + + args = OmegaConf.merge(args.default, args.scene) + state = init_state_from_dict(args) + self. __init__(state, self.behavior_bank, self.dynamics_fn) + # Functions to start, stop, pause @@ -245,7 +256,9 @@ def init_state(self, state): def load_state(self, state): lg.info('load_state') - self.__init__(state, self.behavior_bank, self.dynamics_fn) + # the pause may be unnecessary + with self.pause(): + self.__init__(state, self.behavior_bank, self.dynamics_fn) # Neighbor functions diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 40c931f..09c6281 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -1,6 +1,5 @@ from enum import Enum from typing import Optional, List, Union -from collections import OrderedDict import inspect import yaml @@ -311,6 +310,21 @@ def init_state( ) +def init_state_from_dict(dictionary: dict): + simulator_state = init_simulator_state(**dictionary.simulator) + + agents_state = init_agent_state(simulator_state=simulator_state, **dictionary.agents) + + objects_state = init_object_state(simulator_state=simulator_state, **dictionary.objects) + + entities_state = init_entities_state(simulator_state=simulator_state, **dictionary.entities) + + return init_state(simulator_state = simulator_state, + agents_state = agents_state, + objects_state = objects_state, + entities_state = entities_state) + + def generate_default_config_files(): """ Generate a default yaml file with all the default arguments in the init_params_fns (see dict below)