Skip to content

Commit

Permalink
Renamed entities to entity (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger authored Apr 25, 2024
1 parent 2637b1c commit 14ee329
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 144 deletions.
10 changes: 5 additions & 5 deletions tests/test_simulator_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_entity_state
from vivarium.simulator.states import init_state
from vivarium.simulator.states import _init_state
from vivarium.simulator.simulator import Simulator
Expand All @@ -23,13 +23,13 @@ def test_init_simulator_helper_fns():
simulator_state = init_simulator_state()
agents_state = init_agent_state(simulator_state=simulator_state)
objects_state = init_object_state(simulator_state=simulator_state)
entities_state = init_entities_state(simulator_state=simulator_state)
entity_state = init_entity_state(simulator_state=simulator_state)

state = _init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
entities_state=entities_state
entity_state=entity_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_init_simulator_args():
collision_eps=col_eps,
collision_alpha=col_alpha)

entities_state = init_entities_state(
entity_state = init_entity_state(
simulator_state,
diameter=diameter,
friction=friction)
Expand All @@ -86,7 +86,7 @@ def test_init_simulator_args():
simulator_state=simulator_state,
agents_state=agent_state,
objects_state=object_state,
entities_state=entities_state)
entity_state=entity_state)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

Expand Down
52 changes: 26 additions & 26 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax_md.rigid_body import RigidBody

from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig, stype_to_config, config_to_stype
from vivarium.simulator.states import State, SimulatorState, EntitiesState, AgentState, ObjectState, EntityType, StateType
from vivarium.simulator.states import State, SimulatorState, EntityState, AgentState, ObjectState, EntityType, StateType
from vivarium.simulator.behaviors import behavior_name_map, reversed_behavior_name_map


Expand Down Expand Up @@ -53,35 +53,35 @@ class StateFieldInfo:
exists_c_to_s = lambda x: int(x)


