diff --git a/requirements.txt b/requirements.txt index 47b6f12..a000c53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,8 @@ param==2.0.2 matplotlib==3.8.2 # Client-server communication -grpcio==1.60.0 +grpcio +grpcio-tools # Tests and code formatting pytest==8.0.0 diff --git a/scripts/run_server.py b/scripts/run_server.py index e59ce50..b340b97 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -26,6 +26,8 @@ def parse_args(): parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation') parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop') parser.add_argument('--log_level', type=str, default='INFO', help='Logging level') + parser.add_argument('--collision_eps', type=float, required=False, default=0.1) + parser.add_argument('--collision_alpha', type=float, required=False, default=0.5) return parser.parse_args() @@ -44,7 +46,9 @@ def parse_args(): freq=args.freq, neighbor_radius=args.neighbor_radius, to_jit=args.to_jit, - use_fori_loop=args.use_fori_loop + use_fori_loop=args.use_fori_loop, + collision_eps=args.collision_eps, + collision_alpha=args.collision_alpha ) agent_configs = [AgentConfig(idx=i, @@ -65,5 +69,6 @@ def parse_args(): }) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + lg.info('Simulator server started') serve(simulator) diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index e9ba40d..ca64892 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -14,7 +14,8 @@ def parse_args(): parser = argparse.ArgumentParser(description='Simulator Configuration') # Experiment run arguments - parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation loops') + parser.add_argument('--log_level', type=str, default='INFO', help='Logging level') + parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation steps') # Simulator config arguments parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') parser.add_argument('--n_agents', type=int, default=10, help='Number of agents') @@ -26,7 +27,8 @@ def parse_args(): # By default jit compile the code and use normal python loops parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation') parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop') - parser.add_argument('--log_level', type=str, default='INFO', help='Logging level') + parser.add_argument('--collision_eps', type=float, required=False, default=0.3) + parser.add_argument('--collision_alpha', type=float, required=False, default=0.7) return parser.parse_args() @@ -45,7 +47,9 @@ def parse_args(): freq=args.freq, neighbor_radius=args.neighbor_radius, to_jit=args.to_jit, - use_fori_loop=args.use_fori_loop + use_fori_loop=args.use_fori_loop, + collision_eps=args.collision_eps, + collision_alpha=args.collision_alpha ) agent_configs = [ diff --git a/vivarium/controllers/config.py b/vivarium/controllers/config.py index 73c2a19..5943517 100644 --- a/vivarium/controllers/config.py +++ b/vivarium/controllers/config.py @@ -65,7 +65,7 @@ class ObjectConfig(Config): mass_orientation = param.Number(mass_orientation) diameter = param.Number(5.) color = param.Color('red') - friction = param.Number(10.) + friction = param.Number(0.1) exists = param.Boolean(True) def __init__(self, **params): @@ -83,6 +83,8 @@ class SimulatorConfig(Config): neighbor_radius = param.Number(100., bounds=(0, None)) to_jit = param.Boolean(True) use_fori_loop = param.Boolean(False) + collision_eps = param.Number(0.1) + collision_alpha = param.Number(0.5) def __init__(self, **params): super().__init__(**params) diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index e3d91ae..69b485d 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -25,7 +25,7 @@ object_state_fields = [f.name for f in jax_md.dataclasses.fields(ObjectState)] object_common_fields = [f for f in object_config_fields if f in object_state_fields] -# + simulator_config_fields = SimulatorConfig.param.objects().keys() simulator_state_fields = [f.name for f in jax_md.dataclasses.fields(SimulatorState)] @@ -106,7 +106,9 @@ def get_default_state(n_entities_dict): n_agents=jnp.array([n_agents]), n_objects=jnp.array([n_objects]), num_steps_lax=jnp.array([1]), dt=jnp.array([1.]), freq=jnp.array([1.]), neighbor_radius=jnp.array([1.]), - to_jit= jnp.array([1]), use_fori_loop=jnp.array([0])), + to_jit= jnp.array([1]), use_fori_loop=jnp.array([0]), + collision_alpha=jnp.array([0.]), + collision_eps=jnp.array([0.])), nve_state=NVEState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), momentum=None, force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index 5289db8..11c2683 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -13,16 +13,18 @@ def proto_to_state(state): def proto_to_simulator_state(simulator_state): - return SimulatorState(idx=proto_to_ndarray(simulator_state.idx), - box_size=proto_to_ndarray(simulator_state.box_size), - n_agents=proto_to_ndarray(simulator_state.n_agents), - n_objects=proto_to_ndarray(simulator_state.n_objects), - num_steps_lax=proto_to_ndarray(simulator_state.num_steps_lax), - dt=proto_to_ndarray(simulator_state.dt), - freq=proto_to_ndarray(simulator_state.freq), - neighbor_radius=proto_to_ndarray(simulator_state.neighbor_radius), - to_jit=proto_to_ndarray(simulator_state.to_jit), - use_fori_loop=proto_to_ndarray(simulator_state.use_fori_loop) + return SimulatorState(idx=proto_to_ndarray(simulator_state.idx).astype(int), + box_size=proto_to_ndarray(simulator_state.box_size).astype(float), + n_agents=proto_to_ndarray(simulator_state.n_agents).astype(int), + n_objects=proto_to_ndarray(simulator_state.n_objects).astype(int), + num_steps_lax=proto_to_ndarray(simulator_state.num_steps_lax).astype(int), + dt=proto_to_ndarray(simulator_state.dt).astype(float), + freq=proto_to_ndarray(simulator_state.freq).astype(float), + neighbor_radius=proto_to_ndarray(simulator_state.neighbor_radius).astype(float), + to_jit=proto_to_ndarray(simulator_state.to_jit).astype(int), + use_fori_loop=proto_to_ndarray(simulator_state.use_fori_loop).astype(int), + collision_eps=proto_to_ndarray(simulator_state.collision_eps).astype(float), + collision_alpha=proto_to_ndarray(simulator_state.collision_alpha).astype(float) ) @@ -81,7 +83,9 @@ def simulator_state_to_proto(simulator_state): freq=ndarray_to_proto(simulator_state.freq), neighbor_radius=ndarray_to_proto(simulator_state.neighbor_radius), to_jit=ndarray_to_proto(simulator_state.to_jit), - use_fori_loop=ndarray_to_proto(simulator_state.use_fori_loop) + use_fori_loop=ndarray_to_proto(simulator_state.use_fori_loop), + collision_eps=ndarray_to_proto(simulator_state.collision_eps), + collision_alpha=ndarray_to_proto(simulator_state.collision_alpha) ) diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index ab03f90..afd6c8f 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -66,6 +66,8 @@ message SimulatorState { NDArray neighbor_radius = 8; NDArray to_jit = 9; NDArray use_fori_loop = 10; + NDArray collision_eps = 11; + NDArray collision_alpha = 12; } message NVEState { diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index d9b51c9..c6dcdcc 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: simulator.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -14,14 +15,14 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\x8d\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe5\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'simulator_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\032io.grpc.examples.simulatorB\016SimulatorProtoP\001\242\002\003SIM' + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\032io.grpc.examples.simulatorB\016SimulatorProtoP\001\242\002\003SIM' _globals['_AGENTIDX']._serialized_start=59 _globals['_AGENTIDX']._serialized_end=82 _globals['_NDARRAY']._serialized_start=84 @@ -29,21 +30,21 @@ _globals['_RIGIDBODY']._serialized_start=112 _globals['_RIGIDBODY']._serialized_end=200 _globals['_SIMULATORSTATE']._serialized_start=203 - _globals['_SIMULATORSTATE']._serialized_end=600 - _globals['_NVESTATE']._serialized_start=603 - _globals['_NVESTATE']._serialized_end=959 - _globals['_AGENTSTATE']._serialized_start=962 - _globals['_AGENTSTATE']._serialized_end=1362 - _globals['_OBJECTSTATE']._serialized_start=1364 - _globals['_OBJECTSTATE']._serialized_end=1491 - _globals['_STATE']._serialized_start=1494 - _globals['_STATE']._serialized_end=1683 - _globals['_STATECHANGE']._serialized_start=1685 - _globals['_STATECHANGE']._serialized_end=1789 - _globals['_ADDAGENTINPUT']._serialized_start=1791 - _globals['_ADDAGENTINPUT']._serialized_end=1851 - _globals['_ISSTARTEDSTATE']._serialized_start=1853 - _globals['_ISSTARTEDSTATE']._serialized_end=1889 - _globals['_SIMULATORSERVER']._serialized_start=1892 - _globals['_SIMULATORSERVER']._serialized_end=2458 + _globals['_SIMULATORSTATE']._serialized_end=688 + _globals['_NVESTATE']._serialized_start=691 + _globals['_NVESTATE']._serialized_end=1047 + _globals['_AGENTSTATE']._serialized_start=1050 + _globals['_AGENTSTATE']._serialized_end=1450 + _globals['_OBJECTSTATE']._serialized_start=1452 + _globals['_OBJECTSTATE']._serialized_end=1579 + _globals['_STATE']._serialized_start=1582 + _globals['_STATE']._serialized_end=1771 + _globals['_STATECHANGE']._serialized_start=1773 + _globals['_STATECHANGE']._serialized_end=1877 + _globals['_ADDAGENTINPUT']._serialized_start=1879 + _globals['_ADDAGENTINPUT']._serialized_end=1939 + _globals['_ISSTARTEDSTATE']._serialized_start=1941 + _globals['_ISSTARTEDSTATE']._serialized_end=1977 + _globals['_SIMULATORSERVER']._serialized_start=1980 + _globals['_SIMULATORSERVER']._serialized_end=2546 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 7564e20..26b6a56 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -7,19 +7,19 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map DESCRIPTOR: _descriptor.FileDescriptor class AgentIdx(_message.Message): - __slots__ = ["idx"] + __slots__ = ("idx",) IDX_FIELD_NUMBER: _ClassVar[int] idx: _containers.RepeatedScalarFieldContainer[int] def __init__(self, idx: _Optional[_Iterable[int]] = ...) -> None: ... class NDArray(_message.Message): - __slots__ = ["ndarray"] + __slots__ = ("ndarray",) NDARRAY_FIELD_NUMBER: _ClassVar[int] ndarray: bytes def __init__(self, ndarray: _Optional[bytes] = ...) -> None: ... class RigidBody(_message.Message): - __slots__ = ["center", "orientation"] + __slots__ = ("center", "orientation") CENTER_FIELD_NUMBER: _ClassVar[int] ORIENTATION_FIELD_NUMBER: _ClassVar[int] center: NDArray @@ -27,7 +27,7 @@ class RigidBody(_message.Message): def __init__(self, center: _Optional[_Union[NDArray, _Mapping]] = ..., orientation: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class SimulatorState(_message.Message): - __slots__ = ["idx", "box_size", "n_agents", "n_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop"] + __slots__ = ("idx", "box_size", "n_agents", "n_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha") IDX_FIELD_NUMBER: _ClassVar[int] BOX_SIZE_FIELD_NUMBER: _ClassVar[int] N_AGENTS_FIELD_NUMBER: _ClassVar[int] @@ -38,6 +38,8 @@ class SimulatorState(_message.Message): NEIGHBOR_RADIUS_FIELD_NUMBER: _ClassVar[int] TO_JIT_FIELD_NUMBER: _ClassVar[int] USE_FORI_LOOP_FIELD_NUMBER: _ClassVar[int] + COLLISION_EPS_FIELD_NUMBER: _ClassVar[int] + COLLISION_ALPHA_FIELD_NUMBER: _ClassVar[int] idx: NDArray box_size: NDArray n_agents: NDArray @@ -48,10 +50,12 @@ class SimulatorState(_message.Message): neighbor_radius: NDArray to_jit: NDArray use_fori_loop: NDArray - def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., n_agents: _Optional[_Union[NDArray, _Mapping]] = ..., n_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... + collision_eps: NDArray + collision_alpha: NDArray + def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., n_agents: _Optional[_Union[NDArray, _Mapping]] = ..., n_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ..., collision_eps: _Optional[_Union[NDArray, _Mapping]] = ..., collision_alpha: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class NVEState(_message.Message): - __slots__ = ["position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists"] + __slots__ = ("position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists") POSITION_FIELD_NUMBER: _ClassVar[int] MOMENTUM_FIELD_NUMBER: _ClassVar[int] FORCE_FIELD_NUMBER: _ClassVar[int] @@ -73,7 +77,7 @@ class NVEState(_message.Message): def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AgentState(_message.Message): - __slots__ = ["nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color"] + __slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color") NVE_IDX_FIELD_NUMBER: _ClassVar[int] PROX_FIELD_NUMBER: _ClassVar[int] MOTOR_FIELD_NUMBER: _ClassVar[int] @@ -97,7 +101,7 @@ class AgentState(_message.Message): def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class ObjectState(_message.Message): - __slots__ = ["nve_idx", "custom_field", "color"] + __slots__ = ("nve_idx", "custom_field", "color") NVE_IDX_FIELD_NUMBER: _ClassVar[int] CUSTOM_FIELD_FIELD_NUMBER: _ClassVar[int] COLOR_FIELD_NUMBER: _ClassVar[int] @@ -107,7 +111,7 @@ class ObjectState(_message.Message): def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., custom_field: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class State(_message.Message): - __slots__ = ["simulator_state", "nve_state", "agent_state", "object_state"] + __slots__ = ("simulator_state", "nve_state", "agent_state", "object_state") SIMULATOR_STATE_FIELD_NUMBER: _ClassVar[int] NVE_STATE_FIELD_NUMBER: _ClassVar[int] AGENT_STATE_FIELD_NUMBER: _ClassVar[int] @@ -119,7 +123,7 @@ class State(_message.Message): def __init__(self, simulator_state: _Optional[_Union[SimulatorState, _Mapping]] = ..., nve_state: _Optional[_Union[NVEState, _Mapping]] = ..., agent_state: _Optional[_Union[AgentState, _Mapping]] = ..., object_state: _Optional[_Union[ObjectState, _Mapping]] = ...) -> None: ... class StateChange(_message.Message): - __slots__ = ["nve_idx", "col_idx", "nested_field", "value"] + __slots__ = ("nve_idx", "col_idx", "nested_field", "value") NVE_IDX_FIELD_NUMBER: _ClassVar[int] COL_IDX_FIELD_NUMBER: _ClassVar[int] NESTED_FIELD_FIELD_NUMBER: _ClassVar[int] @@ -131,7 +135,7 @@ class StateChange(_message.Message): def __init__(self, nve_idx: _Optional[_Iterable[int]] = ..., col_idx: _Optional[_Iterable[int]] = ..., nested_field: _Optional[_Iterable[str]] = ..., value: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AddAgentInput(_message.Message): - __slots__ = ["n_agents", "serialized_config"] + __slots__ = ("n_agents", "serialized_config") N_AGENTS_FIELD_NUMBER: _ClassVar[int] SERIALIZED_CONFIG_FIELD_NUMBER: _ClassVar[int] n_agents: int @@ -139,7 +143,7 @@ class AddAgentInput(_message.Message): def __init__(self, n_agents: _Optional[int] = ..., serialized_config: _Optional[str] = ...) -> None: ... class IsStartedState(_message.Message): - __slots__ = ["is_started"] + __slots__ = ("is_started",) IS_STARTED_FIELD_NUMBER: _ClassVar[int] is_started: bool def __init__(self, is_started: bool = ...) -> None: ... diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index 6f47701..d50e33f 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -7,12 +7,11 @@ from jax import ops, vmap, lax from jax_md import space, rigid_body, util, simulate, energy, quantity from jax_md.dataclasses import dataclass - - f32 = util.f32 SPACE_NDIMS = 2 + class EntityType(Enum): AGENT = 0 OBJECT = 1 @@ -20,6 +19,7 @@ class EntityType(Enum): def to_state_type(self): return StateType(self.value) + class StateType(Enum): AGENT = 0 OBJECT = 1 @@ -33,7 +33,6 @@ def to_entity_type(self): return EntityType(self.value) - @dataclass class NVEState(simulate.NVEState): entity_type: util.Array @@ -46,6 +45,7 @@ class NVEState(simulate.NVEState): def velocity(self) -> util.Array: return self.momentum / self.mass + @dataclass class AgentState: nve_idx: util.Array # idx in NVEState @@ -65,6 +65,7 @@ class ObjectState: nve_idx: util.Array # idx in NVEState color: util.Array + @dataclass class SimulatorState: idx: util.Array @@ -77,18 +78,21 @@ class SimulatorState: neighbor_radius: util.Array to_jit: util.Array use_fori_loop: util.Array + collision_alpha: util.Array + collision_eps: util.Array @staticmethod def get_type(attr): if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']: return int - elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius']: + elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius', 'collision_alpha', 'collision_eps']: return float elif attr in ['to_jit', 'use_fori_loop']: return bool else: raise ValueError() + @dataclass class State: simulator_state: SimulatorState @@ -109,7 +113,7 @@ def field(self, stype_or_nested_fields): return res - # Should we keep this function because it is duplicated below ? + # TODO : Should we keep this function because it is duplicated below ? # def nve_idx(self, etype): # cond = self.e_cond(etype) # return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list @@ -136,42 +140,73 @@ def wrapper(e_type): return value[self.e_cond(e_type)] return wrapper - +@vmap def normal(theta): return jnp.array([jnp.cos(theta), jnp.sin(theta)]) -normal = vmap(normal) - def switch_fn(fn_list): def switch(index, *operands): return jax.lax.switch(index, fn_list, *operands) return switch -""" Helper functions for collisions """ +# Helper functions for collisions + +def collision_energy(displacement_fn, r_a, r_b, l_a, l_b, epsilon, alpha, mask): + """Compute the collision energy between a pair of particles -def collision_energy(displacement_fn, r_a, r_b, l_a, l_b, epsilon, alpha): + :param displacement_fn: displacement function of jax_md + :param r_a: position of particle a + :param r_b: position of particle b + :param l_a: diameter of particle a + :param l_b: diameter of particle b + :param epsilon: interaction energy scale + :param alpha: interaction stiffness + :param mask: set the energy to 0 if one of the particles is masked + :return: collision energy between both particles + """ dist = jnp.linalg.norm(displacement_fn(r_a, r_b)) - sigma = l_a + l_b - return energy.soft_sphere(dist, sigma=sigma, epsilon=epsilon, alpha=alpha) + sigma = (l_a + l_b) / 2 + e = energy.soft_sphere(dist, sigma=sigma, epsilon=epsilon, alpha=f32(alpha)) + return jnp.where(mask, e, 0.) + +collision_energy = vmap(collision_energy, (None, 0, 0, 0, 0, None, None, 0)) -collision_energy = vmap(collision_energy, (None, 0, 0, 0, 0, None, None)) +def total_collision_energy(positions, diameter, neighbor, displacement, exists_mask, epsilon, alpha): + """Compute the collision energy between all neighboring pairs of particles in the system -def total_collision_energy(positions, diameter, neighbor, displacement, exists_mask, epsilon=1e-2, alpha=2, **kwargs): + :param positions: positions of all the particles + :param diameter: diameters of all the particles + :param neighbor: neighbor array of the system + :param displacement: dipalcement function of jax_md + :param exists_mask: mask to specify which particles exist + :param epsilon: interaction energy scale between two particles + :param alpha: interaction stiffness between two particles + :return: sum of all collisions energies of the system + """ diameter = lax.stop_gradient(diameter) senders, receivers = neighbor.idx - Ra = positions[senders] - Rb = positions[receivers] - l_a = diameter[senders] - l_b = diameter[receivers] - e = collision_energy(displacement, Ra, Rb, l_a, l_b, epsilon, alpha) + + r_senders = positions[senders] + r_receivers = positions[receivers] + l_senders = diameter[senders] + l_receivers = diameter[receivers] + # Set collision energy to zero if the sender or receiver is non existing - e = jnp.where(exists_mask[senders] * exists_mask[receivers], e, 0.) - return jnp.sum(e) + mask = exists_mask[senders] * exists_mask[receivers] + energies = collision_energy(displacement, + r_senders, + r_receivers, + l_senders, + l_receivers, + epsilon, + alpha, + mask) + return jnp.sum(energies) -""" Helper functions for motor function """ +# Helper functions for motor function def lr_2_fwd_rot(left_spd, right_spd, base_length, wheel_diameter): fwd = (wheel_diameter / 4.) * (left_spd + right_spd) @@ -186,42 +221,57 @@ def fwd_rot_2_lr(fwd, rot, base_length, wheel_diameter): def motor_command(wheel_activation, base_length, wheel_diameter): - fwd, rot = lr_2_fwd_rot(wheel_activation[0], wheel_activation[1], base_length, wheel_diameter) - return fwd, rot + fwd, rot = lr_2_fwd_rot(wheel_activation[0], wheel_activation[1], base_length, wheel_diameter) + return fwd, rot motor_command = vmap(motor_command, (0, 0, 0)) -""" Functions to compute the verlet force on the whole system""" +# Functions to compute the verlet force on the whole system def get_verlet_force_fn(displacement): - coll_force_fn = quantity.force(partial(total_collision_energy, displacement=displacement, - epsilon=10., alpha=12)) - - def collision_force(nve_state, neighbor, exists_mask): - return coll_force_fn(nve_state.position.center, neighbor=neighbor, exists_mask=exists_mask, diameter=nve_state.diameter) - - def friction_force(nve_state, exists_mask): - cur_vel = nve_state.momentum.center / nve_state.mass.center + coll_force_fn = quantity.force(partial(total_collision_energy, displacement=displacement)) + + def collision_force(state, neighbor, exists_mask): + return coll_force_fn( + state.nve_state.position.center, + neighbor=neighbor, + exists_mask=exists_mask, + diameter=state.nve_state.diameter, + epsilon=state.simulator_state.collision_eps, + alpha=state.simulator_state.collision_alpha + ) + + def friction_force(state, exists_mask): + cur_vel = state.nve_state.momentum.center / state.nve_state.mass.center # stack the mask to give it the same shape as cur_vel (that has 2 rows for forward and angular velocities) mask = jnp.stack([exists_mask] * 2, axis=1) cur_vel = jnp.where(mask, cur_vel, 0.) - return - jnp.tile(nve_state.friction, (SPACE_NDIMS, 1)).T * cur_vel + return - jnp.tile(state.nve_state.friction, (SPACE_NDIMS, 1)).T * cur_vel def motor_force(state, exists_mask): agent_idx = state.agent_state.nve_idx - body = rigid_body.RigidBody(center=state.nve_state.position.center[agent_idx], - orientation=state.nve_state.position.orientation[agent_idx]) - fwd, rot = motor_command(state.agent_state.motor, - state.nve_state.diameter[agent_idx], - state.agent_state.wheel_diameter) + + body = rigid_body.RigidBody( + center=state.nve_state.position.center[agent_idx], + orientation=state.nve_state.position.orientation[agent_idx] + ) + n = normal(body.orientation) + fwd, rot = motor_command( + state.agent_state.motor, + state.nve_state.diameter[agent_idx], + state.agent_state.wheel_diameter + ) + cur_vel = state.nve_state.momentum.center[agent_idx] / state.nve_state.mass.center[agent_idx] cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) cur_rot_vel = state.nve_state.momentum.orientation[agent_idx] / state.nve_state.mass.orientation[agent_idx] + fwd_delta = fwd - cur_fwd_vel rot_delta = rot - cur_rot_vel + fwd_force = n * jnp.tile(fwd_delta, (SPACE_NDIMS, 1)).T * jnp.tile(state.agent_state.speed_mul, (SPACE_NDIMS, 1)).T rot_force = rot_delta * state.agent_state.theta_mul @@ -240,14 +290,17 @@ def motor_force(state, exists_mask): def force_fn(state, neighbor, exists_mask): mf = motor_force(state, exists_mask) - center = collision_force(state.nve_state, neighbor, exists_mask) + friction_force(state.nve_state, exists_mask) + mf.center + cf = collision_force(state, neighbor, exists_mask) + ff = friction_force(state, exists_mask) + + center = cf + ff + mf.center orientation = mf.orientation return rigid_body.RigidBody(center=center, orientation=orientation) return force_fn -""" Helper functions for sensors """ +# Helper functions for sensors def dist_theta(displ, theta): """ @@ -293,7 +346,7 @@ def sensor(displ, theta, dist_max, cos_min, n_agents, senders, target_exists): return proxs -""" Functions to compute the dynamics of the whole system """ +# Functions to compute the dynamics of the whole system def dynamics_rigid(displacement, shift, behavior_bank, force_fn=None): force_fn = force_fn or get_verlet_force_fn(displacement)