From 4e25ebb18e58b513529c649d051024dae6f4b9f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Moulin-Frier?= Date: Thu, 7 Mar 2024 13:51:49 +0100 Subject: [PATCH 1/4] First working version --- vivarium/controllers/config.py | 4 +++ vivarium/controllers/converters.py | 9 ++++++ vivarium/simulator/grpc_server/converters.py | 4 +++ .../grpc_server/protos/simulator.proto | 2 ++ .../simulator/grpc_server/simulator_pb2.py | 28 +++++++++---------- .../simulator/grpc_server/simulator_pb2.pyi | 8 ++++-- vivarium/simulator/sim_computation.py | 17 +++++++---- 7 files changed, 51 insertions(+), 21 deletions(-) diff --git a/vivarium/controllers/config.py b/vivarium/controllers/config.py index 73c2a19..051285c 100644 --- a/vivarium/controllers/config.py +++ b/vivarium/controllers/config.py @@ -6,6 +6,8 @@ from jax_md.rigid_body import monomer +import numpy as np + mass = monomer.mass() mass_center = float(mass.center[0]) @@ -42,6 +44,8 @@ class AgentConfig(Config): right_motor = param.Number(0., bounds=(0., 1.)) left_prox = param.Number(0., bounds=(0., 1.)) right_prox = param.Number(0., bounds=(0., 1.)) + neighbor_map_dist = param.Array(np.array([0.])) + neighbor_map_theta = param.Array(np.array([0.])) wheel_diameter = param.Number(2.) diameter = param.Number(5.) speed_mul = param.Number(1.) diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index e3d91ae..1067e70 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -53,6 +53,7 @@ class StateFieldInfo: mass_center_s_to_c = lambda x, typ: typ(x) mass_center_c_to_s = lambda x: [x] exists_c_to_s = lambda x: int(x) +neighbor_map_s_to_c = lambda x, typ: x agent_configs_to_state_dict = {'x_position': StateFieldInfo(('nve_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s), @@ -66,6 +67,8 @@ class StateFieldInfo: 'right_motor': StateFieldInfo(('agent_state', 'motor',), 1, identity_s_to_c, identity_c_to_s), 'left_prox': StateFieldInfo(('agent_state', 'prox',), 0, identity_s_to_c, identity_c_to_s), 'right_prox': StateFieldInfo(('agent_state', 'prox',), 1, identity_s_to_c, identity_c_to_s), + 'neighbor_map_dist': StateFieldInfo(('agent_state', 'neighbor_map_dist',), slice(None), neighbor_map_s_to_c, identity_c_to_s), + 'neighbor_map_theta': StateFieldInfo(('agent_state', 'neighbor_map_theta',), slice(None), neighbor_map_s_to_c, identity_c_to_s), 'behavior': StateFieldInfo(('agent_state', 'behavior',), None, behavior_s_to_c, behavior_c_to_s), 'color': StateFieldInfo(('agent_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s), 'idx': StateFieldInfo(('agent_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s), @@ -118,6 +121,8 @@ def get_default_state(n_entities_dict): exists=jnp.ones(n_entities, dtype=int) ), agent_state=AgentState(nve_idx=jnp.zeros(n_agents, dtype=int), + neighbor_map_dist=jnp.zeros((n_agents, 1)), + neighbor_map_theta=jnp.zeros((n_agents, 1)), prox=jnp.zeros((n_agents, 2)), motor=jnp.zeros((n_agents, 2)), behavior=jnp.zeros(n_agents, dtype=int), @@ -202,6 +207,10 @@ def rec_set_dataclass(var, nested_field, row_idx, column_idx, value): if len(nested_field) == 1: field = nested_field[0] + if isinstance(column_idx, slice): + column_idx = np.arange(column_idx.start if column_idx.start is not None else 0, + column_idx.stop if column_idx.stop is not None else getattr(var, field).shape[1], + column_idx.step if column_idx.step is not None else 1, dtype=int) if column_idx is None or len(column_idx) == 0: d = {field: getattr(var, field).at[row_idx].set(value.reshape(-1))} else: diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index 5289db8..3435adf 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -45,6 +45,8 @@ def proto_to_nve_state(nve_state): def proto_to_agent_state(agent_state): return AgentState(nve_idx=proto_to_ndarray(agent_state.nve_idx).astype(int), + neighbor_map_dist=proto_to_ndarray(agent_state.neighbor_map_dist).astype(float), + neighbor_map_theta=proto_to_ndarray(agent_state.neighbor_map_theta).astype(float), prox=proto_to_ndarray(agent_state.prox).astype(float), motor=proto_to_ndarray(agent_state.motor).astype(float), behavior=proto_to_ndarray(agent_state.behavior).astype(int), @@ -104,6 +106,8 @@ def nve_state_to_proto(nve_state): def agent_state_to_proto(agent_state): return simulator_pb2.AgentState(nve_idx=ndarray_to_proto(agent_state.nve_idx), + neighbor_map_dist=ndarray_to_proto(agent_state.neighbor_map_dist), + neighbor_map_theta=ndarray_to_proto(agent_state.neighbor_map_theta), prox=ndarray_to_proto(agent_state.prox), motor=ndarray_to_proto(agent_state.motor), behavior=ndarray_to_proto(agent_state.behavior), diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index ab03f90..8bf1e8c 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -91,6 +91,8 @@ message AgentState { NDArray proxs_dist_max = 8; NDArray proxs_cos_min = 9; NDArray color = 10; + NDArray neighbor_map_dist = 11; + NDArray neighbor_map_theta = 12; } message ObjectState { diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index d9b51c9..4652016 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -14,7 +14,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\x8d\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\x8d\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xef\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12-\n\x11neighbor_map_dist\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12.\n\x12neighbor_map_theta\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -33,17 +33,17 @@ _globals['_NVESTATE']._serialized_start=603 _globals['_NVESTATE']._serialized_end=959 _globals['_AGENTSTATE']._serialized_start=962 - _globals['_AGENTSTATE']._serialized_end=1362 - _globals['_OBJECTSTATE']._serialized_start=1364 - _globals['_OBJECTSTATE']._serialized_end=1491 - _globals['_STATE']._serialized_start=1494 - _globals['_STATE']._serialized_end=1683 - _globals['_STATECHANGE']._serialized_start=1685 - _globals['_STATECHANGE']._serialized_end=1789 - _globals['_ADDAGENTINPUT']._serialized_start=1791 - _globals['_ADDAGENTINPUT']._serialized_end=1851 - _globals['_ISSTARTEDSTATE']._serialized_start=1853 - _globals['_ISSTARTEDSTATE']._serialized_end=1889 - _globals['_SIMULATORSERVER']._serialized_start=1892 - _globals['_SIMULATORSERVER']._serialized_end=2458 + _globals['_AGENTSTATE']._serialized_end=1457 + _globals['_OBJECTSTATE']._serialized_start=1459 + _globals['_OBJECTSTATE']._serialized_end=1586 + _globals['_STATE']._serialized_start=1589 + _globals['_STATE']._serialized_end=1778 + _globals['_STATECHANGE']._serialized_start=1780 + _globals['_STATECHANGE']._serialized_end=1884 + _globals['_ADDAGENTINPUT']._serialized_start=1886 + _globals['_ADDAGENTINPUT']._serialized_end=1946 + _globals['_ISSTARTEDSTATE']._serialized_start=1948 + _globals['_ISSTARTEDSTATE']._serialized_end=1984 + _globals['_SIMULATORSERVER']._serialized_start=1987 + _globals['_SIMULATORSERVER']._serialized_end=2553 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 7564e20..e9e2b29 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -73,7 +73,7 @@ class NVEState(_message.Message): def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AgentState(_message.Message): - __slots__ = ["nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color"] + __slots__ = ["nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color", "neighbor_map_dist", "neighbor_map_theta"] NVE_IDX_FIELD_NUMBER: _ClassVar[int] PROX_FIELD_NUMBER: _ClassVar[int] MOTOR_FIELD_NUMBER: _ClassVar[int] @@ -84,6 +84,8 @@ class AgentState(_message.Message): PROXS_DIST_MAX_FIELD_NUMBER: _ClassVar[int] PROXS_COS_MIN_FIELD_NUMBER: _ClassVar[int] COLOR_FIELD_NUMBER: _ClassVar[int] + NEIGHBOR_MAP_DIST_FIELD_NUMBER: _ClassVar[int] + NEIGHBOR_MAP_THETA_FIELD_NUMBER: _ClassVar[int] nve_idx: NDArray prox: NDArray motor: NDArray @@ -94,7 +96,9 @@ class AgentState(_message.Message): proxs_dist_max: NDArray proxs_cos_min: NDArray color: NDArray - def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... + neighbor_map_dist: NDArray + neighbor_map_theta: NDArray + def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_map_dist: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_map_theta: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class ObjectState(_message.Message): __slots__ = ["nve_idx", "custom_field", "color"] diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index 3a03e4c..8b5314b 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -50,6 +50,8 @@ def velocity(self) -> util.Array: @dataclass class AgentState: nve_idx: util.Array # idx in NVEState + neighbor_map_dist: util.Array + neighbor_map_theta: util.Array prox: util.Array motor: util.Array behavior: util.Array @@ -223,9 +225,15 @@ def compute_prox(state, agent_neighs_idx, target_exists_mask): Ra = body.center[senders] Rb = body.center[receivers] dR = - space.map_bond(displacement)(Ra, Rb) # Looks like it should be opposite, but don't understand why - prox = sensor(dR, body.orientation[senders], state.agent_state.proxs_dist_max[senders], + + dist_theta = proximity_map(dR, body.orientation[senders]) + neighbor_map_dist = jnp.zeros((state.agent_state.nve_idx.shape[0], state.nve_state.entity_idx.shape[0])) + neighbor_map_dist = neighbor_map_dist.at[agent_neighs_idx[0, :], agent_neighs_idx[1, :]].set(dist_theta[:, 0]) + neighbor_map_theta = jnp.zeros((state.agent_state.nve_idx.shape[0], state.nve_state.entity_idx.shape[0])) + neighbor_map_theta = neighbor_map_theta.at[agent_neighs_idx[0, :], agent_neighs_idx[1, :]].set(dist_theta[:, 1]) + prox = sensor(dist_theta[:, 0], dist_theta[:, 1], state.agent_state.proxs_dist_max[senders], state.agent_state.proxs_cos_min[senders], len(state.agent_state.nve_idx), senders, mask) - return state.agent_state.set(prox=prox) + return state.agent_state.set(neighbor_map_dist=neighbor_map_dist, neighbor_map_theta=neighbor_map_theta, prox=prox) def sensorimotor(agent_state): @@ -256,7 +264,7 @@ def dist_theta(displ, theta): norm_displ = displ / dist theta_displ = jnp.arccos(norm_displ[0]) * jnp.sign(jnp.arcsin(norm_displ[1])) relative_theta = theta_displ - theta - return dist, relative_theta + return jnp.array([dist, relative_theta]) proximity_map = vmap(dist_theta, (0, 0)) @@ -283,8 +291,7 @@ def sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists): sensor_fn = vmap(sensor_fn, (0, 0, 0, 0, 0)) -def sensor(displ, theta, dist_max, cos_min, n_agents, senders, target_exists): - dist, relative_theta = proximity_map(displ, theta) +def sensor(dist, relative_theta, dist_max, cos_min, n_agents, senders, target_exists): proxs = ops.segment_max(sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists), senders, n_agents) return proxs From ca23da9d3572cbe7df9ec64d95e2677163510009 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Moulin-Frier?= Date: Mon, 15 Apr 2024 20:38:56 +0200 Subject: [PATCH 2/4] Adapt code to merge new master into this branch --- vivarium/controllers/config.py | 4 +-- vivarium/controllers/converters.py | 8 ++--- vivarium/simulator/grpc_server/converters.py | 8 ++--- .../grpc_server/protos/simulator.proto | 4 +-- .../simulator/grpc_server/simulator_pb2.py | 33 +++++++++---------- .../simulator/grpc_server/simulator_pb2.pyi | 28 +++++++++------- vivarium/simulator/physics_engine.py | 15 ++++++--- vivarium/simulator/states.py | 6 +++- 8 files changed, 59 insertions(+), 47 deletions(-) diff --git a/vivarium/controllers/config.py b/vivarium/controllers/config.py index c34a4d9..6684125 100644 --- a/vivarium/controllers/config.py +++ b/vivarium/controllers/config.py @@ -44,8 +44,8 @@ class AgentConfig(Config): right_motor = param.Number(0., bounds=(0., 1.)) left_prox = param.Number(0., bounds=(0., 1.)) right_prox = param.Number(0., bounds=(0., 1.)) - neighbor_map_dist = param.Array(np.array([0.])) - neighbor_map_theta = param.Array(np.array([0.])) + proximity_map_dist = param.Array(np.array([0.])) + proximity_map_theta = param.Array(np.array([0.])) wheel_diameter = param.Number(2.) diameter = param.Number(5.) speed_mul = param.Number(1.) diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index a505772..a275494 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -65,8 +65,8 @@ class StateFieldInfo: 'right_motor': StateFieldInfo(('agent_state', 'motor',), 1, identity_s_to_c, identity_c_to_s), 'left_prox': StateFieldInfo(('agent_state', 'prox',), 0, identity_s_to_c, identity_c_to_s), 'right_prox': StateFieldInfo(('agent_state', 'prox',), 1, identity_s_to_c, identity_c_to_s), - 'neighbor_map_dist': StateFieldInfo(('agent_state', 'neighbor_map_dist',), slice(None), neighbor_map_s_to_c, identity_c_to_s), - 'neighbor_map_theta': StateFieldInfo(('agent_state', 'neighbor_map_theta',), slice(None), neighbor_map_s_to_c, identity_c_to_s), + 'proximity_map_dist': StateFieldInfo(('agent_state', 'proximity_map_dist',), slice(None), neighbor_map_s_to_c, identity_c_to_s), + 'proximity_map_theta': StateFieldInfo(('agent_state', 'proximity_map_theta',), slice(None), neighbor_map_s_to_c, identity_c_to_s), 'behavior': StateFieldInfo(('agent_state', 'behavior',), None, behavior_s_to_c, behavior_c_to_s), 'color': StateFieldInfo(('agent_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s), 'idx': StateFieldInfo(('agent_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s), @@ -123,8 +123,8 @@ def get_default_state(n_entities_dict): agent_state=AgentState(nve_idx=jnp.zeros(max_agents, dtype=int), prox=jnp.zeros((max_agents, 2)), motor=jnp.zeros((max_agents, 2)), - neighbor_map_dist=jnp.zeros((max_agents, 1)), - neighbor_map_theta=jnp.zeros((max_agents, 1)), + proximity_map_dist=jnp.zeros((max_agents, 1)), + proximity_map_theta=jnp.zeros((max_agents, 1)), behavior=jnp.zeros(max_agents, dtype=int), wheel_diameter=jnp.zeros(max_agents), speed_mul=jnp.zeros(max_agents), diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index 94cc128..2b2a8d5 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -48,8 +48,8 @@ def proto_to_nve_state(entities_state): def proto_to_agent_state(agent_state): return AgentState(nve_idx=proto_to_ndarray(agent_state.nve_idx).astype(int), - neighbor_map_dist=proto_to_ndarray(agent_state.neighbor_map_dist).astype(float), - neighbor_map_theta=proto_to_ndarray(agent_state.neighbor_map_theta).astype(float), + proximity_map_dist=proto_to_ndarray(agent_state.proximity_map_dist).astype(float), + proximity_map_theta=proto_to_ndarray(agent_state.proximity_map_theta).astype(float), prox=proto_to_ndarray(agent_state.prox).astype(float), motor=proto_to_ndarray(agent_state.motor).astype(float), behavior=proto_to_ndarray(agent_state.behavior).astype(int), @@ -112,8 +112,8 @@ def nve_state_to_proto(entities_state): def agent_state_to_proto(agent_state): return simulator_pb2.AgentState(nve_idx=ndarray_to_proto(agent_state.nve_idx), - neighbor_map_dist=ndarray_to_proto(agent_state.neighbor_map_dist), - neighbor_map_theta=ndarray_to_proto(agent_state.neighbor_map_theta), + proximity_map_dist=ndarray_to_proto(agent_state.proximity_map_dist), + proximity_map_theta=ndarray_to_proto(agent_state.proximity_map_theta), prox=ndarray_to_proto(agent_state.prox), motor=ndarray_to_proto(agent_state.motor), behavior=ndarray_to_proto(agent_state.behavior), diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index ea23632..92f6969 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -94,8 +94,8 @@ message AgentState { NDArray proxs_dist_max = 9; NDArray proxs_cos_min = 10; NDArray color = 11; - NDArray neighbor_map_dist = 12; - NDArray neighbor_map_theta = 13; + NDArray proximity_map_dist = 12; + NDArray proximity_map_theta = 13; } message ObjectState { diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index e459f53..4babe02 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: simulator.proto -# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -15,14 +14,14 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xbb\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x98\x04\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12.\n\x12proximity_map_dist\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\x12/\n\x13proximity_map_theta\x18\r \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xbb\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'simulator_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - _globals['DESCRIPTOR']._options = None - _globals['DESCRIPTOR']._serialized_options = b'\n\032io.grpc.examples.simulatorB\016SimulatorProtoP\001\242\002\003SIM' + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\032io.grpc.examples.simulatorB\016SimulatorProtoP\001\242\002\003SIM' _globals['_AGENTIDX']._serialized_start=59 _globals['_AGENTIDX']._serialized_end=82 _globals['_NDARRAY']._serialized_start=84 @@ -34,17 +33,17 @@ _globals['_ENTITIESSTATE']._serialized_start=695 _globals['_ENTITIESSTATE']._serialized_end=1056 _globals['_AGENTSTATE']._serialized_start=1059 - _globals['_AGENTSTATE']._serialized_end=1498 - _globals['_OBJECTSTATE']._serialized_start=1500 - _globals['_OBJECTSTATE']._serialized_end=1627 - _globals['_STATE']._serialized_start=1630 - _globals['_STATE']._serialized_end=1829 - _globals['_STATECHANGE']._serialized_start=1831 - _globals['_STATECHANGE']._serialized_end=1935 - _globals['_ADDAGENTINPUT']._serialized_start=1937 - _globals['_ADDAGENTINPUT']._serialized_end=1999 - _globals['_ISSTARTEDSTATE']._serialized_start=2001 - _globals['_ISSTARTEDSTATE']._serialized_end=2037 - _globals['_SIMULATORSERVER']._serialized_start=2040 - _globals['_SIMULATORSERVER']._serialized_end=2611 + _globals['_AGENTSTATE']._serialized_end=1595 + _globals['_OBJECTSTATE']._serialized_start=1597 + _globals['_OBJECTSTATE']._serialized_end=1724 + _globals['_STATE']._serialized_start=1727 + _globals['_STATE']._serialized_end=1926 + _globals['_STATECHANGE']._serialized_start=1928 + _globals['_STATECHANGE']._serialized_end=2032 + _globals['_ADDAGENTINPUT']._serialized_start=2034 + _globals['_ADDAGENTINPUT']._serialized_end=2096 + _globals['_ISSTARTEDSTATE']._serialized_start=2098 + _globals['_ISSTARTEDSTATE']._serialized_end=2134 + _globals['_SIMULATORSERVER']._serialized_start=2137 + _globals['_SIMULATORSERVER']._serialized_end=2708 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 269cdd5..0bd4c48 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -7,19 +7,19 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class AgentIdx(_message.Message): - __slots__ = ("idx",) + __slots__ = ["idx"] IDX_FIELD_NUMBER: _ClassVar[int] idx: _containers.RepeatedScalarFieldContainer[int] def __init__(self, idx: _Optional[_Iterable[int]] = ...) -> None: ... class NDArray(_message.Message): - __slots__ = ("ndarray",) + __slots__ = ["ndarray"] NDARRAY_FIELD_NUMBER: _ClassVar[int] ndarray: bytes def __init__(self, ndarray: _Optional[bytes] = ...) -> None: ... class RigidBody(_message.Message): - __slots__ = ("center", "orientation") + __slots__ = ["center", "orientation"] CENTER_FIELD_NUMBER: _ClassVar[int] ORIENTATION_FIELD_NUMBER: _ClassVar[int] center: NDArray @@ -27,7 +27,7 @@ class RigidBody(_message.Message): def __init__(self, center: _Optional[_Union[NDArray, _Mapping]] = ..., orientation: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class SimulatorState(_message.Message): - __slots__ = ("idx", "box_size", "max_agents", "max_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha") + __slots__ = ["idx", "box_size", "max_agents", "max_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha"] IDX_FIELD_NUMBER: _ClassVar[int] BOX_SIZE_FIELD_NUMBER: _ClassVar[int] MAX_AGENTS_FIELD_NUMBER: _ClassVar[int] @@ -55,7 +55,7 @@ class SimulatorState(_message.Message): def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., max_agents: _Optional[_Union[NDArray, _Mapping]] = ..., max_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ..., collision_eps: _Optional[_Union[NDArray, _Mapping]] = ..., collision_alpha: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class EntitiesState(_message.Message): - __slots__ = ("position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists") + __slots__ = ["position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists"] POSITION_FIELD_NUMBER: _ClassVar[int] MOMENTUM_FIELD_NUMBER: _ClassVar[int] FORCE_FIELD_NUMBER: _ClassVar[int] @@ -77,7 +77,7 @@ class EntitiesState(_message.Message): def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AgentState(_message.Message): - __slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color") + __slots__ = ["nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color", "proximity_map_dist", "proximity_map_theta"] NVE_IDX_FIELD_NUMBER: _ClassVar[int] PROX_FIELD_NUMBER: _ClassVar[int] MOTOR_FIELD_NUMBER: _ClassVar[int] @@ -89,6 +89,8 @@ class AgentState(_message.Message): PROXS_DIST_MAX_FIELD_NUMBER: _ClassVar[int] PROXS_COS_MIN_FIELD_NUMBER: _ClassVar[int] COLOR_FIELD_NUMBER: _ClassVar[int] + PROXIMITY_MAP_DIST_FIELD_NUMBER: _ClassVar[int] + PROXIMITY_MAP_THETA_FIELD_NUMBER: _ClassVar[int] nve_idx: NDArray prox: NDArray motor: NDArray @@ -100,10 +102,12 @@ class AgentState(_message.Message): proxs_dist_max: NDArray proxs_cos_min: NDArray color: NDArray - def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... + proximity_map_dist: NDArray + proximity_map_theta: NDArray + def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ..., proximity_map_dist: _Optional[_Union[NDArray, _Mapping]] = ..., proximity_map_theta: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class ObjectState(_message.Message): - __slots__ = ("nve_idx", "custom_field", "color") + __slots__ = ["nve_idx", "custom_field", "color"] NVE_IDX_FIELD_NUMBER: _ClassVar[int] CUSTOM_FIELD_FIELD_NUMBER: _ClassVar[int] COLOR_FIELD_NUMBER: _ClassVar[int] @@ -113,7 +117,7 @@ class ObjectState(_message.Message): def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., custom_field: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class State(_message.Message): - __slots__ = ("simulator_state", "entities_state", "agent_state", "object_state") + __slots__ = ["simulator_state", "entities_state", "agent_state", "object_state"] SIMULATOR_STATE_FIELD_NUMBER: _ClassVar[int] ENTITIES_STATE_FIELD_NUMBER: _ClassVar[int] AGENT_STATE_FIELD_NUMBER: _ClassVar[int] @@ -125,7 +129,7 @@ class State(_message.Message): def __init__(self, simulator_state: _Optional[_Union[SimulatorState, _Mapping]] = ..., entities_state: _Optional[_Union[EntitiesState, _Mapping]] = ..., agent_state: _Optional[_Union[AgentState, _Mapping]] = ..., object_state: _Optional[_Union[ObjectState, _Mapping]] = ...) -> None: ... class StateChange(_message.Message): - __slots__ = ("nve_idx", "col_idx", "nested_field", "value") + __slots__ = ["nve_idx", "col_idx", "nested_field", "value"] NVE_IDX_FIELD_NUMBER: _ClassVar[int] COL_IDX_FIELD_NUMBER: _ClassVar[int] NESTED_FIELD_FIELD_NUMBER: _ClassVar[int] @@ -137,7 +141,7 @@ class StateChange(_message.Message): def __init__(self, nve_idx: _Optional[_Iterable[int]] = ..., col_idx: _Optional[_Iterable[int]] = ..., nested_field: _Optional[_Iterable[str]] = ..., value: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AddAgentInput(_message.Message): - __slots__ = ("max_agents", "serialized_config") + __slots__ = ["max_agents", "serialized_config"] MAX_AGENTS_FIELD_NUMBER: _ClassVar[int] SERIALIZED_CONFIG_FIELD_NUMBER: _ClassVar[int] max_agents: int @@ -145,7 +149,7 @@ class AddAgentInput(_message.Message): def __init__(self, max_agents: _Optional[int] = ..., serialized_config: _Optional[str] = ...) -> None: ... class IsStartedState(_message.Message): - __slots__ = ("is_started",) + __slots__ = ["is_started"] IS_STARTED_FIELD_NUMBER: _ClassVar[int] is_started: bool def __init__(self, is_started: bool = ...) -> None: ... diff --git a/vivarium/simulator/physics_engine.py b/vivarium/simulator/physics_engine.py index f052633..c254bd4 100644 --- a/vivarium/simulator/physics_engine.py +++ b/vivarium/simulator/physics_engine.py @@ -187,7 +187,7 @@ def dist_theta(displ, theta): norm_displ = displ / dist theta_displ = jnp.arccos(norm_displ[0]) * jnp.sign(jnp.arcsin(norm_displ[1])) relative_theta = theta_displ - theta - return dist, relative_theta + return jnp.array([dist, relative_theta]) proximity_map = vmap(dist_theta, (0, 0)) @@ -212,8 +212,7 @@ def sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists): sensor_fn = vmap(sensor_fn, (0, 0, 0, 0, 0)) -def sensor(displ, theta, dist_max, cos_min, max_agents, senders, target_exists): - dist, relative_theta = proximity_map(displ, theta) +def sensor(dist, relative_theta, dist_max, cos_min, max_agents, senders, target_exists): proxs = ops.segment_max(sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists), senders, max_agents) return proxs @@ -276,9 +275,15 @@ def compute_prox(state, agent_neighs_idx, target_exists_mask): Ra = body.center[senders] Rb = body.center[receivers] dR = - space.map_bond(displacement)(Ra, Rb) # Looks like it should be opposite, but don't understand why - prox = sensor(dR, body.orientation[senders], state.agent_state.proxs_dist_max[senders], + + dist_theta = proximity_map(dR, body.orientation[senders]) + proximity_map_dist = jnp.zeros((state.agent_state.nve_idx.shape[0], state.entities_state.entity_idx.shape[0])) + proximity_map_dist = proximity_map_dist.at[agent_neighs_idx[0, :], agent_neighs_idx[1, :]].set(dist_theta[:, 0]) + proximity_map_theta = jnp.zeros((state.agent_state.nve_idx.shape[0], state.entities_state.entity_idx.shape[0])) + proximity_map_theta = proximity_map_theta.at[agent_neighs_idx[0, :], agent_neighs_idx[1, :]].set(dist_theta[:, 1]) + prox = sensor(dist_theta[:, 0], dist_theta[:, 1], state.agent_state.proxs_dist_max[senders], state.agent_state.proxs_cos_min[senders], len(state.agent_state.nve_idx), senders, mask) - return state.agent_state.set(prox=prox) + return state.agent_state.set(proximity_map_dist=proximity_map_dist, proximity_map_theta=proximity_map_theta, prox=prox) def sensorimotor(agent_state): motor = multi_switch(agent_state.behavior, agent_state.prox, agent_state.motor) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 40c931f..24c7ade 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -33,7 +33,7 @@ def to_entity_type(self): return EntityType(self.value) -# No need to define position, momentum, force, and mass (i.e already in simulate.EntitiesState) +# No need to define position, momentum, force, and mass (i.e already in jax_md.simulate.NVEState) @dataclass class EntitiesState(simulate.NVEState): entity_type: util.Array @@ -52,6 +52,8 @@ class AgentState: nve_idx: util.Array # idx in EntitiesState prox: util.Array motor: util.Array + proximity_map_dist: util.Array + proximity_map_theta: util.Array behavior: util.Array wheel_diameter: util.Array speed_mul: util.Array @@ -269,6 +271,8 @@ def init_agent_state( nve_idx=jnp.arange(max_agents, dtype=int), prox=jnp.zeros((max_agents, 2)), motor=jnp.zeros((max_agents, 2)), + proximity_map_dist=jnp.zeros((max_agents, 1)), + proximity_map_theta=jnp.zeros((max_agents, 1)), behavior=jnp.full((max_agents), behavior), wheel_diameter=jnp.full((max_agents), wheel_diameter), speed_mul=jnp.full((max_agents), speed_mul), From 8b02763af712440cdae6fcc14f303d75726cdf8b Mon Sep 17 00:00:00 2001 From: corentinlger Date: Fri, 26 Apr 2024 15:02:10 +0200 Subject: [PATCH 3/4] Recompile grpc and add doc --- .../simulator/grpc_server/simulator_pb2.py | 33 ++++++++++--------- .../simulator/grpc_server/simulator_pb2.pyi | 18 +++++----- vivarium/simulator/physics_engine.py | 8 +++-- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index ace76cd..1f0ae8e 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: simulator.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -14,14 +15,14 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe7\x02\n\x0b\x45ntityState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc3\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12,\n\x0c\x65ntity_state\x18\x02 \x01(\x0b\x32\x16.simulator.EntityState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb9\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12?\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.EntityState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe7\x02\n\x0b\x45ntityState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x98\x04\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12.\n\x12proximity_map_dist\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\x12/\n\x13proximity_map_theta\x18\r \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc3\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12,\n\x0c\x65ntity_state\x18\x02 \x01(\x0b\x32\x16.simulator.EntityState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb9\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12?\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.EntityState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'simulator_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\032io.grpc.examples.simulatorB\016SimulatorProtoP\001\242\002\003SIM' + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\032io.grpc.examples.simulatorB\016SimulatorProtoP\001\242\002\003SIM' _globals['_AGENTIDX']._serialized_start=59 _globals['_AGENTIDX']._serialized_end=82 _globals['_NDARRAY']._serialized_start=84 @@ -33,17 +34,17 @@ _globals['_ENTITYSTATE']._serialized_start=695 _globals['_ENTITYSTATE']._serialized_end=1054 _globals['_AGENTSTATE']._serialized_start=1057 - _globals['_AGENTSTATE']._serialized_end=1496 - _globals['_OBJECTSTATE']._serialized_start=1498 - _globals['_OBJECTSTATE']._serialized_end=1625 - _globals['_STATE']._serialized_start=1628 - _globals['_STATE']._serialized_end=1823 - _globals['_STATECHANGE']._serialized_start=1825 - _globals['_STATECHANGE']._serialized_end=1929 - _globals['_ADDAGENTINPUT']._serialized_start=1931 - _globals['_ADDAGENTINPUT']._serialized_end=1993 - _globals['_ISSTARTEDSTATE']._serialized_start=1995 - _globals['_ISSTARTEDSTATE']._serialized_end=2031 - _globals['_SIMULATORSERVER']._serialized_start=2034 - _globals['_SIMULATORSERVER']._serialized_end=2603 + _globals['_AGENTSTATE']._serialized_end=1593 + _globals['_OBJECTSTATE']._serialized_start=1595 + _globals['_OBJECTSTATE']._serialized_end=1722 + _globals['_STATE']._serialized_start=1725 + _globals['_STATE']._serialized_end=1920 + _globals['_STATECHANGE']._serialized_start=1922 + _globals['_STATECHANGE']._serialized_end=2026 + _globals['_ADDAGENTINPUT']._serialized_start=2028 + _globals['_ADDAGENTINPUT']._serialized_end=2090 + _globals['_ISSTARTEDSTATE']._serialized_start=2092 + _globals['_ISSTARTEDSTATE']._serialized_end=2128 + _globals['_SIMULATORSERVER']._serialized_start=2131 + _globals['_SIMULATORSERVER']._serialized_end=2700 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 959ac4a..be74bf5 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -7,19 +7,19 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class AgentIdx(_message.Message): - __slots__ = ["idx"] + __slots__ = ("idx",) IDX_FIELD_NUMBER: _ClassVar[int] idx: _containers.RepeatedScalarFieldContainer[int] def __init__(self, idx: _Optional[_Iterable[int]] = ...) -> None: ... class NDArray(_message.Message): - __slots__ = ["ndarray"] + __slots__ = ("ndarray",) NDARRAY_FIELD_NUMBER: _ClassVar[int] ndarray: bytes def __init__(self, ndarray: _Optional[bytes] = ...) -> None: ... class RigidBody(_message.Message): - __slots__ = ["center", "orientation"] + __slots__ = ("center", "orientation") CENTER_FIELD_NUMBER: _ClassVar[int] ORIENTATION_FIELD_NUMBER: _ClassVar[int] center: NDArray @@ -27,7 +27,7 @@ class RigidBody(_message.Message): def __init__(self, center: _Optional[_Union[NDArray, _Mapping]] = ..., orientation: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class SimulatorState(_message.Message): - __slots__ = ["idx", "box_size", "max_agents", "max_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha"] + __slots__ = ("idx", "box_size", "max_agents", "max_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha") IDX_FIELD_NUMBER: _ClassVar[int] BOX_SIZE_FIELD_NUMBER: _ClassVar[int] MAX_AGENTS_FIELD_NUMBER: _ClassVar[int] @@ -77,7 +77,7 @@ class EntityState(_message.Message): def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AgentState(_message.Message): - __slots__ = ["nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color", "proximity_map_dist", "proximity_map_theta"] + __slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color", "proximity_map_dist", "proximity_map_theta") NVE_IDX_FIELD_NUMBER: _ClassVar[int] PROX_FIELD_NUMBER: _ClassVar[int] MOTOR_FIELD_NUMBER: _ClassVar[int] @@ -107,7 +107,7 @@ class AgentState(_message.Message): def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ..., proximity_map_dist: _Optional[_Union[NDArray, _Mapping]] = ..., proximity_map_theta: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class ObjectState(_message.Message): - __slots__ = ["nve_idx", "custom_field", "color"] + __slots__ = ("nve_idx", "custom_field", "color") NVE_IDX_FIELD_NUMBER: _ClassVar[int] CUSTOM_FIELD_FIELD_NUMBER: _ClassVar[int] COLOR_FIELD_NUMBER: _ClassVar[int] @@ -129,7 +129,7 @@ class State(_message.Message): def __init__(self, simulator_state: _Optional[_Union[SimulatorState, _Mapping]] = ..., entity_state: _Optional[_Union[EntityState, _Mapping]] = ..., agent_state: _Optional[_Union[AgentState, _Mapping]] = ..., object_state: _Optional[_Union[ObjectState, _Mapping]] = ...) -> None: ... class StateChange(_message.Message): - __slots__ = ["nve_idx", "col_idx", "nested_field", "value"] + __slots__ = ("nve_idx", "col_idx", "nested_field", "value") NVE_IDX_FIELD_NUMBER: _ClassVar[int] COL_IDX_FIELD_NUMBER: _ClassVar[int] NESTED_FIELD_FIELD_NUMBER: _ClassVar[int] @@ -141,7 +141,7 @@ class StateChange(_message.Message): def __init__(self, nve_idx: _Optional[_Iterable[int]] = ..., col_idx: _Optional[_Iterable[int]] = ..., nested_field: _Optional[_Iterable[str]] = ..., value: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AddAgentInput(_message.Message): - __slots__ = ["max_agents", "serialized_config"] + __slots__ = ("max_agents", "serialized_config") MAX_AGENTS_FIELD_NUMBER: _ClassVar[int] SERIALIZED_CONFIG_FIELD_NUMBER: _ClassVar[int] max_agents: int @@ -149,7 +149,7 @@ class AddAgentInput(_message.Message): def __init__(self, max_agents: _Optional[int] = ..., serialized_config: _Optional[str] = ...) -> None: ... class IsStartedState(_message.Message): - __slots__ = ["is_started"] + __slots__ = ("is_started",) IS_STARTED_FIELD_NUMBER: _ClassVar[int] is_started: bool def __init__(self, is_started: bool = ...) -> None: ... diff --git a/vivarium/simulator/physics_engine.py b/vivarium/simulator/physics_engine.py index 9c19a31..ee8cf1c 100644 --- a/vivarium/simulator/physics_engine.py +++ b/vivarium/simulator/physics_engine.py @@ -276,11 +276,15 @@ def compute_prox(state, agent_neighs_idx, target_exists_mask): Rb = body.center[receivers] dR = - space.map_bond(displacement)(Ra, Rb) # Looks like it should be opposite, but don't understand why + # dist_theta[:, 0] = dist array, dist_theta[:, 1] = theta array dist_theta = proximity_map(dR, body.orientation[senders]) - proximity_map_dist = jnp.zeros((state.agent_state.nve_idx.shape[0], state.entities_state.entity_idx.shape[0])) + # Create distance map between entities + proximity_map_dist = jnp.zeros((state.agent_state.nve_idx.shape[0], state.entity_state.entity_idx.shape[0])) proximity_map_dist = proximity_map_dist.at[agent_neighs_idx[0, :], agent_neighs_idx[1, :]].set(dist_theta[:, 0]) - proximity_map_theta = jnp.zeros((state.agent_state.nve_idx.shape[0], state.entities_state.entity_idx.shape[0])) + # Create theta map between entities + proximity_map_theta = jnp.zeros((state.agent_state.nve_idx.shape[0], state.entity_state.entity_idx.shape[0])) proximity_map_theta = proximity_map_theta.at[agent_neighs_idx[0, :], agent_neighs_idx[1, :]].set(dist_theta[:, 1]) + prox = sensor(dist_theta[:, 0], dist_theta[:, 1], state.agent_state.proxs_dist_max[senders], state.agent_state.proxs_cos_min[senders], len(state.agent_state.nve_idx), senders, mask) return state.agent_state.set(proximity_map_dist=proximity_map_dist, proximity_map_theta=proximity_map_theta, prox=prox) From 9a656ff8e419be91f6113b22bef10ecd07f6577d Mon Sep 17 00:00:00 2001 From: corentinlger Date: Fri, 26 Apr 2024 15:08:32 +0200 Subject: [PATCH 4/4] Remove testing cells from notebook --- notebooks/quickstart_tutorial.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/notebooks/quickstart_tutorial.ipynb b/notebooks/quickstart_tutorial.ipynb index 37f7a2f..c33b10c 100644 --- a/notebooks/quickstart_tutorial.ipynb +++ b/notebooks/quickstart_tutorial.ipynb @@ -40,7 +40,6 @@ "outputs": [], "source": [ "from vivarium.controllers.notebook_controller import NotebookController\n", - "import numpy as np\n", "controller = NotebookController()" ] }, @@ -258,7 +257,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.12" } }, "nbformat": 4,