Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a proximity map raw sensor #76

Merged
merged 6 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions notebooks/quickstart_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
"outputs": [],
"source": [
"from vivarium.controllers.notebook_controller import NotebookController\n",
"import numpy as np\n",
"controller = NotebookController()"
]
},
Expand Down Expand Up @@ -258,7 +257,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
4 changes: 4 additions & 0 deletions vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from jax_md.rigid_body import monomer

import numpy as np


mass = monomer.mass()
mass_center = float(mass.center[0])
Expand Down Expand Up @@ -42,6 +44,8 @@ class AgentConfig(Config):
right_motor = param.Number(0., bounds=(0., 1.))
left_prox = param.Number(0., bounds=(0., 1.))
right_prox = param.Number(0., bounds=(0., 1.))
proximity_map_dist = param.Array(np.array([0.]))
proximity_map_theta = param.Array(np.array([0.]))
wheel_diameter = param.Number(2.)
diameter = param.Number(5.)
speed_mul = param.Number(1.)
Expand Down
9 changes: 9 additions & 0 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class StateFieldInfo:
mass_center_s_to_c = lambda x, typ: typ(x)
mass_center_c_to_s = lambda x: [x]
exists_c_to_s = lambda x: int(x)
neighbor_map_s_to_c = lambda x, typ: x


agent_configs_to_state_dict = {'x_position': StateFieldInfo(('entity_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s),
Expand All @@ -64,6 +65,8 @@ class StateFieldInfo:
'right_motor': StateFieldInfo(('agent_state', 'motor',), 1, identity_s_to_c, identity_c_to_s),
'left_prox': StateFieldInfo(('agent_state', 'prox',), 0, identity_s_to_c, identity_c_to_s),
'right_prox': StateFieldInfo(('agent_state', 'prox',), 1, identity_s_to_c, identity_c_to_s),
'proximity_map_dist': StateFieldInfo(('agent_state', 'proximity_map_dist',), slice(None), neighbor_map_s_to_c, identity_c_to_s),
'proximity_map_theta': StateFieldInfo(('agent_state', 'proximity_map_theta',), slice(None), neighbor_map_s_to_c, identity_c_to_s),
'behavior': StateFieldInfo(('agent_state', 'behavior',), None, behavior_s_to_c, behavior_c_to_s),
'color': StateFieldInfo(('agent_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s),
'idx': StateFieldInfo(('agent_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s),
Expand Down Expand Up @@ -120,6 +123,8 @@ def get_default_state(n_entities_dict):
agent_state=AgentState(nve_idx=jnp.zeros(max_agents, dtype=int),
prox=jnp.zeros((max_agents, 2)),
motor=jnp.zeros((max_agents, 2)),
proximity_map_dist=jnp.zeros((max_agents, 1)),
proximity_map_theta=jnp.zeros((max_agents, 1)),
behavior=jnp.zeros(max_agents, dtype=int),
wheel_diameter=jnp.zeros(max_agents),
speed_mul=jnp.zeros(max_agents),
Expand Down Expand Up @@ -203,6 +208,10 @@ def rec_set_dataclass(var, nested_field, row_idx, column_idx, value):

if len(nested_field) == 1:
field = nested_field[0]
if isinstance(column_idx, slice):
column_idx = np.arange(column_idx.start if column_idx.start is not None else 0,
column_idx.stop if column_idx.stop is not None else getattr(var, field).shape[1],
column_idx.step if column_idx.step is not None else 1, dtype=int)
if column_idx is None or len(column_idx) == 0:
d = {field: getattr(var, field).at[row_idx].set(value.reshape(-1))}
else:
Expand Down
4 changes: 4 additions & 0 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def proto_to_nve_state(entity_state):

def proto_to_agent_state(agent_state):
return AgentState(nve_idx=proto_to_ndarray(agent_state.nve_idx).astype(int),
proximity_map_dist=proto_to_ndarray(agent_state.proximity_map_dist).astype(float),
proximity_map_theta=proto_to_ndarray(agent_state.proximity_map_theta).astype(float),
prox=proto_to_ndarray(agent_state.prox).astype(float),
motor=proto_to_ndarray(agent_state.motor).astype(float),
behavior=proto_to_ndarray(agent_state.behavior).astype(int),
Expand Down Expand Up @@ -110,6 +112,8 @@ def nve_state_to_proto(entity_state):

def agent_state_to_proto(agent_state):
return simulator_pb2.AgentState(nve_idx=ndarray_to_proto(agent_state.nve_idx),
proximity_map_dist=ndarray_to_proto(agent_state.proximity_map_dist),
proximity_map_theta=ndarray_to_proto(agent_state.proximity_map_theta),
prox=ndarray_to_proto(agent_state.prox),
motor=ndarray_to_proto(agent_state.motor),
behavior=ndarray_to_proto(agent_state.behavior),
Expand Down
2 changes: 2 additions & 0 deletions vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ message AgentState {
NDArray proxs_dist_max = 9;
NDArray proxs_cos_min = 10;
NDArray color = 11;
NDArray proximity_map_dist = 12;
NDArray proximity_map_theta = 13;
}

message ObjectState {
Expand Down
28 changes: 14 additions & 14 deletions vivarium/simulator/grpc_server/simulator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions vivarium/simulator/grpc_server/simulator_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class EntityState(_message.Message):
def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class AgentState(_message.Message):
__slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color")
__slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color", "proximity_map_dist", "proximity_map_theta")
NVE_IDX_FIELD_NUMBER: _ClassVar[int]
PROX_FIELD_NUMBER: _ClassVar[int]
MOTOR_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -89,6 +89,8 @@ class AgentState(_message.Message):
PROXS_DIST_MAX_FIELD_NUMBER: _ClassVar[int]
PROXS_COS_MIN_FIELD_NUMBER: _ClassVar[int]
COLOR_FIELD_NUMBER: _ClassVar[int]
PROXIMITY_MAP_DIST_FIELD_NUMBER: _ClassVar[int]
PROXIMITY_MAP_THETA_FIELD_NUMBER: _ClassVar[int]
nve_idx: NDArray
prox: NDArray
motor: NDArray
Expand All @@ -100,7 +102,9 @@ class AgentState(_message.Message):
proxs_dist_max: NDArray
proxs_cos_min: NDArray
color: NDArray
def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...
proximity_map_dist: NDArray
proximity_map_theta: NDArray
def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ..., proximity_map_dist: _Optional[_Union[NDArray, _Mapping]] = ..., proximity_map_theta: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class ObjectState(_message.Message):
__slots__ = ("nve_idx", "custom_field", "color")
Expand Down
Loading
Loading