agent_configs_to_state_dict = {'x_position': StateFieldInfo(('entities_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s),
'y_position': StateFieldInfo(('entities_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s),
'orientation': StateFieldInfo(('entities_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'mass_center': StateFieldInfo(('entities_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s),
'mass_orientation': StateFieldInfo(('entities_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'diameter': StateFieldInfo(('entities_state', 'diameter'), None, identity_s_to_c, identity_c_to_s),
'friction': StateFieldInfo(('entities_state', 'friction'), None, identity_s_to_c, identity_c_to_s),
agent_configs_to_state_dict = {'x_position': StateFieldInfo(('entity_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s),
'y_position': StateFieldInfo(('entity_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s),
'orientation': StateFieldInfo(('entity_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'mass_center': StateFieldInfo(('entity_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s),
'mass_orientation': StateFieldInfo(('entity_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'diameter': StateFieldInfo(('entity_state', 'diameter'), None, identity_s_to_c, identity_c_to_s),
'friction': StateFieldInfo(('entity_state', 'friction'), None, identity_s_to_c, identity_c_to_s),
'left_motor': StateFieldInfo(('agent_state', 'motor',), 0, identity_s_to_c, identity_c_to_s),
'right_motor': StateFieldInfo(('agent_state', 'motor',), 1, identity_s_to_c, identity_c_to_s),
'left_prox': StateFieldInfo(('agent_state', 'prox',), 0, identity_s_to_c, identity_c_to_s),
'right_prox': StateFieldInfo(('agent_state', 'prox',), 1, identity_s_to_c, identity_c_to_s),
'behavior': StateFieldInfo(('agent_state', 'behavior',), None, behavior_s_to_c, behavior_c_to_s),
'color': StateFieldInfo(('agent_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s),
'idx': StateFieldInfo(('agent_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s),
'exists': StateFieldInfo(('entities_state', 'exists'), None, identity_s_to_c, exists_c_to_s)
'exists': StateFieldInfo(('entity_state', 'exists'), None, identity_s_to_c, exists_c_to_s)
}

agent_configs_to_state_dict.update({f: StateFieldInfo(('agent_state', f,), None, identity_s_to_c, identity_c_to_s) for f in agent_common_fields if f not in agent_configs_to_state_dict})

object_configs_to_state_dict = {'x_position': StateFieldInfo(('entities_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s),
'y_position': StateFieldInfo(('entities_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s),
'orientation': StateFieldInfo(('entities_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'mass_center': StateFieldInfo(('entities_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s),
'mass_orientation': StateFieldInfo(('entities_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'diameter': StateFieldInfo(('entities_state', 'diameter'), None, identity_s_to_c, identity_c_to_s),
'friction': StateFieldInfo(('entities_state', 'friction'), None, identity_s_to_c, identity_c_to_s),
object_configs_to_state_dict = {'x_position': StateFieldInfo(('entity_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s),
'y_position': StateFieldInfo(('entity_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s),
'orientation': StateFieldInfo(('entity_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'mass_center': StateFieldInfo(('entity_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s),
'mass_orientation': StateFieldInfo(('entity_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s),
'diameter': StateFieldInfo(('entity_state', 'diameter'), None, identity_s_to_c, identity_c_to_s),
'friction': StateFieldInfo(('entity_state', 'friction'), None, identity_s_to_c, identity_c_to_s),
'color': StateFieldInfo(('object_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s),
'idx': StateFieldInfo(('object_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s),
'exists': StateFieldInfo(('entities_state', 'exists'), None, identity_s_to_c, exists_c_to_s)
'exists': StateFieldInfo(('entity_state', 'exists'), None, identity_s_to_c, exists_c_to_s)

}

Expand All @@ -107,7 +107,7 @@ def get_default_state(n_entities_dict):
to_jit= jnp.array([1]), use_fori_loop=jnp.array([0]),
collision_alpha=jnp.array([0.]),
collision_eps=jnp.array([0.])),
entities_state=EntitiesState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)),
entity_state=EntityState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)),
momentum=None,
force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)),
mass=RigidBody(center=jnp.zeros((n_entities, 1)), orientation=jnp.zeros(n_entities)),
Expand All @@ -131,7 +131,7 @@ def get_default_state(n_entities_dict):
object_state=ObjectState(nve_idx=jnp.zeros(max_objects, dtype=int), color=jnp.zeros((max_objects, 3))))


EntitiesTuple = namedtuple('EntitiesTuple', ['idx', 'col', 'val'])
EntityTuple = namedtuple('EntityTuple', ['idx', 'col', 'val'])
ValueTuple = namedtuple('ValueData', ['nve_idx', 'col_idx', 'row_map', 'col_map', 'val'])
StateChangeTuple = namedtuple('StateChange', ['nested_field', 'nve_idx', 'column_idx', 'value'])

Expand All @@ -148,12 +148,12 @@ def events_to_nve_data(events, state):
val = state_field_info.config_to_state(e.new)

if state_field_info.column_idx is None:
nve_data[nested_field].append(EntitiesTuple(idx, None, val))
nve_data[nested_field].append(EntityTuple(idx, None, val))
elif isinstance(state_field_info.column_idx, int):
nve_data[nested_field].append(EntitiesTuple(idx, state_field_info.column_idx, val))
nve_data[nested_field].append(EntityTuple(idx, state_field_info.column_idx, val))
else:
for c, v in zip(state_field_info.column_idx, val):
nve_data[nested_field].append(EntitiesTuple(idx, c, v))
nve_data[nested_field].append(EntityTuple(idx, c, v))

return nve_data

Expand Down Expand Up @@ -222,15 +222,15 @@ def set_state_from_config_dict(config_dict, state=None):
params = configs[0].param_names()
for p in params:
state_field_info = configs_to_state_dict[stype][p]
nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'entities_state' else range(len(configs))
nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'entity_state' else range(len(configs))
change = rec_set_dataclass(state, state_field_info.nested_field, jnp.array(nve_idx), state_field_info.column_idx,
jnp.array([state_field_info.config_to_state(getattr(c, p)) for c in configs]))
state = state.set(**change)
if stype.is_entity():
e_idx.at[state.field(stype).nve_idx].set(jnp.array(range(n_entities_dict[stype])))

# TODO: something weird with the to lines below, the second one will have no effect (would need state = state.set(.)), but if we fix it we get only zeros in entities_state.entitiy_idx. As it is it seems to get correct values though
change = rec_set_dataclass(state, ('entities_state', 'entity_idx'), jnp.array(range(sum(n_entities_dict.values()))), None, e_idx)
# TODO: something weird with the to lines below, the second one will have no effect (would need state = state.set(.)), but if we fix it we get only zeros in entity_state.entitiy_idx. As it is it seems to get correct values though
change = rec_set_dataclass(state, ('entity_state', 'entity_idx'), jnp.array(range(sum(n_entities_dict.values()))), None, e_idx)
state.set(**change)

return state
Expand All @@ -239,7 +239,7 @@ def set_state_from_config_dict(config_dict, state=None):
def set_configs_from_state(state, config_dict=None):
if config_dict is None:
config_dict = {stype: [] for stype in list(StateType)}
for idx, stype_int in enumerate(state.entities_state.entity_type):
for idx, stype_int in enumerate(state.entity_state.entity_type):
stype = StateType(stype_int)
config_dict[stype].append(stype_to_config[stype](idx=idx))
config_dict[StateType.SIMULATOR].append(SimulatorConfig())
Expand Down
62 changes: 31 additions & 31 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import simulator_pb2

from vivarium.simulator.grpc_server.numproto.numproto import proto_to_ndarray, ndarray_to_proto
from vivarium.simulator.states import State, SimulatorState, EntitiesState, AgentState, ObjectState
from vivarium.simulator.states import State, SimulatorState, EntityState, AgentState, ObjectState


def proto_to_state(state):
return State(simulator_state=proto_to_simulator_state(state.simulator_state),
entities_state=proto_to_nve_state(state.entities_state),
entity_state=proto_to_nve_state(state.entity_state),
agent_state=proto_to_agent_state(state.agent_state),
object_state=proto_to_object_state(state.object_state))

Expand All @@ -29,20 +29,20 @@ def proto_to_simulator_state(simulator_state):
)


def proto_to_nve_state(entities_state):
return EntitiesState(position=RigidBody(center=proto_to_ndarray(entities_state.position.center).astype(float),
orientation=proto_to_ndarray(entities_state.position.orientation).astype(float)),
momentum=RigidBody(center=proto_to_ndarray(entities_state.momentum.center).astype(float),
orientation=proto_to_ndarray(entities_state.momentum.orientation).astype(float)),
force=RigidBody(center=proto_to_ndarray(entities_state.force.center).astype(float),
orientation=proto_to_ndarray(entities_state.force.orientation).astype(float)),
mass=RigidBody(center=proto_to_ndarray(entities_state.mass.center).astype(float),
orientation=proto_to_ndarray(entities_state.mass.orientation).astype(float)),
entity_type=proto_to_ndarray(entities_state.entity_type).astype(int),
entity_idx=proto_to_ndarray(entities_state.entity_idx).astype(int),
diameter=proto_to_ndarray(entities_state.diameter).astype(float),
friction=proto_to_ndarray(entities_state.friction).astype(float),
exists=proto_to_ndarray(entities_state.exists).astype(int)
def proto_to_nve_state(entity_state):
return EntityState(position=RigidBody(center=proto_to_ndarray(entity_state.position.center).astype(float),
orientation=proto_to_ndarray(entity_state.position.orientation).astype(float)),
momentum=RigidBody(center=proto_to_ndarray(entity_state.momentum.center).astype(float),
orientation=proto_to_ndarray(entity_state.momentum.orientation).astype(float)),
force=RigidBody(center=proto_to_ndarray(entity_state.force.center).astype(float),
orientation=proto_to_ndarray(entity_state.force.orientation).astype(float)),
mass=RigidBody(center=proto_to_ndarray(entity_state.mass.center).astype(float),
orientation=proto_to_ndarray(entity_state.mass.orientation).astype(float)),
entity_type=proto_to_ndarray(entity_state.entity_type).astype(int),
entity_idx=proto_to_ndarray(entity_state.entity_idx).astype(int),
diameter=proto_to_ndarray(entity_state.diameter).astype(float),
friction=proto_to_ndarray(entity_state.friction).astype(float),
exists=proto_to_ndarray(entity_state.exists).astype(int)
)


Expand All @@ -69,7 +69,7 @@ def proto_to_object_state(object_state):

def state_to_proto(state):
return simulator_pb2.State(simulator_state=simulator_state_to_proto(state.simulator_state),
entities_state=nve_state_to_proto(state.entities_state),
entity_state=nve_state_to_proto(state.entity_state),
agent_state=agent_state_to_proto(state.agent_state),
object_state=object_state_to_proto(state.object_state))

Expand All @@ -91,20 +91,20 @@ def simulator_state_to_proto(simulator_state):
)


def nve_state_to_proto(entities_state):
return simulator_pb2.EntitiesState(position=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.position.center),
orientation=ndarray_to_proto(entities_state.position.orientation)),
momentum=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.momentum.center),
orientation=ndarray_to_proto(entities_state.momentum.orientation)),
force=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.force.center),
orientation=ndarray_to_proto(entities_state.force.orientation)),
mass=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.mass.center),
orientation=ndarray_to_proto(entities_state.mass.orientation)),
entity_type=ndarray_to_proto(entities_state.entity_type),
entity_idx=ndarray_to_proto(entities_state.entity_idx),
diameter=ndarray_to_proto(entities_state.diameter),
friction=ndarray_to_proto(entities_state.friction),
exists=ndarray_to_proto(entities_state.exists)
def nve_state_to_proto(entity_state):
return simulator_pb2.EntityState(position=simulator_pb2.RigidBody(center=ndarray_to_proto(entity_state.position.center),
orientation=ndarray_to_proto(entity_state.position.orientation)),
momentum=simulator_pb2.RigidBody(center=ndarray_to_proto(entity_state.momentum.center),
orientation=ndarray_to_proto(entity_state.momentum.orientation)),
force=simulator_pb2.RigidBody(center=ndarray_to_proto(entity_state.force.center),
orientation=ndarray_to_proto(entity_state.force.orientation)),
mass=simulator_pb2.RigidBody(center=ndarray_to_proto(entity_state.mass.center),
orientation=ndarray_to_proto(entity_state.mass.orientation)),
entity_type=ndarray_to_proto(entity_state.entity_type),
entity_idx=ndarray_to_proto(entity_state.entity_idx),
diameter=ndarray_to_proto(entity_state.diameter),
friction=ndarray_to_proto(entity_state.friction),
exists=ndarray_to_proto(entity_state.exists)
)


Expand Down
Loading

0 comments on commit 14ee329

Please sign in to comment.