Skip to content

Commit

Permalink
added scene load functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Marsolo1 committed Apr 16, 2024
1 parent 2ce8aa3 commit eca0bb5
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 25 deletions.
27 changes: 8 additions & 19 deletions scripts/run_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import logging
import hydra

import hydra.core
import hydra.core.global_hydra
from omegaconf import DictConfig, OmegaConf

from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.states import init_state_from_dict
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.physics_engine import dynamics_rigid
from vivarium.simulator.grpc_server.simulator_server import serve
Expand All @@ -21,23 +19,14 @@ def main(cfg: DictConfig = None) -> None:

args = OmegaConf.merge(cfg.default, cfg.scene)

simulator_state = init_simulator_state(**args.simulator)

agents_state = init_agent_state(simulator_state=simulator_state, **args.agents)

objects_state = init_object_state(simulator_state=simulator_state, **args.objects)

entities_state = init_entities_state(simulator_state=simulator_state, **args.entities)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
entities_state=entities_state
)
state = init_state_from_dict(args)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

# necessary to be able to load other scenes
glob = hydra.core.global_hydra.GlobalHydra()
glob.clear()

lg.info('Simulator server started')
serve(simulator)

Expand Down
4 changes: 4 additions & 0 deletions vivarium/controllers/simulator_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def get_nve_state(self):
self.state = self.client.get_nve_state()
return self.state

def load_scene(self, scene):
self.client.load_scene(scene)
self.client.state = self.client.get_state()
self.__init__(client=self.client)

if __name__ == "__main__":

Expand Down
5 changes: 5 additions & 0 deletions vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ service SimulatorServer {

rpc Stop(google.protobuf.Empty) returns (google.protobuf.Empty) {}

rpc LoadScene(Scene) returns (google.protobuf.Empty) {}
}

message AgentIdx {
Expand Down Expand Up @@ -124,3 +125,7 @@ message AddAgentInput {
message IsStartedState {
bool is_started = 1;
}

message Scene {
string scene = 1;
}
4 changes: 4 additions & 0 deletions vivarium/simulator/grpc_server/simulator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ def step(self):

def is_started(self):
return self.stub.IsStarted(Empty()).is_started

def load_scene(self, scene):
message = simulator_pb2.Scene(scene=scene)
return self.stub.LoadScene(message)
8 changes: 5 additions & 3 deletions vivarium/simulator/grpc_server/simulator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions vivarium/simulator/grpc_server/simulator_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,9 @@ class IsStartedState(_message.Message):
IS_STARTED_FIELD_NUMBER: _ClassVar[int]
is_started: bool
def __init__(self, is_started: bool = ...) -> None: ...

class Scene(_message.Message):
__slots__ = ("scene",)
SCENE_FIELD_NUMBER: _ClassVar[int]
scene: str
def __init__(self, scene: _Optional[str] = ...) -> None: ...
33 changes: 33 additions & 0 deletions vivarium/simulator/grpc_server/simulator_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def __init__(self, channel):
request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
)
self.LoadScene = channel.unary_unary(
'/simulator.SimulatorServer/LoadScene',
request_serializer=simulator__pb2.Scene.SerializeToString,
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
)


class SimulatorServerServicer(object):
Expand Down Expand Up @@ -121,6 +126,12 @@ def Stop(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def LoadScene(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_SimulatorServerServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -169,6 +180,11 @@ def add_SimulatorServerServicer_to_server(servicer, server):
request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
),
'LoadScene': grpc.unary_unary_rpc_method_handler(
servicer.LoadScene,
request_deserializer=simulator__pb2.Scene.FromString,
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'simulator.SimulatorServer', rpc_method_handlers)
Expand Down Expand Up @@ -332,3 +348,20 @@ def Stop(request,
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def LoadScene(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/simulator.SimulatorServer/LoadScene',
simulator__pb2.Scene.SerializeToString,
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
5 changes: 5 additions & 0 deletions vivarium/simulator/grpc_server/simulator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def Start(self, request, context):

def IsStarted(self, request, context):
return simulator_pb2.IsStartedState(is_started=self.simulator.is_started())

def LoadScene(self, request, context):
with self._lock:
self.simulator.load_scene(request.scene)
return Empty()

def Stop(self, request, context):
self.simulator.stop()
Expand Down
17 changes: 15 additions & 2 deletions vivarium/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from jax import lax
from jax_md import space, partition, dataclasses

from hydra import compose, initialize
from omegaconf import OmegaConf

from vivarium.controllers import converters
from vivarium.simulator.states import EntityType, SimulatorState
from vivarium.simulator.states import EntityType, SimulatorState, init_state_from_dict

lg = logging.getLogger(__name__)

Expand Down Expand Up @@ -188,6 +191,14 @@ def set_state(self, nested_field, nve_idx, column_idx, value):
if nested_field in (('simulator_state', 'box_size'), ('simulator_state', 'dt'), ('simulator_state', 'to_jit')):
self.update_function_update()

def load_scene(self, scene):
with initialize(version_base=None, config_path="../../conf"):
args = compose(config_name="config", overrides=[f"scene={scene}"])

args = OmegaConf.merge(args.default, args.scene)
state = init_state_from_dict(args)
self. __init__(state, self.behavior_bank, self.dynamics_fn)


# Functions to start, stop, pause

Expand Down Expand Up @@ -245,7 +256,9 @@ def init_state(self, state):

def load_state(self, state):
lg.info('load_state')
self.__init__(state, self.behavior_bank, self.dynamics_fn)
# the pause may be unnecessary
with self.pause():
self.__init__(state, self.behavior_bank, self.dynamics_fn)

# Neighbor functions

Expand Down
16 changes: 15 additions & 1 deletion vivarium/simulator/states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from enum import Enum
from typing import Optional, List, Union
from collections import OrderedDict

import inspect
import yaml
Expand Down Expand Up @@ -311,6 +310,21 @@ def init_state(
)


def init_state_from_dict(dictionary: dict):
simulator_state = init_simulator_state(**dictionary.simulator)

agents_state = init_agent_state(simulator_state=simulator_state, **dictionary.agents)

objects_state = init_object_state(simulator_state=simulator_state, **dictionary.objects)

entities_state = init_entities_state(simulator_state=simulator_state, **dictionary.entities)

return init_state(simulator_state = simulator_state,
agents_state = agents_state,
objects_state = objects_state,
entities_state = entities_state)


def generate_default_config_files():
"""
Generate a default yaml file with all the default arguments in the init_params_fns (see dict below)
Expand Down

0 comments on commit eca0bb5

Please sign in to comment.