From 41882d0319d0e65e4ef30c70a1c29a032e6092ab Mon Sep 17 00:00:00 2001 From: Marsolo1 Date: Wed, 10 Apr 2024 16:57:59 +0200 Subject: [PATCH 1/5] added max_speed attribute --- conf/scene/default.yaml | 1 + tests/test_simulator_init.py | 2 ++ vivarium/controllers/config.py | 1 + vivarium/controllers/converters.py | 1 + vivarium/simulator/grpc_server/converters.py | 2 ++ .../grpc_server/protos/simulator.proto | 9 +++--- .../simulator/grpc_server/simulator_pb2.py | 28 +++++++++---------- .../simulator/grpc_server/simulator_pb2.pyi | 6 ++-- vivarium/simulator/sim_computation.py | 2 ++ vivarium/simulator/states.py | 3 ++ 10 files changed, 35 insertions(+), 20 deletions(-) diff --git a/conf/scene/default.yaml b/conf/scene/default.yaml index e810bd9..806d061 100644 --- a/conf/scene/default.yaml +++ b/conf/scene/default.yaml @@ -26,6 +26,7 @@ agents: behavior: 1 wheel_diameter: 2. speed_mul: 1. + max_speed: 10. theta_mul: 1. prox_dist_max: 40. prox_cos_min: 0. diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index 7ea3718..10abe9b 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -40,6 +40,7 @@ def test_init_simulator_args(): behavior = 1 wheel_diameter = 2.0 speed_mul = 1.0 + max_speed = 10.0 theta_mul = 1.0 prox_dist_max = 20.0 prox_cos_min = 0.0 @@ -62,6 +63,7 @@ def test_init_simulator_args(): behavior=behavior, wheel_diameter=wheel_diameter, speed_mul=speed_mul, + max_speed=max_speed, theta_mul=theta_mul, prox_dist_max=prox_dist_max, prox_cos_min=prox_cos_min) diff --git a/vivarium/controllers/config.py b/vivarium/controllers/config.py index fa280a5..cf53097 100644 --- a/vivarium/controllers/config.py +++ b/vivarium/controllers/config.py @@ -45,6 +45,7 @@ class AgentConfig(Config): wheel_diameter = param.Number(2.) diameter = param.Number(5.) speed_mul = param.Number(1.) + max_speed = param.Number(10.) theta_mul = param.Number(1.) proxs_dist_max = param.Number(100., bounds=(0, None)) proxs_cos_min = param.Number(0., bounds=(-1., 1.)) diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index aa34ad7..58c4702 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -123,6 +123,7 @@ def get_default_state(n_entities_dict): behavior=jnp.zeros(n_agents, dtype=int), wheel_diameter=jnp.zeros(n_agents), speed_mul=jnp.zeros(n_agents), + max_speed=jnp.zeros(n_agents), theta_mul=jnp.zeros(n_agents), proxs_dist_max=jnp.zeros(n_agents), proxs_cos_min=jnp.zeros(n_agents), diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index 96cbd1f..13447b2 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -53,6 +53,7 @@ def proto_to_agent_state(agent_state): behavior=proto_to_ndarray(agent_state.behavior).astype(int), wheel_diameter=proto_to_ndarray(agent_state.wheel_diameter).astype(float), speed_mul=proto_to_ndarray(agent_state.speed_mul).astype(float), + max_speed=proto_to_ndarray(agent_state.max_speed).astype(float), theta_mul=proto_to_ndarray(agent_state.theta_mul).astype(float), proxs_dist_max=proto_to_ndarray(agent_state.proxs_dist_max).astype(float), proxs_cos_min=proto_to_ndarray(agent_state.proxs_cos_min).astype(float), @@ -114,6 +115,7 @@ def agent_state_to_proto(agent_state): behavior=ndarray_to_proto(agent_state.behavior), wheel_diameter=ndarray_to_proto(agent_state.wheel_diameter), speed_mul=ndarray_to_proto(agent_state.speed_mul), + max_speed=ndarray_to_proto(agent_state.max_speed), theta_mul=ndarray_to_proto(agent_state.theta_mul), proxs_dist_max=ndarray_to_proto(agent_state.proxs_dist_max), proxs_cos_min=ndarray_to_proto(agent_state.proxs_cos_min), diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index afd6c8f..ec684d7 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -89,10 +89,11 @@ message AgentState { NDArray behavior = 4; NDArray wheel_diameter = 5; NDArray speed_mul = 6; - NDArray theta_mul = 7; - NDArray proxs_dist_max = 8; - NDArray proxs_cos_min = 9; - NDArray color = 10; + NDArray max_speed = 7; + NDArray theta_mul = 8; + NDArray proxs_dist_max = 9; + NDArray proxs_cos_min = 10; + NDArray color = 11; } message ObjectState { diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index c6dcdcc..3ff2094 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe5\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe5\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -34,17 +34,17 @@ _globals['_NVESTATE']._serialized_start=691 _globals['_NVESTATE']._serialized_end=1047 _globals['_AGENTSTATE']._serialized_start=1050 - _globals['_AGENTSTATE']._serialized_end=1450 - _globals['_OBJECTSTATE']._serialized_start=1452 - _globals['_OBJECTSTATE']._serialized_end=1579 - _globals['_STATE']._serialized_start=1582 - _globals['_STATE']._serialized_end=1771 - _globals['_STATECHANGE']._serialized_start=1773 - _globals['_STATECHANGE']._serialized_end=1877 - _globals['_ADDAGENTINPUT']._serialized_start=1879 - _globals['_ADDAGENTINPUT']._serialized_end=1939 - _globals['_ISSTARTEDSTATE']._serialized_start=1941 - _globals['_ISSTARTEDSTATE']._serialized_end=1977 - _globals['_SIMULATORSERVER']._serialized_start=1980 - _globals['_SIMULATORSERVER']._serialized_end=2546 + _globals['_AGENTSTATE']._serialized_end=1489 + _globals['_OBJECTSTATE']._serialized_start=1491 + _globals['_OBJECTSTATE']._serialized_end=1618 + _globals['_STATE']._serialized_start=1621 + _globals['_STATE']._serialized_end=1810 + _globals['_STATECHANGE']._serialized_start=1812 + _globals['_STATECHANGE']._serialized_end=1916 + _globals['_ADDAGENTINPUT']._serialized_start=1918 + _globals['_ADDAGENTINPUT']._serialized_end=1978 + _globals['_ISSTARTEDSTATE']._serialized_start=1980 + _globals['_ISSTARTEDSTATE']._serialized_end=2016 + _globals['_SIMULATORSERVER']._serialized_start=2019 + _globals['_SIMULATORSERVER']._serialized_end=2585 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 26b6a56..fcf56fe 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -77,13 +77,14 @@ class NVEState(_message.Message): def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AgentState(_message.Message): - __slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color") + __slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color") NVE_IDX_FIELD_NUMBER: _ClassVar[int] PROX_FIELD_NUMBER: _ClassVar[int] MOTOR_FIELD_NUMBER: _ClassVar[int] BEHAVIOR_FIELD_NUMBER: _ClassVar[int] WHEEL_DIAMETER_FIELD_NUMBER: _ClassVar[int] SPEED_MUL_FIELD_NUMBER: _ClassVar[int] + MAX_SPEED_FIELD_NUMBER: _ClassVar[int] THETA_MUL_FIELD_NUMBER: _ClassVar[int] PROXS_DIST_MAX_FIELD_NUMBER: _ClassVar[int] PROXS_COS_MIN_FIELD_NUMBER: _ClassVar[int] @@ -94,11 +95,12 @@ class AgentState(_message.Message): behavior: NDArray wheel_diameter: NDArray speed_mul: NDArray + max_speed: NDArray theta_mul: NDArray proxs_dist_max: NDArray proxs_cos_min: NDArray color: NDArray - def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... + def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class ObjectState(_message.Message): __slots__ = ("nve_idx", "custom_field", "color") diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index 243c68a..3422310 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -138,6 +138,8 @@ def motor_force(state, exists_mask): cur_vel = state.nve_state.momentum.center[agent_idx] / state.nve_state.mass.center[agent_idx] cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) + # `a_max` arg is deprecated in recent versions of jax, replaced by `max` + cur_fwd_vel = jnp.clip(cur_fwd_vel, a_max=state.agent_state.max_speed) cur_rot_vel = state.nve_state.momentum.orientation[agent_idx] / state.nve_state.mass.orientation[agent_idx] fwd_delta = fwd - cur_fwd_vel diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index b53c9da..cf6e043 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -52,6 +52,7 @@ class AgentState: behavior: util.Array wheel_diameter: util.Array speed_mul: util.Array + max_speed: util.Array theta_mul: util.Array proxs_dist_max: util.Array proxs_cos_min: util.Array @@ -256,6 +257,7 @@ def init_agent_state( behavior: int = 1, wheel_diameter: float = 2., speed_mul: float = 1., + max_speed: float = 10., theta_mul: float = 1., prox_dist_max: float = 40., prox_cos_min: float = 0., @@ -273,6 +275,7 @@ def init_agent_state( behavior=jnp.full((n_agents), behavior), wheel_diameter=jnp.full((n_agents), wheel_diameter), speed_mul=jnp.full((n_agents), speed_mul), + max_speed=jnp.full((n_agents), max_speed), theta_mul=jnp.full((n_agents), theta_mul), proxs_dist_max=jnp.full((n_agents), prox_dist_max), proxs_cos_min=jnp.full((n_agents), prox_cos_min), From 99a8dbe65ea99e4741b03d8641df23d6ec72e48b Mon Sep 17 00:00:00 2001 From: Marsolo1 Date: Thu, 11 Apr 2024 14:48:38 +0200 Subject: [PATCH 2/5] finished max_speed --- vivarium/simulator/sim_computation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index 3422310..6b7ec66 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -135,11 +135,11 @@ def motor_force(state, exists_mask): state.nve_state.diameter[agent_idx], state.agent_state.wheel_diameter ) + # `a_max` arg is deprecated in recent versions of jax, replaced by `max` + fwd = jnp.clip(fwd, a_max=state.agent_state.max_speed) cur_vel = state.nve_state.momentum.center[agent_idx] / state.nve_state.mass.center[agent_idx] cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) - # `a_max` arg is deprecated in recent versions of jax, replaced by `max` - cur_fwd_vel = jnp.clip(cur_fwd_vel, a_max=state.agent_state.max_speed) cur_rot_vel = state.nve_state.momentum.orientation[agent_idx] / state.nve_state.mass.orientation[agent_idx] fwd_delta = fwd - cur_fwd_vel From a87ebdb80d81e83d9f1f7541d720ec7645ded3ef Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 11 Apr 2024 16:48:46 +0200 Subject: [PATCH 3/5] Fix auto-merge problems --- vivarium/controllers/converters.py | 24 +- .../simulator/grpc_server/simulator_pb2.py | 36 +-- vivarium/simulator/physics_engine.py | 4 + vivarium/simulator/sim_computation.py | 295 ------------------ vivarium/simulator/states.py | 2 +- 5 files changed, 35 insertions(+), 326 deletions(-) delete mode 100644 vivarium/simulator/sim_computation.py diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index d644169..82c7654 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -117,18 +117,18 @@ def get_default_state(n_entities_dict): friction=jnp.zeros(n_entities), exists=jnp.ones(n_entities, dtype=int) ), - agent_state=AgentState(nve_idx=jnp.zeros(n_agents, dtype=int), - prox=jnp.zeros((n_agents, 2)), - motor=jnp.zeros((n_agents, 2)), - behavior=jnp.zeros(n_agents, dtype=int), - wheel_diameter=jnp.zeros(n_agents), - speed_mul=jnp.zeros(n_agents), - max_speed=jnp.zeros(n_agents), - theta_mul=jnp.zeros(n_agents), - proxs_dist_max=jnp.zeros(n_agents), - proxs_cos_min=jnp.zeros(n_agents), - color=jnp.zeros((n_agents, 3))), - object_state=ObjectState(nve_idx=jnp.zeros(n_objects, dtype=int), color=jnp.zeros((n_objects, 3)))) + agent_state=AgentState(nve_idx=jnp.zeros(max_agents, dtype=int), + prox=jnp.zeros((max_agents, 2)), + motor=jnp.zeros((max_agents, 2)), + behavior=jnp.zeros(max_agents, dtype=int), + wheel_diameter=jnp.zeros(max_agents), + speed_mul=jnp.zeros(max_agents), + max_speed=jnp.zeros(max_agents), + theta_mul=jnp.zeros(max_agents), + proxs_dist_max=jnp.zeros(max_agents), + proxs_cos_min=jnp.zeros(max_agents), + color=jnp.zeros((max_agents, 3))), + object_state=ObjectState(nve_idx=jnp.zeros(max_objects, dtype=int), color=jnp.zeros((max_objects, 3)))) EntitiesTuple = namedtuple('EntitiesTuple', ['idx', 'col', 'val']) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index 17bc61a..e459f53 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe5\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xbb\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -30,21 +30,21 @@ _globals['_RIGIDBODY']._serialized_start=112 _globals['_RIGIDBODY']._serialized_end=200 _globals['_SIMULATORSTATE']._serialized_start=203 - _globals['_SIMULATORSTATE']._serialized_end=688 - _globals['_NVESTATE']._serialized_start=691 - _globals['_NVESTATE']._serialized_end=1047 - _globals['_AGENTSTATE']._serialized_start=1050 - _globals['_AGENTSTATE']._serialized_end=1450 - _globals['_OBJECTSTATE']._serialized_start=1452 - _globals['_OBJECTSTATE']._serialized_end=1579 - _globals['_STATE']._serialized_start=1582 - _globals['_STATE']._serialized_end=1771 - _globals['_STATECHANGE']._serialized_start=1773 - _globals['_STATECHANGE']._serialized_end=1877 - _globals['_ADDAGENTINPUT']._serialized_start=1879 - _globals['_ADDAGENTINPUT']._serialized_end=1939 - _globals['_ISSTARTEDSTATE']._serialized_start=1941 - _globals['_ISSTARTEDSTATE']._serialized_end=1977 - _globals['_SIMULATORSERVER']._serialized_start=1980 - _globals['_SIMULATORSERVER']._serialized_end=2546 + _globals['_SIMULATORSTATE']._serialized_end=692 + _globals['_ENTITIESSTATE']._serialized_start=695 + _globals['_ENTITIESSTATE']._serialized_end=1056 + _globals['_AGENTSTATE']._serialized_start=1059 + _globals['_AGENTSTATE']._serialized_end=1498 + _globals['_OBJECTSTATE']._serialized_start=1500 + _globals['_OBJECTSTATE']._serialized_end=1627 + _globals['_STATE']._serialized_start=1630 + _globals['_STATE']._serialized_end=1829 + _globals['_STATECHANGE']._serialized_start=1831 + _globals['_STATECHANGE']._serialized_end=1935 + _globals['_ADDAGENTINPUT']._serialized_start=1937 + _globals['_ADDAGENTINPUT']._serialized_end=1999 + _globals['_ISSTARTEDSTATE']._serialized_start=2001 + _globals['_ISSTARTEDSTATE']._serialized_end=2037 + _globals['_SIMULATORSERVER']._serialized_start=2040 + _globals['_SIMULATORSERVER']._serialized_end=2611 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/physics_engine.py b/vivarium/simulator/physics_engine.py index 8fa4042..62a431b 100644 --- a/vivarium/simulator/physics_engine.py +++ b/vivarium/simulator/physics_engine.py @@ -135,6 +135,10 @@ def motor_force(state, exists_mask): state.entities_state.diameter[agent_idx], state.agent_state.wheel_diameter ) + # `a_max` arg is deprecated in recent versions of jax, replaced by `max` + fwd = jnp.clip(fwd, a_max=state.agent_state.max_speed) + jax.debug.print("max_speed {max_speed}", max_speed=state.agent_state.max_speed) + jax.debug.print("fwd {fwd}", fwd=fwd) cur_vel = state.entities_state.momentum.center[agent_idx] / state.entities_state.mass.center[agent_idx] cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py deleted file mode 100644 index f052633..0000000 --- a/vivarium/simulator/sim_computation.py +++ /dev/null @@ -1,295 +0,0 @@ -from functools import partial - -import jax -import jax.numpy as jnp - -from jax import ops, vmap, lax -from jax_md import space, rigid_body, util, simulate, energy, quantity -f32 = util.f32 - - -# Only work on 2D environments atm -SPACE_NDIMS = 2 - -@vmap -def normal(theta): - return jnp.array([jnp.cos(theta), jnp.sin(theta)]) - -def switch_fn(fn_list): - def switch(index, *operands): - return jax.lax.switch(index, fn_list, *operands) - return switch - - -# Helper functions for collisions - -def collision_energy(displacement_fn, r_a, r_b, l_a, l_b, epsilon, alpha, mask): - """Compute the collision energy between a pair of particles - - :param displacement_fn: displacement function of jax_md - :param r_a: position of particle a - :param r_b: position of particle b - :param l_a: diameter of particle a - :param l_b: diameter of particle b - :param epsilon: interaction energy scale - :param alpha: interaction stiffness - :param mask: set the energy to 0 if one of the particles is masked - :return: collision energy between both particles - """ - dist = jnp.linalg.norm(displacement_fn(r_a, r_b)) - sigma = (l_a + l_b) / 2 - e = energy.soft_sphere(dist, sigma=sigma, epsilon=epsilon, alpha=f32(alpha)) - return jnp.where(mask, e, 0.) - -collision_energy = vmap(collision_energy, (None, 0, 0, 0, 0, None, None, 0)) - - -def total_collision_energy(positions, diameter, neighbor, displacement, exists_mask, epsilon, alpha): - """Compute the collision energy between all neighboring pairs of particles in the system - - :param positions: positions of all the particles - :param diameter: diameters of all the particles - :param neighbor: neighbor array of the system - :param displacement: dipalcement function of jax_md - :param exists_mask: mask to specify which particles exist - :param epsilon: interaction energy scale between two particles - :param alpha: interaction stiffness between two particles - :return: sum of all collisions energies of the system - """ - diameter = lax.stop_gradient(diameter) - senders, receivers = neighbor.idx - - r_senders = positions[senders] - r_receivers = positions[receivers] - l_senders = diameter[senders] - l_receivers = diameter[receivers] - - # Set collision energy to zero if the sender or receiver is non existing - mask = exists_mask[senders] * exists_mask[receivers] - energies = collision_energy(displacement, - r_senders, - r_receivers, - l_senders, - l_receivers, - epsilon, - alpha, - mask) - return jnp.sum(energies) - - -# Helper functions for motor function - -def lr_2_fwd_rot(left_spd, right_spd, base_length, wheel_diameter): - fwd = (wheel_diameter / 4.) * (left_spd + right_spd) - rot = 0.5 * (wheel_diameter / base_length) * (right_spd - left_spd) - return fwd, rot - - -def fwd_rot_2_lr(fwd, rot, base_length, wheel_diameter): - left = ((2.0 * fwd) - (rot * base_length)) / wheel_diameter - right = ((2.0 * fwd) + (rot * base_length)) / wheel_diameter - return left, right - - -def motor_command(wheel_activation, base_length, wheel_diameter): - fwd, rot = lr_2_fwd_rot(wheel_activation[0], wheel_activation[1], base_length, wheel_diameter) - return fwd, rot - -motor_command = vmap(motor_command, (0, 0, 0)) - - -# Functions to compute the verlet force on the whole system - -def get_verlet_force_fn(displacement): - coll_force_fn = quantity.force(partial(total_collision_energy, displacement=displacement)) - - def collision_force(state, neighbor, exists_mask): - return coll_force_fn( - state.entities_state.position.center, - neighbor=neighbor, - exists_mask=exists_mask, - diameter=state.entities_state.diameter, - epsilon=state.simulator_state.collision_eps, - alpha=state.simulator_state.collision_alpha - ) - - def friction_force(state, exists_mask): - cur_vel = state.entities_state.momentum.center / state.entities_state.mass.center - # stack the mask to give it the same shape as cur_vel (that has 2 rows for forward and angular velocities) - mask = jnp.stack([exists_mask] * 2, axis=1) - cur_vel = jnp.where(mask, cur_vel, 0.) - return - jnp.tile(state.entities_state.friction, (SPACE_NDIMS, 1)).T * cur_vel - - def motor_force(state, exists_mask): - agent_idx = state.agent_state.nve_idx - - body = rigid_body.RigidBody( - center=state.entities_state.position.center[agent_idx], - orientation=state.entities_state.position.orientation[agent_idx] - ) - - n = normal(body.orientation) - - fwd, rot = motor_command( - state.agent_state.motor, - state.entities_state.diameter[agent_idx], - state.agent_state.wheel_diameter - ) - # `a_max` arg is deprecated in recent versions of jax, replaced by `max` - fwd = jnp.clip(fwd, a_max=state.agent_state.max_speed) - - cur_vel = state.entities_state.momentum.center[agent_idx] / state.entities_state.mass.center[agent_idx] - cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) - cur_rot_vel = state.entities_state.momentum.orientation[agent_idx] / state.entities_state.mass.orientation[agent_idx] - - fwd_delta = fwd - cur_fwd_vel - rot_delta = rot - cur_rot_vel - - fwd_force = n * jnp.tile(fwd_delta, (SPACE_NDIMS, 1)).T * jnp.tile(state.agent_state.speed_mul, (SPACE_NDIMS, 1)).T - rot_force = rot_delta * state.agent_state.theta_mul - - center=jnp.zeros_like(state.entities_state.position.center).at[agent_idx].set(fwd_force) - orientation=jnp.zeros_like(state.entities_state.position.orientation).at[agent_idx].set(rot_force) - - # apply mask to make non existing agents stand still - orientation = jnp.where(exists_mask, orientation, 0.) - # Because position has SPACE_NDMS dims, need to stack the mask to give it the same shape as center - exists_mask = jnp.stack([exists_mask] * SPACE_NDIMS, axis=1) - center = jnp.where(exists_mask, center, 0.) - - - return rigid_body.RigidBody(center=center, - orientation=orientation) - - def force_fn(state, neighbor, exists_mask): - mf = motor_force(state, exists_mask) - cf = collision_force(state, neighbor, exists_mask) - ff = friction_force(state, exists_mask) - - center = cf + ff + mf.center - orientation = mf.orientation - return rigid_body.RigidBody(center=center, orientation=orientation) - - return force_fn - - -# Helper functions for sensors - -def dist_theta(displ, theta): - """ - Compute the relative distance and angle from a source agent to a target agent - :param displ: Displacement vector (jnp arrray with shape (2,) from source to target - :param theta: Orientation of the source agent (in the reference frame of the map) - :return: dist: distance from source to target. - relative_theta: relative angle of the target in the reference frame of the source agent (front direction at angle 0) - """ - dist = jnp.linalg.norm(displ) - norm_displ = displ / dist - theta_displ = jnp.arccos(norm_displ[0]) * jnp.sign(jnp.arcsin(norm_displ[1])) - relative_theta = theta_displ - theta - return dist, relative_theta - -proximity_map = vmap(dist_theta, (0, 0)) - - -def sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists): - """ - Compute the proximeter activations (left, right) induced by the presence of an entity - :param dist: distance from the agent to the entity - :param relative_theta: angle of the entity in the reference frame of the agent (front direction at angle 0) - :param dist_max: Max distance of the proximiter (will return 0. above this distance) - :param cos_min: Field of view as a cosinus (e.g. cos_min = 0 means a pi/4 FoV on each proximeter, so pi/2 in total) - :return: left and right proximeter activation in a jnp array with shape (2,) - """ - cos_dir = jnp.cos(relative_theta) - prox = 1. - (dist / dist_max) - in_view = jnp.logical_and(dist < dist_max, cos_dir > cos_min) - at_left = jnp.logical_and(True, jnp.sin(relative_theta) >= 0) - left = in_view * at_left * prox - right = in_view * (1. - at_left) * prox - return jnp.array([left, right]) * target_exists # i.e. 0 if target does not exist - -sensor_fn = vmap(sensor_fn, (0, 0, 0, 0, 0)) - - -def sensor(displ, theta, dist_max, cos_min, max_agents, senders, target_exists): - dist, relative_theta = proximity_map(displ, theta) - proxs = ops.segment_max(sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists), - senders, max_agents) - return proxs - - -# Functions to compute the dynamics of the whole system - -def dynamics_rigid(displacement, shift, behavior_bank, force_fn=None): - force_fn = force_fn or get_verlet_force_fn(displacement) - multi_switch = jax.vmap(switch_fn(behavior_bank), (0, 0, 0)) - # shape = rigid_body.monomer - - def init_fn(state, key, kT=0.): - key, _ = jax.random.split(key) - assert state.entities_state.momentum is None - assert not jnp.any(state.entities_state.force.center) and not jnp.any(state.entities_state.force.orientation) - - state = state.set(entities_state=simulate.initialize_momenta(state.entities_state, key, kT)) - return state - - def mask_momentum(entities_state, exists_mask): - """ - Set the momentum values to zeros for non existing entities - :param entities_state: entities_state - :param exists_mask: bool array specifying which entities exist or not - :return: entities_state: new entities state state with masked momentum values - """ - orientation = jnp.where(exists_mask, entities_state.momentum.orientation, 0) - exists_mask = jnp.stack([exists_mask] * SPACE_NDIMS, axis=1) - center = jnp.where(exists_mask, entities_state.momentum.center, 0) - momentum = rigid_body.RigidBody(center=center, orientation=orientation) - return entities_state.set(momentum=momentum) - - def physics_fn(state, force, shift_fn, dt, neighbor, mask): - """Apply a single step of velocity Verlet integration to a state.""" - # dt = f32(dt) - dt_2 = dt / 2. # f32(dt / 2) - # state = sensorimotor(state, neighbor) # now in step_fn - entities_state = simulate.momentum_step(state.entities_state, dt_2) - entities_state = simulate.position_step(entities_state, shift_fn, dt, neighbor=neighbor) - entities_state = entities_state.set(force=force) - entities_state = simulate.momentum_step(entities_state, dt_2) - entities_state = mask_momentum(entities_state, mask) - - return state.set(entities_state=entities_state) - - def compute_prox(state, agent_neighs_idx, target_exists_mask): - """ - Set agents' proximeter activations - :param state: full simulation State - :param agent_neighs_idx: Neighbor representation, where sources are only agents. Matrix of shape (2, n_pairs), - where n_pairs is the number of neighbor entity pairs where sources (first row) are agent indexes. - :param target_exists_mask: Specify which target entities exist. Vector with shape (n_entities,). - target_exists_mask[i] is True (resp. False) if entity of index i in state.entities_state exists (resp. don't exist). - :return: - """ - body = state.entities_state.position - mask = target_exists_mask[agent_neighs_idx[1, :]] - senders, receivers = agent_neighs_idx - Ra = body.center[senders] - Rb = body.center[receivers] - dR = - space.map_bond(displacement)(Ra, Rb) # Looks like it should be opposite, but don't understand why - prox = sensor(dR, body.orientation[senders], state.agent_state.proxs_dist_max[senders], - state.agent_state.proxs_cos_min[senders], len(state.agent_state.nve_idx), senders, mask) - return state.agent_state.set(prox=prox) - - def sensorimotor(agent_state): - motor = multi_switch(agent_state.behavior, agent_state.prox, agent_state.motor) - return agent_state.set(motor=motor) - - def step_fn(state, neighbor, agent_neighs_idx): - exists_mask = (state.entities_state.exists == 1) # Only existing entities have effect on others - state = state.set(agent_state=compute_prox(state, agent_neighs_idx, target_exists_mask=exists_mask)) - state = state.set(agent_state=sensorimotor(state.agent_state)) - force = force_fn(state, neighbor, exists_mask) - state = physics_fn(state, force, shift, state.simulator_state.dt[0], neighbor=neighbor, mask=exists_mask) - return state - - return init_fn, step_fn diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 22f6606..1df4ad0 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -158,7 +158,7 @@ def init_simulator_state( """ return SimulatorState( idx=jnp.array([0]), - box_size=jnp.array([box_size]), + box_size=jnp.array([box_size]), max_agents=jnp.array([max_agents]), max_objects=jnp.array([max_objects]), num_steps_lax=jnp.array([num_steps_lax], dtype=int), From 56751ee96d7d783430a57e792e44da0ce2a55128 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 11 Apr 2024 16:50:35 +0200 Subject: [PATCH 4/5] Remove debug print lines --- vivarium/simulator/physics_engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vivarium/simulator/physics_engine.py b/vivarium/simulator/physics_engine.py index 62a431b..f052633 100644 --- a/vivarium/simulator/physics_engine.py +++ b/vivarium/simulator/physics_engine.py @@ -137,8 +137,6 @@ def motor_force(state, exists_mask): ) # `a_max` arg is deprecated in recent versions of jax, replaced by `max` fwd = jnp.clip(fwd, a_max=state.agent_state.max_speed) - jax.debug.print("max_speed {max_speed}", max_speed=state.agent_state.max_speed) - jax.debug.print("fwd {fwd}", fwd=fwd) cur_vel = state.entities_state.momentum.center[agent_idx] / state.entities_state.mass.center[agent_idx] cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) From d1498a8ed54867adc8e34ed1cbdf420d4dedb514 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 11 Apr 2024 16:53:11 +0200 Subject: [PATCH 5/5] Rename sim computation physics engine in tests --- tests/test_simulator_init.py | 2 +- tests/test_simulator_run.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index fa841cb..5b6447b 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -5,7 +5,7 @@ from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator -from vivarium.simulator.sim_computation import dynamics_rigid +from vivarium.simulator.physics_engine import dynamics_rigid def test_init_simulator_no_args(): diff --git a/tests/test_simulator_run.py b/tests/test_simulator_run.py index 6d12175..f018afd 100644 --- a/tests/test_simulator_run.py +++ b/tests/test_simulator_run.py @@ -5,7 +5,7 @@ from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator -from vivarium.simulator.sim_computation import dynamics_rigid +from vivarium.simulator.physics_engine import dynamics_rigid NUM_STEPS = 50