From 5acf3de74af1f9da85afd4b85c22ec1ba0f97fe3 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 4 Apr 2024 19:22:52 +0200 Subject: [PATCH 1/6] Rename n_agents and n_objects to max_* --- scripts/run_server.py | 8 +-- scripts/run_simulation.py | 8 +-- tests/test_simulator_init.py | 8 +-- vivarium/controllers/config.py | 4 +- vivarium/controllers/converters.py | 32 +++++----- vivarium/simulator/grpc_server/converters.py | 8 +-- .../grpc_server/protos/simulator.proto | 6 +- .../simulator/grpc_server/simulator_pb2.py | 36 +++++------ .../simulator/grpc_server/simulator_pb2.pyi | 20 +++--- vivarium/simulator/sim_computation.py | 4 +- vivarium/simulator/states.py | 62 +++++++++---------- 11 files changed, 98 insertions(+), 98 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index 5009bcd..5cc2134 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -16,9 +16,9 @@ def parse_args(): parser = argparse.ArgumentParser(description='Simulator Configuration') parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') - parser.add_argument('--n_agents', type=int, default=10, help='Number of agents') + parser.add_argument('--max_agents', type=int, default=10, help='Number of agents') parser.add_argument('--n_existing_agents', type=int, default=10, help='Number of agents') - parser.add_argument('--n_objects', type=int, default=2, help='Number of objects') + parser.add_argument('--max_objects', type=int, default=2, help='Number of objects') parser.add_argument('--n_existing_objects', type=int, default=2, help='Number of agents') parser.add_argument('--num_steps-lax', type=int, default=4, help='Number of lax steps per loop') parser.add_argument('--dt', type=float, default=0.1, help='Time step size') @@ -41,8 +41,8 @@ def parse_args(): simulator_state = init_simulator_state( box_size=args.box_size, - n_agents=args.n_agents, - n_objects=args.n_objects, + max_agents=args.max_agents, + max_objects=args.max_objects, num_steps_lax=args.num_steps_lax, neighbor_radius=args.neighbor_radius, dt=args.dt, diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index 1279853..df26e3e 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -19,8 +19,8 @@ def parse_args(): parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation steps') # Simulator config arguments parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') - parser.add_argument('--n_agents', type=int, default=10, help='Number of agents') - parser.add_argument('--n_objects', type=int, default=2, help='Number of objects') + parser.add_argument('--max_agents', type=int, default=10, help='Number of agents') + parser.add_argument('--max_objects', type=int, default=2, help='Number of objects') parser.add_argument('--num_steps_lax', type=int, default=4, help='Number of lax steps per loop') parser.add_argument('--dt', type=float, default=0.1, help='Time step size') parser.add_argument('--freq', type=float, default=40.0, help='Frequency parameter') @@ -41,8 +41,8 @@ def parse_args(): simulator_state = init_simulator_state( box_size=args.box_size, - n_agents=args.n_agents, - n_objects=args.n_objects, + max_agents=args.max_agents, + max_objects=args.max_objects, num_steps_lax=args.num_steps_lax, neighbor_radius=args.neighbor_radius, dt=args.dt, diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index 7ea3718..bb7f30a 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -30,8 +30,8 @@ def test_init_simulator_no_args(): def test_init_simulator_args(): """ Test the initialization of state with arguments """ box_size = 100.0 - n_agents = 10 - n_objects = 2 + max_agents = 10 + max_objects = 2 col_eps = 0.1 col_alpha = 0.5 @@ -47,8 +47,8 @@ def test_init_simulator_args(): simulator_state = init_simulator_state( box_size=box_size, - n_agents=n_agents, - n_objects=n_objects, + max_agents=max_agents, + max_objects=max_objects, collision_eps=col_eps, collision_alpha=col_alpha) diff --git a/vivarium/controllers/config.py b/vivarium/controllers/config.py index fa280a5..bea4b53 100644 --- a/vivarium/controllers/config.py +++ b/vivarium/controllers/config.py @@ -75,8 +75,8 @@ def __init__(self, **params): class SimulatorConfig(Config): idx = param.Integer(0, constant=True) box_size = param.Number(100., bounds=(0, None)) - n_agents = param.Integer(10) - n_objects = param.Integer(2) + max_agents = param.Integer(10) + max_objects = param.Integer(2) num_steps_lax = param.Integer(4) dt = param.Number(0.1) freq = param.Number(40., allow_None=True) diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index aa34ad7..987f5d2 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -97,11 +97,11 @@ class StateFieldInfo: def get_default_state(n_entities_dict): - n_agents = n_entities_dict[StateType.AGENT] - n_objects = n_entities_dict[StateType.OBJECT] + max_agents = n_entities_dict[StateType.AGENT] + max_objects = n_entities_dict[StateType.OBJECT] n_entities = sum(n_entities_dict.values()) return State(simulator_state=SimulatorState(idx=jnp.array([0]), box_size=jnp.array([100.]), - n_agents=jnp.array([n_agents]), n_objects=jnp.array([n_objects]), + max_agents=jnp.array([max_agents]), max_objects=jnp.array([max_objects]), num_steps_lax=jnp.array([1]), dt=jnp.array([1.]), freq=jnp.array([1.]), neighbor_radius=jnp.array([1.]), to_jit= jnp.array([1]), use_fori_loop=jnp.array([0]), @@ -111,23 +111,23 @@ def get_default_state(n_entities_dict): momentum=None, force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), mass=RigidBody(center=jnp.zeros((n_entities, 1)), orientation=jnp.zeros(n_entities)), - entity_type=jnp.array([EntityType.AGENT.value] * n_agents + [EntityType.OBJECT.value] * n_objects, dtype=int), - entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), + entity_type=jnp.array([EntityType.AGENT.value] * max_agents + [EntityType.OBJECT.value] * max_objects, dtype=int), + entity_idx = jnp.array(list(range(max_agents)) + list(range(max_objects))), diameter=jnp.zeros(n_entities), friction=jnp.zeros(n_entities), exists=jnp.ones(n_entities, dtype=int) ), - agent_state=AgentState(nve_idx=jnp.zeros(n_agents, dtype=int), - prox=jnp.zeros((n_agents, 2)), - motor=jnp.zeros((n_agents, 2)), - behavior=jnp.zeros(n_agents, dtype=int), - wheel_diameter=jnp.zeros(n_agents), - speed_mul=jnp.zeros(n_agents), - theta_mul=jnp.zeros(n_agents), - proxs_dist_max=jnp.zeros(n_agents), - proxs_cos_min=jnp.zeros(n_agents), - color=jnp.zeros((n_agents, 3))), - object_state=ObjectState(nve_idx=jnp.zeros(n_objects, dtype=int), color=jnp.zeros((n_objects, 3)))) + agent_state=AgentState(nve_idx=jnp.zeros(max_agents, dtype=int), + prox=jnp.zeros((max_agents, 2)), + motor=jnp.zeros((max_agents, 2)), + behavior=jnp.zeros(max_agents, dtype=int), + wheel_diameter=jnp.zeros(max_agents), + speed_mul=jnp.zeros(max_agents), + theta_mul=jnp.zeros(max_agents), + proxs_dist_max=jnp.zeros(max_agents), + proxs_cos_min=jnp.zeros(max_agents), + color=jnp.zeros((max_agents, 3))), + object_state=ObjectState(nve_idx=jnp.zeros(max_objects, dtype=int), color=jnp.zeros((max_objects, 3)))) NVETuple = namedtuple('NVETuple', ['idx', 'col', 'val']) diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index 96cbd1f..b3c09a1 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -16,8 +16,8 @@ def proto_to_state(state): def proto_to_simulator_state(simulator_state): return SimulatorState(idx=proto_to_ndarray(simulator_state.idx).astype(int), box_size=proto_to_ndarray(simulator_state.box_size).astype(float), - n_agents=proto_to_ndarray(simulator_state.n_agents).astype(int), - n_objects=proto_to_ndarray(simulator_state.n_objects).astype(int), + max_agents=proto_to_ndarray(simulator_state.max_agents).astype(int), + max_objects=proto_to_ndarray(simulator_state.max_objects).astype(int), num_steps_lax=proto_to_ndarray(simulator_state.num_steps_lax).astype(int), dt=proto_to_ndarray(simulator_state.dt).astype(float), freq=proto_to_ndarray(simulator_state.freq).astype(float), @@ -77,8 +77,8 @@ def simulator_state_to_proto(simulator_state): return simulator_pb2.SimulatorState( idx=ndarray_to_proto(simulator_state.idx), box_size=ndarray_to_proto(simulator_state.box_size), - n_agents=ndarray_to_proto(simulator_state.n_agents), - n_objects=ndarray_to_proto(simulator_state.n_objects), + max_agents=ndarray_to_proto(simulator_state.max_agents), + max_objects=ndarray_to_proto(simulator_state.max_objects), num_steps_lax=ndarray_to_proto(simulator_state.num_steps_lax), dt=ndarray_to_proto(simulator_state.dt), freq=ndarray_to_proto(simulator_state.freq), diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index afd6c8f..d69fc7d 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -58,8 +58,8 @@ message RigidBody { message SimulatorState { NDArray idx = 1; NDArray box_size = 2; - NDArray n_agents = 3; - NDArray n_objects = 4; + NDArray max_agents = 3; + NDArray max_objects = 4; NDArray num_steps_lax = 5; NDArray dt = 6; NDArray freq = 7; @@ -116,7 +116,7 @@ message StateChange { } message AddAgentInput { - int32 n_agents = 1; + int32 max_agents = 1; string serialized_config =2; } diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index c6dcdcc..caee8e3 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe5\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08n_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tn_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\"<\n\rAddAgentInput\x12\x10\n\x08n_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -30,21 +30,21 @@ _globals['_RIGIDBODY']._serialized_start=112 _globals['_RIGIDBODY']._serialized_end=200 _globals['_SIMULATORSTATE']._serialized_start=203 - _globals['_SIMULATORSTATE']._serialized_end=688 - _globals['_NVESTATE']._serialized_start=691 - _globals['_NVESTATE']._serialized_end=1047 - _globals['_AGENTSTATE']._serialized_start=1050 - _globals['_AGENTSTATE']._serialized_end=1450 - _globals['_OBJECTSTATE']._serialized_start=1452 - _globals['_OBJECTSTATE']._serialized_end=1579 - _globals['_STATE']._serialized_start=1582 - _globals['_STATE']._serialized_end=1771 - _globals['_STATECHANGE']._serialized_start=1773 - _globals['_STATECHANGE']._serialized_end=1877 - _globals['_ADDAGENTINPUT']._serialized_start=1879 - _globals['_ADDAGENTINPUT']._serialized_end=1939 - _globals['_ISSTARTEDSTATE']._serialized_start=1941 - _globals['_ISSTARTEDSTATE']._serialized_end=1977 - _globals['_SIMULATORSERVER']._serialized_start=1980 - _globals['_SIMULATORSERVER']._serialized_end=2546 + _globals['_SIMULATORSTATE']._serialized_end=692 + _globals['_NVESTATE']._serialized_start=695 + _globals['_NVESTATE']._serialized_end=1051 + _globals['_AGENTSTATE']._serialized_start=1054 + _globals['_AGENTSTATE']._serialized_end=1454 + _globals['_OBJECTSTATE']._serialized_start=1456 + _globals['_OBJECTSTATE']._serialized_end=1583 + _globals['_STATE']._serialized_start=1586 + _globals['_STATE']._serialized_end=1775 + _globals['_STATECHANGE']._serialized_start=1777 + _globals['_STATECHANGE']._serialized_end=1881 + _globals['_ADDAGENTINPUT']._serialized_start=1883 + _globals['_ADDAGENTINPUT']._serialized_end=1945 + _globals['_ISSTARTEDSTATE']._serialized_start=1947 + _globals['_ISSTARTEDSTATE']._serialized_end=1983 + _globals['_SIMULATORSERVER']._serialized_start=1986 + _globals['_SIMULATORSERVER']._serialized_end=2552 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 26b6a56..8e05925 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -27,11 +27,11 @@ class RigidBody(_message.Message): def __init__(self, center: _Optional[_Union[NDArray, _Mapping]] = ..., orientation: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class SimulatorState(_message.Message): - __slots__ = ("idx", "box_size", "n_agents", "n_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha") + __slots__ = ("idx", "box_size", "max_agents", "max_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha") IDX_FIELD_NUMBER: _ClassVar[int] BOX_SIZE_FIELD_NUMBER: _ClassVar[int] - N_AGENTS_FIELD_NUMBER: _ClassVar[int] - N_OBJECTS_FIELD_NUMBER: _ClassVar[int] + MAX_AGENTS_FIELD_NUMBER: _ClassVar[int] + MAX_OBJECTS_FIELD_NUMBER: _ClassVar[int] NUM_STEPS_LAX_FIELD_NUMBER: _ClassVar[int] DT_FIELD_NUMBER: _ClassVar[int] FREQ_FIELD_NUMBER: _ClassVar[int] @@ -42,8 +42,8 @@ class SimulatorState(_message.Message): COLLISION_ALPHA_FIELD_NUMBER: _ClassVar[int] idx: NDArray box_size: NDArray - n_agents: NDArray - n_objects: NDArray + max_agents: NDArray + max_objects: NDArray num_steps_lax: NDArray dt: NDArray freq: NDArray @@ -52,7 +52,7 @@ class SimulatorState(_message.Message): use_fori_loop: NDArray collision_eps: NDArray collision_alpha: NDArray - def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., n_agents: _Optional[_Union[NDArray, _Mapping]] = ..., n_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ..., collision_eps: _Optional[_Union[NDArray, _Mapping]] = ..., collision_alpha: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... + def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., max_agents: _Optional[_Union[NDArray, _Mapping]] = ..., max_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ..., collision_eps: _Optional[_Union[NDArray, _Mapping]] = ..., collision_alpha: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class NVEState(_message.Message): __slots__ = ("position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists") @@ -135,12 +135,12 @@ class StateChange(_message.Message): def __init__(self, nve_idx: _Optional[_Iterable[int]] = ..., col_idx: _Optional[_Iterable[int]] = ..., nested_field: _Optional[_Iterable[str]] = ..., value: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class AddAgentInput(_message.Message): - __slots__ = ("n_agents", "serialized_config") - N_AGENTS_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("max_agents", "serialized_config") + MAX_AGENTS_FIELD_NUMBER: _ClassVar[int] SERIALIZED_CONFIG_FIELD_NUMBER: _ClassVar[int] - n_agents: int + max_agents: int serialized_config: str - def __init__(self, n_agents: _Optional[int] = ..., serialized_config: _Optional[str] = ...) -> None: ... + def __init__(self, max_agents: _Optional[int] = ..., serialized_config: _Optional[str] = ...) -> None: ... class IsStartedState(_message.Message): __slots__ = ("is_started",) diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index 243c68a..d664afe 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -210,10 +210,10 @@ def sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists): sensor_fn = vmap(sensor_fn, (0, 0, 0, 0, 0)) -def sensor(displ, theta, dist_max, cos_min, n_agents, senders, target_exists): +def sensor(displ, theta, dist_max, cos_min, max_agents, senders, target_exists): dist, relative_theta = proximity_map(displ, theta) proxs = ops.segment_max(sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists), - senders, n_agents) + senders, max_agents) return proxs diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index b53c9da..f3be3c8 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -69,8 +69,8 @@ class ObjectState: class SimulatorState: idx: util.Array box_size: util.Array - n_agents: util.Array - n_objects: util.Array + max_agents: util.Array + max_objects: util.Array num_steps_lax: util.Array dt: util.Array freq: util.Array @@ -82,7 +82,7 @@ class SimulatorState: @staticmethod def get_type(attr): - if attr in ['idx', 'n_agents', 'n_objects', 'num_steps_lax']: + if attr in ['idx', 'max_agents', 'max_objects', 'num_steps_lax']: return int elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius', 'collision_alpha', 'collision_eps']: return float @@ -147,8 +147,8 @@ def _string_to_rgb(color_str): def init_simulator_state( box_size: float = 100., - n_agents: int = 10, - n_objects: int = 2, + max_agents: int = 10, + max_objects: int = 2, num_steps_lax: int = 4, dt: float = 0.1, freq: float = 40., @@ -164,8 +164,8 @@ def init_simulator_state( return SimulatorState( idx=jnp.array([0]), box_size=jnp.array([box_size]), - n_agents=jnp.array([n_agents]), - n_objects=jnp.array([n_objects]), + max_agents=jnp.array([max_agents]), + max_objects=jnp.array([max_objects]), num_steps_lax=jnp.array([num_steps_lax], dtype=int), dt=jnp.array([dt], dtype=float), freq=jnp.array([freq], dtype=float), @@ -213,29 +213,29 @@ def init_nve_state( """ Initialize nve state with given parameters """ - n_agents = simulator_state.n_agents[0] - n_objects = simulator_state.n_objects[0] - n_entities = n_agents + n_objects + max_agents = simulator_state.max_agents[0] + max_objects = simulator_state.max_objects[0] + n_entities = max_agents + max_objects key = random.PRNGKey(seed) key_pos, key_or = random.split(key) key_ag, key_obj = random.split(key_pos) # If we have a list of agents or objects positions, transform it into a jax array, else initialize random positions - agents_positions = _init_positions(key_ag, agents_positions, n_agents, simulator_state.box_size) - objects_positions = _init_positions(key_obj, objects_positions, n_objects, simulator_state.box_size) + agents_positions = _init_positions(key_ag, agents_positions, max_agents, simulator_state.box_size) + objects_positions = _init_positions(key_obj, objects_positions, max_objects, simulator_state.box_size) # Assign their positions to each entities positions = jnp.concatenate((agents_positions, objects_positions)) # Assign random orientations between 0 and 2*pi orientations = random.uniform(key_or, (n_entities,)) * 2 * jnp.pi - agents_entities = jnp.full(n_agents, EntityType.AGENT.value) - object_entities = jnp.full(n_objects, EntityType.OBJECT.value) + agents_entities = jnp.full(max_agents, EntityType.AGENT.value) + object_entities = jnp.full(max_objects, EntityType.OBJECT.value) entity_types = jnp.concatenate((agents_entities, object_entities), dtype=int) - existing_agents = _init_existing(existing_agents, n_agents) - existing_objects = _init_existing(existing_objects, n_objects) + existing_agents = _init_existing(existing_agents, max_agents) + existing_objects = _init_existing(existing_objects, max_objects) exists = jnp.concatenate((existing_agents, existing_objects), dtype=int) return NVEState( @@ -244,7 +244,7 @@ def init_nve_state( force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), mass=RigidBody(center=jnp.full((n_entities, 1), mass_center), orientation=jnp.full((n_entities), mass_orientation)), entity_type=entity_types, - entity_idx = jnp.array(list(range(n_agents)) + list(range(n_objects))), + entity_idx = jnp.array(list(range(max_agents)) + list(range(max_objects))), diameter=jnp.full((n_entities), diameter), friction=jnp.full((n_entities), friction), exists=exists @@ -264,19 +264,19 @@ def init_agent_state( """ Initialize agent state with given parameters """ - n_agents = simulator_state.n_agents[0] + max_agents = simulator_state.max_agents[0] return AgentState( - nve_idx=jnp.arange(n_agents, dtype=int), - prox=jnp.zeros((n_agents, 2)), - motor=jnp.zeros((n_agents, 2)), - behavior=jnp.full((n_agents), behavior), - wheel_diameter=jnp.full((n_agents), wheel_diameter), - speed_mul=jnp.full((n_agents), speed_mul), - theta_mul=jnp.full((n_agents), theta_mul), - proxs_dist_max=jnp.full((n_agents), prox_dist_max), - proxs_cos_min=jnp.full((n_agents), prox_cos_min), - color=jnp.tile(_string_to_rgb(color), (n_agents, 1)) + nve_idx=jnp.arange(max_agents, dtype=int), + prox=jnp.zeros((max_agents, 2)), + motor=jnp.zeros((max_agents, 2)), + behavior=jnp.full((max_agents), behavior), + wheel_diameter=jnp.full((max_agents), wheel_diameter), + speed_mul=jnp.full((max_agents), speed_mul), + theta_mul=jnp.full((max_agents), theta_mul), + proxs_dist_max=jnp.full((max_agents), prox_dist_max), + proxs_cos_min=jnp.full((max_agents), prox_cos_min), + color=jnp.tile(_string_to_rgb(color), (max_agents, 1)) ) @@ -287,12 +287,12 @@ def init_object_state( """ Initialize object state with given parameters """ - n_agents, n_objects = simulator_state.n_agents[0], simulator_state.n_objects[0] - start_idx, stop_idx = n_agents, n_agents + n_objects + max_agents, max_objects = simulator_state.max_agents[0], simulator_state.max_objects[0] + start_idx, stop_idx = max_agents, max_agents + max_objects objects_nve_idx = jnp.arange(start_idx, stop_idx, dtype=int) return ObjectState( nve_idx=objects_nve_idx, - color=jnp.tile(_string_to_rgb(color), (n_objects, 1)) + color=jnp.tile(_string_to_rgb(color), (max_objects, 1)) ) From c030d1c3fc77c85456e1c8f5065940404faca05f Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 4 Apr 2024 19:31:44 +0200 Subject: [PATCH 2/6] Rename nve_state to entities_state --- scripts/run_server.py | 4 +- scripts/run_simulation.py | 4 +- tests/test_simulator_init.py | 8 +-- tests/test_simulator_run.py | 4 +- vivarium/controllers/converters.py | 52 ++++++++-------- vivarium/simulator/grpc_server/converters.py | 62 +++++++++---------- .../grpc_server/protos/simulator.proto | 6 +- .../simulator/grpc_server/simulator_client.py | 4 +- .../simulator/grpc_server/simulator_pb2.py | 34 +++++----- .../simulator/grpc_server/simulator_pb2.pyi | 10 +-- .../grpc_server/simulator_pb2_grpc.py | 6 +- .../simulator/grpc_server/simulator_server.py | 4 +- vivarium/simulator/sim_computation.py | 58 ++++++++--------- vivarium/simulator/simulator.py | 8 +-- vivarium/simulator/states.py | 28 ++++----- 15 files changed, 146 insertions(+), 146 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index 5cc2134..87c9bd2 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -57,7 +57,7 @@ def parse_args(): objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state( + entities_state = init_nve_state( simulator_state=simulator_state, existing_agents=args.n_existing_agents, existing_objects=args.n_existing_objects, @@ -67,7 +67,7 @@ def parse_args(): simulator_state=simulator_state, agents_state=agents_state, objects_state=objects_state, - nve_state=nve_state + entities_state=entities_state ) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index df26e3e..6898b4c 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -57,13 +57,13 @@ def parse_args(): objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state(simulator_state=simulator_state) + entities_state = init_nve_state(simulator_state=simulator_state) state = init_state( simulator_state=simulator_state, agents_state=agents_state, objects_state=objects_state, - nve_state=nve_state + entities_state=entities_state ) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index bb7f30a..a1cd46a 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -13,13 +13,13 @@ def test_init_simulator_no_args(): simulator_state = init_simulator_state() agents_state = init_agent_state(simulator_state=simulator_state) objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state(simulator_state=simulator_state) + entities_state = init_nve_state(simulator_state=simulator_state) state = init_state( simulator_state=simulator_state, agents_state=agents_state, objects_state=objects_state, - nve_state=nve_state + entities_state=entities_state ) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) @@ -52,7 +52,7 @@ def test_init_simulator_args(): collision_eps=col_eps, collision_alpha=col_alpha) - nve_state = init_nve_state( + entities_state = init_nve_state( simulator_state, diameter=diameter, friction=friction) @@ -74,7 +74,7 @@ def test_init_simulator_args(): simulator_state=simulator_state, agents_state=agent_state, objects_state=object_state, - nve_state=nve_state) + entities_state=entities_state) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) diff --git a/tests/test_simulator_run.py b/tests/test_simulator_run.py index 5e9c23b..b4ffbbd 100644 --- a/tests/test_simulator_run.py +++ b/tests/test_simulator_run.py @@ -16,13 +16,13 @@ def test_simulator_run(): objects_state = init_object_state(simulator_state=simulator_state) - nve_state = init_nve_state(simulator_state=simulator_state) + entities_state = init_nve_state(simulator_state=simulator_state) state = init_state( simulator_state=simulator_state, agents_state=agents_state, objects_state=objects_state, - nve_state=nve_state + entities_state=entities_state ) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index 987f5d2..f3db9fc 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -10,7 +10,7 @@ from jax_md.rigid_body import RigidBody from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig, stype_to_config, config_to_stype -from vivarium.simulator.states import State, SimulatorState, NVEState, AgentState, ObjectState, EntityType, StateType +from vivarium.simulator.states import State, SimulatorState, EntitiesState, AgentState, ObjectState, EntityType, StateType from vivarium.simulator.behaviors import behavior_name_map, reversed_behavior_name_map @@ -53,13 +53,13 @@ class StateFieldInfo: exists_c_to_s = lambda x: int(x) -agent_configs_to_state_dict = {'x_position': StateFieldInfo(('nve_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s), - 'y_position': StateFieldInfo(('nve_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s), - 'orientation': StateFieldInfo(('nve_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s), - 'mass_center': StateFieldInfo(('nve_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s), - 'mass_orientation': StateFieldInfo(('nve_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s), - 'diameter': StateFieldInfo(('nve_state', 'diameter'), None, identity_s_to_c, identity_c_to_s), - 'friction': StateFieldInfo(('nve_state', 'friction'), None, identity_s_to_c, identity_c_to_s), +agent_configs_to_state_dict = {'x_position': StateFieldInfo(('entities_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s), + 'y_position': StateFieldInfo(('entities_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s), + 'orientation': StateFieldInfo(('entities_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s), + 'mass_center': StateFieldInfo(('entities_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s), + 'mass_orientation': StateFieldInfo(('entities_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s), + 'diameter': StateFieldInfo(('entities_state', 'diameter'), None, identity_s_to_c, identity_c_to_s), + 'friction': StateFieldInfo(('entities_state', 'friction'), None, identity_s_to_c, identity_c_to_s), 'left_motor': StateFieldInfo(('agent_state', 'motor',), 0, identity_s_to_c, identity_c_to_s), 'right_motor': StateFieldInfo(('agent_state', 'motor',), 1, identity_s_to_c, identity_c_to_s), 'left_prox': StateFieldInfo(('agent_state', 'prox',), 0, identity_s_to_c, identity_c_to_s), @@ -67,21 +67,21 @@ class StateFieldInfo: 'behavior': StateFieldInfo(('agent_state', 'behavior',), None, behavior_s_to_c, behavior_c_to_s), 'color': StateFieldInfo(('agent_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s), 'idx': StateFieldInfo(('agent_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s), - 'exists': StateFieldInfo(('nve_state', 'exists'), None, identity_s_to_c, exists_c_to_s) + 'exists': StateFieldInfo(('entities_state', 'exists'), None, identity_s_to_c, exists_c_to_s) } agent_configs_to_state_dict.update({f: StateFieldInfo(('agent_state', f,), None, identity_s_to_c, identity_c_to_s) for f in agent_common_fields if f not in agent_configs_to_state_dict}) -object_configs_to_state_dict = {'x_position': StateFieldInfo(('nve_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s), - 'y_position': StateFieldInfo(('nve_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s), - 'orientation': StateFieldInfo(('nve_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s), - 'mass_center': StateFieldInfo(('nve_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s), - 'mass_orientation': StateFieldInfo(('nve_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s), - 'diameter': StateFieldInfo(('nve_state', 'diameter'), None, identity_s_to_c, identity_c_to_s), - 'friction': StateFieldInfo(('nve_state', 'friction'), None, identity_s_to_c, identity_c_to_s), +object_configs_to_state_dict = {'x_position': StateFieldInfo(('entities_state', 'position', 'center'), 0, identity_s_to_c, identity_c_to_s), + 'y_position': StateFieldInfo(('entities_state', 'position', 'center'), 1, identity_s_to_c, identity_c_to_s), + 'orientation': StateFieldInfo(('entities_state', 'position', 'orientation'), None, identity_s_to_c, identity_c_to_s), + 'mass_center': StateFieldInfo(('entities_state', 'mass', 'center'), np.array([0]), mass_center_s_to_c, mass_center_c_to_s), + 'mass_orientation': StateFieldInfo(('entities_state', 'mass', 'orientation'), None, identity_s_to_c, identity_c_to_s), + 'diameter': StateFieldInfo(('entities_state', 'diameter'), None, identity_s_to_c, identity_c_to_s), + 'friction': StateFieldInfo(('entities_state', 'friction'), None, identity_s_to_c, identity_c_to_s), 'color': StateFieldInfo(('object_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s), 'idx': StateFieldInfo(('object_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s), - 'exists': StateFieldInfo(('nve_state', 'exists'), None, identity_s_to_c, exists_c_to_s) + 'exists': StateFieldInfo(('entities_state', 'exists'), None, identity_s_to_c, exists_c_to_s) } @@ -107,7 +107,7 @@ def get_default_state(n_entities_dict): to_jit= jnp.array([1]), use_fori_loop=jnp.array([0]), collision_alpha=jnp.array([0.]), collision_eps=jnp.array([0.])), - nve_state=NVEState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), + entities_state=EntitiesState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), momentum=None, force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), mass=RigidBody(center=jnp.zeros((n_entities, 1)), orientation=jnp.zeros(n_entities)), @@ -130,7 +130,7 @@ def get_default_state(n_entities_dict): object_state=ObjectState(nve_idx=jnp.zeros(max_objects, dtype=int), color=jnp.zeros((max_objects, 3)))) -NVETuple = namedtuple('NVETuple', ['idx', 'col', 'val']) +EntitiesTuple = namedtuple('EntitiesTuple', ['idx', 'col', 'val']) ValueTuple = namedtuple('ValueData', ['nve_idx', 'col_idx', 'row_map', 'col_map', 'val']) StateChangeTuple = namedtuple('StateChange', ['nested_field', 'nve_idx', 'column_idx', 'value']) @@ -147,12 +147,12 @@ def events_to_nve_data(events, state): val = state_field_info.config_to_state(e.new) if state_field_info.column_idx is None: - nve_data[nested_field].append(NVETuple(idx, None, val)) + nve_data[nested_field].append(EntitiesTuple(idx, None, val)) elif isinstance(state_field_info.column_idx, int): - nve_data[nested_field].append(NVETuple(idx, state_field_info.column_idx, val)) + nve_data[nested_field].append(EntitiesTuple(idx, state_field_info.column_idx, val)) else: for c, v in zip(state_field_info.column_idx, val): - nve_data[nested_field].append(NVETuple(idx, c, v)) + nve_data[nested_field].append(EntitiesTuple(idx, c, v)) return nve_data @@ -221,15 +221,15 @@ def set_state_from_config_dict(config_dict, state=None): params = configs[0].param_names() for p in params: state_field_info = configs_to_state_dict[stype][p] - nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'nve_state' else range(len(configs)) + nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'entities_state' else range(len(configs)) change = rec_set_dataclass(state, state_field_info.nested_field, jnp.array(nve_idx), state_field_info.column_idx, jnp.array([state_field_info.config_to_state(getattr(c, p)) for c in configs])) state = state.set(**change) if stype.is_entity(): e_idx.at[state.field(stype).nve_idx].set(jnp.array(range(n_entities_dict[stype]))) - # TODO: something weird with the to lines below, the second one will have no effect (would need state = state.set(.)), but if we fix it we get only zeros in nve_state.entitiy_idx. As it is it seems to get correct values though - change = rec_set_dataclass(state, ('nve_state', 'entity_idx'), jnp.array(range(sum(n_entities_dict.values()))), None, e_idx) + # TODO: something weird with the to lines below, the second one will have no effect (would need state = state.set(.)), but if we fix it we get only zeros in entities_state.entitiy_idx. As it is it seems to get correct values though + change = rec_set_dataclass(state, ('entities_state', 'entity_idx'), jnp.array(range(sum(n_entities_dict.values()))), None, e_idx) state.set(**change) return state @@ -238,7 +238,7 @@ def set_state_from_config_dict(config_dict, state=None): def set_configs_from_state(state, config_dict=None): if config_dict is None: config_dict = {stype: [] for stype in list(StateType)} - for idx, stype_int in enumerate(state.nve_state.entity_type): + for idx, stype_int in enumerate(state.entities_state.entity_type): stype = StateType(stype_int) config_dict[stype].append(stype_to_config[stype](idx=idx)) config_dict[StateType.SIMULATOR].append(SimulatorConfig()) diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index b3c09a1..1de40e1 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -3,12 +3,12 @@ import simulator_pb2 from vivarium.simulator.grpc_server.numproto.numproto import proto_to_ndarray, ndarray_to_proto -from vivarium.simulator.states import State, SimulatorState, NVEState, AgentState, ObjectState +from vivarium.simulator.states import State, SimulatorState, EntitiesState, AgentState, ObjectState def proto_to_state(state): return State(simulator_state=proto_to_simulator_state(state.simulator_state), - nve_state=proto_to_nve_state(state.nve_state), + entities_state=proto_to_nve_state(state.entities_state), agent_state=proto_to_agent_state(state.agent_state), object_state=proto_to_object_state(state.object_state)) @@ -29,20 +29,20 @@ def proto_to_simulator_state(simulator_state): ) -def proto_to_nve_state(nve_state): - return NVEState(position=RigidBody(center=proto_to_ndarray(nve_state.position.center).astype(float), - orientation=proto_to_ndarray(nve_state.position.orientation).astype(float)), - momentum=RigidBody(center=proto_to_ndarray(nve_state.momentum.center).astype(float), - orientation=proto_to_ndarray(nve_state.momentum.orientation).astype(float)), - force=RigidBody(center=proto_to_ndarray(nve_state.force.center).astype(float), - orientation=proto_to_ndarray(nve_state.force.orientation).astype(float)), - mass=RigidBody(center=proto_to_ndarray(nve_state.mass.center).astype(float), - orientation=proto_to_ndarray(nve_state.mass.orientation).astype(float)), - entity_type=proto_to_ndarray(nve_state.entity_type).astype(int), - entity_idx=proto_to_ndarray(nve_state.entity_idx).astype(int), - diameter=proto_to_ndarray(nve_state.diameter).astype(float), - friction=proto_to_ndarray(nve_state.friction).astype(float), - exists=proto_to_ndarray(nve_state.exists).astype(int) +def proto_to_nve_state(entities_state): + return EntitiesState(position=RigidBody(center=proto_to_ndarray(entities_state.position.center).astype(float), + orientation=proto_to_ndarray(entities_state.position.orientation).astype(float)), + momentum=RigidBody(center=proto_to_ndarray(entities_state.momentum.center).astype(float), + orientation=proto_to_ndarray(entities_state.momentum.orientation).astype(float)), + force=RigidBody(center=proto_to_ndarray(entities_state.force.center).astype(float), + orientation=proto_to_ndarray(entities_state.force.orientation).astype(float)), + mass=RigidBody(center=proto_to_ndarray(entities_state.mass.center).astype(float), + orientation=proto_to_ndarray(entities_state.mass.orientation).astype(float)), + entity_type=proto_to_ndarray(entities_state.entity_type).astype(int), + entity_idx=proto_to_ndarray(entities_state.entity_idx).astype(int), + diameter=proto_to_ndarray(entities_state.diameter).astype(float), + friction=proto_to_ndarray(entities_state.friction).astype(float), + exists=proto_to_ndarray(entities_state.exists).astype(int) ) @@ -68,7 +68,7 @@ def proto_to_object_state(object_state): def state_to_proto(state): return simulator_pb2.State(simulator_state=simulator_state_to_proto(state.simulator_state), - nve_state=nve_state_to_proto(state.nve_state), + entities_state=nve_state_to_proto(state.entities_state), agent_state=agent_state_to_proto(state.agent_state), object_state=object_state_to_proto(state.object_state)) @@ -90,20 +90,20 @@ def simulator_state_to_proto(simulator_state): ) -def nve_state_to_proto(nve_state): - return simulator_pb2.NVEState(position=simulator_pb2.RigidBody(center=ndarray_to_proto(nve_state.position.center), - orientation=ndarray_to_proto(nve_state.position.orientation)), - momentum=simulator_pb2.RigidBody(center=ndarray_to_proto(nve_state.momentum.center), - orientation=ndarray_to_proto(nve_state.momentum.orientation)), - force=simulator_pb2.RigidBody(center=ndarray_to_proto(nve_state.force.center), - orientation=ndarray_to_proto(nve_state.force.orientation)), - mass=simulator_pb2.RigidBody(center=ndarray_to_proto(nve_state.mass.center), - orientation=ndarray_to_proto(nve_state.mass.orientation)), - entity_type=ndarray_to_proto(nve_state.entity_type), - entity_idx=ndarray_to_proto(nve_state.entity_idx), - diameter=ndarray_to_proto(nve_state.diameter), - friction=ndarray_to_proto(nve_state.friction), - exists=ndarray_to_proto(nve_state.exists) +def nve_state_to_proto(entities_state): + return simulator_pb2.EntitiesState(position=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.position.center), + orientation=ndarray_to_proto(entities_state.position.orientation)), + momentum=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.momentum.center), + orientation=ndarray_to_proto(entities_state.momentum.orientation)), + force=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.force.center), + orientation=ndarray_to_proto(entities_state.force.orientation)), + mass=simulator_pb2.RigidBody(center=ndarray_to_proto(entities_state.mass.center), + orientation=ndarray_to_proto(entities_state.mass.orientation)), + entity_type=ndarray_to_proto(entities_state.entity_type), + entity_idx=ndarray_to_proto(entities_state.entity_idx), + diameter=ndarray_to_proto(entities_state.diameter), + friction=ndarray_to_proto(entities_state.friction), + exists=ndarray_to_proto(entities_state.exists) ) diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index d69fc7d..30ad91e 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -29,7 +29,7 @@ service SimulatorServer { rpc Step(google.protobuf.Empty) returns (State) {} rpc GetState(google.protobuf.Empty) returns (State) {} - rpc GetNVEState(google.protobuf.Empty) returns (NVEState) {} + rpc GetNVEState(google.protobuf.Empty) returns (EntitiesState) {} rpc GetAgentState(google.protobuf.Empty) returns (AgentState) {} rpc GetObjectState(google.protobuf.Empty) returns (ObjectState) {} rpc SetState(StateChange) returns (google.protobuf.Empty) {} @@ -70,7 +70,7 @@ message SimulatorState { NDArray collision_alpha = 12; } -message NVEState { +message EntitiesState { RigidBody position = 1; RigidBody momentum = 2; RigidBody force = 3; @@ -103,7 +103,7 @@ message ObjectState { message State { SimulatorState simulator_state = 1; - NVEState nve_state = 2; + EntitiesState entities_state = 2; AgentState agent_state = 3; ObjectState object_state = 4; } diff --git a/vivarium/simulator/grpc_server/simulator_client.py b/vivarium/simulator/grpc_server/simulator_client.py index 6b6e12d..a43e7a7 100644 --- a/vivarium/simulator/grpc_server/simulator_client.py +++ b/vivarium/simulator/grpc_server/simulator_client.py @@ -40,8 +40,8 @@ def get_state(self): return proto_to_state(state) def get_nve_state(self): - nve_state = self.stub.GetNVEState(Empty()) - return proto_to_nve_state(nve_state) + entities_state = self.stub.GetNVEState(Empty()) + return proto_to_nve_state(entities_state) def get_agent_state(self): agent_state = self.stub.GetAgentState(Empty()) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index caee8e3..c4bf8ba 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe4\x02\n\x08NVEState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xbd\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12&\n\tnve_state\x18\x02 \x01(\x0b\x32\x13.simulator.NVEState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xb6\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12<\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x13.simulator.NVEState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\x90\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\n \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xbb\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -31,20 +31,20 @@ _globals['_RIGIDBODY']._serialized_end=200 _globals['_SIMULATORSTATE']._serialized_start=203 _globals['_SIMULATORSTATE']._serialized_end=692 - _globals['_NVESTATE']._serialized_start=695 - _globals['_NVESTATE']._serialized_end=1051 - _globals['_AGENTSTATE']._serialized_start=1054 - _globals['_AGENTSTATE']._serialized_end=1454 - _globals['_OBJECTSTATE']._serialized_start=1456 - _globals['_OBJECTSTATE']._serialized_end=1583 - _globals['_STATE']._serialized_start=1586 - _globals['_STATE']._serialized_end=1775 - _globals['_STATECHANGE']._serialized_start=1777 - _globals['_STATECHANGE']._serialized_end=1881 - _globals['_ADDAGENTINPUT']._serialized_start=1883 - _globals['_ADDAGENTINPUT']._serialized_end=1945 - _globals['_ISSTARTEDSTATE']._serialized_start=1947 - _globals['_ISSTARTEDSTATE']._serialized_end=1983 - _globals['_SIMULATORSERVER']._serialized_start=1986 - _globals['_SIMULATORSERVER']._serialized_end=2552 + _globals['_ENTITIESSTATE']._serialized_start=695 + _globals['_ENTITIESSTATE']._serialized_end=1056 + _globals['_AGENTSTATE']._serialized_start=1059 + _globals['_AGENTSTATE']._serialized_end=1459 + _globals['_OBJECTSTATE']._serialized_start=1461 + _globals['_OBJECTSTATE']._serialized_end=1588 + _globals['_STATE']._serialized_start=1591 + _globals['_STATE']._serialized_end=1790 + _globals['_STATECHANGE']._serialized_start=1792 + _globals['_STATECHANGE']._serialized_end=1896 + _globals['_ADDAGENTINPUT']._serialized_start=1898 + _globals['_ADDAGENTINPUT']._serialized_end=1960 + _globals['_ISSTARTEDSTATE']._serialized_start=1962 + _globals['_ISSTARTEDSTATE']._serialized_end=1998 + _globals['_SIMULATORSERVER']._serialized_start=2001 + _globals['_SIMULATORSERVER']._serialized_end=2572 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 8e05925..50d039e 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -54,7 +54,7 @@ class SimulatorState(_message.Message): collision_alpha: NDArray def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., max_agents: _Optional[_Union[NDArray, _Mapping]] = ..., max_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ..., collision_eps: _Optional[_Union[NDArray, _Mapping]] = ..., collision_alpha: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... -class NVEState(_message.Message): +class EntitiesState(_message.Message): __slots__ = ("position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists") POSITION_FIELD_NUMBER: _ClassVar[int] MOMENTUM_FIELD_NUMBER: _ClassVar[int] @@ -111,16 +111,16 @@ class ObjectState(_message.Message): def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., custom_field: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class State(_message.Message): - __slots__ = ("simulator_state", "nve_state", "agent_state", "object_state") + __slots__ = ("simulator_state", "entities_state", "agent_state", "object_state") SIMULATOR_STATE_FIELD_NUMBER: _ClassVar[int] - NVE_STATE_FIELD_NUMBER: _ClassVar[int] + ENTITIES_STATE_FIELD_NUMBER: _ClassVar[int] AGENT_STATE_FIELD_NUMBER: _ClassVar[int] OBJECT_STATE_FIELD_NUMBER: _ClassVar[int] simulator_state: SimulatorState - nve_state: NVEState + entities_state: EntitiesState agent_state: AgentState object_state: ObjectState - def __init__(self, simulator_state: _Optional[_Union[SimulatorState, _Mapping]] = ..., nve_state: _Optional[_Union[NVEState, _Mapping]] = ..., agent_state: _Optional[_Union[AgentState, _Mapping]] = ..., object_state: _Optional[_Union[ObjectState, _Mapping]] = ...) -> None: ... + def __init__(self, simulator_state: _Optional[_Union[SimulatorState, _Mapping]] = ..., entities_state: _Optional[_Union[EntitiesState, _Mapping]] = ..., agent_state: _Optional[_Union[AgentState, _Mapping]] = ..., object_state: _Optional[_Union[ObjectState, _Mapping]] = ...) -> None: ... class StateChange(_message.Message): __slots__ = ("nve_idx", "col_idx", "nested_field", "value") diff --git a/vivarium/simulator/grpc_server/simulator_pb2_grpc.py b/vivarium/simulator/grpc_server/simulator_pb2_grpc.py index e3496e4..f11b97c 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2_grpc.py +++ b/vivarium/simulator/grpc_server/simulator_pb2_grpc.py @@ -29,7 +29,7 @@ def __init__(self, channel): self.GetNVEState = channel.unary_unary( '/simulator.SimulatorServer/GetNVEState', request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=simulator__pb2.NVEState.FromString, + response_deserializer=simulator__pb2.EntitiesState.FromString, ) self.GetAgentState = channel.unary_unary( '/simulator.SimulatorServer/GetAgentState', @@ -137,7 +137,7 @@ def add_SimulatorServerServicer_to_server(servicer, server): 'GetNVEState': grpc.unary_unary_rpc_method_handler( servicer.GetNVEState, request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=simulator__pb2.NVEState.SerializeToString, + response_serializer=simulator__pb2.EntitiesState.SerializeToString, ), 'GetAgentState': grpc.unary_unary_rpc_method_handler( servicer.GetAgentState, @@ -227,7 +227,7 @@ def GetNVEState(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/simulator.SimulatorServer/GetNVEState', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - simulator__pb2.NVEState.FromString, + simulator__pb2.EntitiesState.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/vivarium/simulator/grpc_server/simulator_server.py b/vivarium/simulator/grpc_server/simulator_server.py index 1c070ca..aa919a1 100644 --- a/vivarium/simulator/grpc_server/simulator_server.py +++ b/vivarium/simulator/grpc_server/simulator_server.py @@ -45,8 +45,8 @@ def GetState(self, request, context): return state_to_proto(state) def GetNVEState(self, request, context): - nve_state = self.simulator.state.nve_state - return nve_state_to_proto(nve_state) + entities_state = self.simulator.state.entities_state + return nve_state_to_proto(entities_state) def GetAgentState(self, request, context): agent_state = self.simulator.state.agent_state diff --git a/vivarium/simulator/sim_computation.py b/vivarium/simulator/sim_computation.py index d664afe..8fa4042 100644 --- a/vivarium/simulator/sim_computation.py +++ b/vivarium/simulator/sim_computation.py @@ -105,40 +105,40 @@ def get_verlet_force_fn(displacement): def collision_force(state, neighbor, exists_mask): return coll_force_fn( - state.nve_state.position.center, + state.entities_state.position.center, neighbor=neighbor, exists_mask=exists_mask, - diameter=state.nve_state.diameter, + diameter=state.entities_state.diameter, epsilon=state.simulator_state.collision_eps, alpha=state.simulator_state.collision_alpha ) def friction_force(state, exists_mask): - cur_vel = state.nve_state.momentum.center / state.nve_state.mass.center + cur_vel = state.entities_state.momentum.center / state.entities_state.mass.center # stack the mask to give it the same shape as cur_vel (that has 2 rows for forward and angular velocities) mask = jnp.stack([exists_mask] * 2, axis=1) cur_vel = jnp.where(mask, cur_vel, 0.) - return - jnp.tile(state.nve_state.friction, (SPACE_NDIMS, 1)).T * cur_vel + return - jnp.tile(state.entities_state.friction, (SPACE_NDIMS, 1)).T * cur_vel def motor_force(state, exists_mask): agent_idx = state.agent_state.nve_idx body = rigid_body.RigidBody( - center=state.nve_state.position.center[agent_idx], - orientation=state.nve_state.position.orientation[agent_idx] + center=state.entities_state.position.center[agent_idx], + orientation=state.entities_state.position.orientation[agent_idx] ) n = normal(body.orientation) fwd, rot = motor_command( state.agent_state.motor, - state.nve_state.diameter[agent_idx], + state.entities_state.diameter[agent_idx], state.agent_state.wheel_diameter ) - cur_vel = state.nve_state.momentum.center[agent_idx] / state.nve_state.mass.center[agent_idx] + cur_vel = state.entities_state.momentum.center[agent_idx] / state.entities_state.mass.center[agent_idx] cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) - cur_rot_vel = state.nve_state.momentum.orientation[agent_idx] / state.nve_state.mass.orientation[agent_idx] + cur_rot_vel = state.entities_state.momentum.orientation[agent_idx] / state.entities_state.mass.orientation[agent_idx] fwd_delta = fwd - cur_fwd_vel rot_delta = rot - cur_rot_vel @@ -146,8 +146,8 @@ def motor_force(state, exists_mask): fwd_force = n * jnp.tile(fwd_delta, (SPACE_NDIMS, 1)).T * jnp.tile(state.agent_state.speed_mul, (SPACE_NDIMS, 1)).T rot_force = rot_delta * state.agent_state.theta_mul - center=jnp.zeros_like(state.nve_state.position.center).at[agent_idx].set(fwd_force) - orientation=jnp.zeros_like(state.nve_state.position.orientation).at[agent_idx].set(rot_force) + center=jnp.zeros_like(state.entities_state.position.center).at[agent_idx].set(fwd_force) + orientation=jnp.zeros_like(state.entities_state.position.orientation).at[agent_idx].set(rot_force) # apply mask to make non existing agents stand still orientation = jnp.where(exists_mask, orientation, 0.) @@ -226,37 +226,37 @@ def dynamics_rigid(displacement, shift, behavior_bank, force_fn=None): def init_fn(state, key, kT=0.): key, _ = jax.random.split(key) - assert state.nve_state.momentum is None - assert not jnp.any(state.nve_state.force.center) and not jnp.any(state.nve_state.force.orientation) + assert state.entities_state.momentum is None + assert not jnp.any(state.entities_state.force.center) and not jnp.any(state.entities_state.force.orientation) - state = state.set(nve_state=simulate.initialize_momenta(state.nve_state, key, kT)) + state = state.set(entities_state=simulate.initialize_momenta(state.entities_state, key, kT)) return state - def mask_momentum(nve_state, exists_mask): + def mask_momentum(entities_state, exists_mask): """ Set the momentum values to zeros for non existing entities - :param nve_state: nve_state + :param entities_state: entities_state :param exists_mask: bool array specifying which entities exist or not - :return: nve_state: new nve state state with masked momentum values + :return: entities_state: new entities state state with masked momentum values """ - orientation = jnp.where(exists_mask, nve_state.momentum.orientation, 0) + orientation = jnp.where(exists_mask, entities_state.momentum.orientation, 0) exists_mask = jnp.stack([exists_mask] * SPACE_NDIMS, axis=1) - center = jnp.where(exists_mask, nve_state.momentum.center, 0) + center = jnp.where(exists_mask, entities_state.momentum.center, 0) momentum = rigid_body.RigidBody(center=center, orientation=orientation) - return nve_state.set(momentum=momentum) + return entities_state.set(momentum=momentum) def physics_fn(state, force, shift_fn, dt, neighbor, mask): """Apply a single step of velocity Verlet integration to a state.""" # dt = f32(dt) dt_2 = dt / 2. # f32(dt / 2) # state = sensorimotor(state, neighbor) # now in step_fn - nve_state = simulate.momentum_step(state.nve_state, dt_2) - nve_state = simulate.position_step(nve_state, shift_fn, dt, neighbor=neighbor) - nve_state = nve_state.set(force=force) - nve_state = simulate.momentum_step(nve_state, dt_2) - nve_state = mask_momentum(nve_state, mask) + entities_state = simulate.momentum_step(state.entities_state, dt_2) + entities_state = simulate.position_step(entities_state, shift_fn, dt, neighbor=neighbor) + entities_state = entities_state.set(force=force) + entities_state = simulate.momentum_step(entities_state, dt_2) + entities_state = mask_momentum(entities_state, mask) - return state.set(nve_state=nve_state) + return state.set(entities_state=entities_state) def compute_prox(state, agent_neighs_idx, target_exists_mask): """ @@ -265,10 +265,10 @@ def compute_prox(state, agent_neighs_idx, target_exists_mask): :param agent_neighs_idx: Neighbor representation, where sources are only agents. Matrix of shape (2, n_pairs), where n_pairs is the number of neighbor entity pairs where sources (first row) are agent indexes. :param target_exists_mask: Specify which target entities exist. Vector with shape (n_entities,). - target_exists_mask[i] is True (resp. False) if entity of index i in state.nve_state exists (resp. don't exist). + target_exists_mask[i] is True (resp. False) if entity of index i in state.entities_state exists (resp. don't exist). :return: """ - body = state.nve_state.position + body = state.entities_state.position mask = target_exists_mask[agent_neighs_idx[1, :]] senders, receivers = agent_neighs_idx Ra = body.center[senders] @@ -283,7 +283,7 @@ def sensorimotor(agent_state): return agent_state.set(motor=motor) def step_fn(state, neighbor, agent_neighs_idx): - exists_mask = (state.nve_state.exists == 1) # Only existing entities have effect on others + exists_mask = (state.entities_state.exists == 1) # Only existing entities have effect on others state = state.set(agent_state=compute_prox(state, agent_neighs_idx, target_exists_mask=exists_mask)) state = state.set(agent_state=sensorimotor(state.agent_state)) force = force_fn(state, neighbor, exists_mask) diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index 8a2071d..964d794 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -90,7 +90,7 @@ def _step(self, state, neighbors, num_iterations): # If the neighbor list can't fit in the allocation, rebuild it but bigger. if neighbors.did_buffer_overflow: lg.warning('REBUILDING NEIGHBORS ARRAY') - neighbors = self.allocate_neighbors(current_state.nve_state.position.center) + neighbors = self.allocate_neighbors(current_state.entities_state.position.center) # Because there was an error, we need to re-run this simulation loop from the copy of the current_state we created new_state, neighbors = self.simulation_loop(state=current_state, neighbors=neighbors, num_iterations=num_iterations) # Check that neighbors array is now ok but should be the case (allocate neighbors tries to compute a new list that is large enough according to the simulation state) @@ -230,7 +230,7 @@ def update_function_update(self): def update_fn(_, state_and_neighbors): state, neighs = state_and_neighbors - neighs = neighs.update(state.nve_state.position.center) + neighs = neighs.update(state.entities_state.position.center) return (self.step_fn(state=state, neighbor=neighs, agent_neighs_idx=self.agent_neighs_idx), neighs) @@ -257,9 +257,9 @@ def update_neighbor_fn(self, box_size, neighbor_radius): def allocate_neighbors(self, position=None): lg.info('allocate_neighbors') - position = self.state.nve_state.position.center if position is None else position + position = self.state.entities_state.position.center if position is None else position self.neighbors = self.neighbor_fn.allocate(position) - mask = self.state.nve_state.entity_type[self.neighbors.idx[0]] == EntityType.AGENT.value + mask = self.state.entities_state.entity_type[self.neighbors.idx[0]] == EntityType.AGENT.value self.agent_neighs_idx = self.neighbors.idx[:, mask] return self.neighbors diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index f3be3c8..cdc69ed 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -30,9 +30,9 @@ def to_entity_type(self): assert self.is_entity() return EntityType(self.value) -# No need to define position, momentum, force, and mass (i.e already in simulate.NVEState) +# No need to define position, momentum, force, and mass (i.e already in simulate.EntitiesState) @dataclass -class NVEState(simulate.NVEState): +class EntitiesState(simulate.NVEState): entity_type: util.Array entity_idx: util.Array # idx in XState (e.g. AgentState) diameter: util.Array @@ -46,7 +46,7 @@ def velocity(self) -> util.Array: @dataclass class AgentState: - nve_idx: util.Array # idx in NVEState + nve_idx: util.Array # idx in EntitiesState prox: util.Array motor: util.Array behavior: util.Array @@ -60,7 +60,7 @@ class AgentState: @dataclass class ObjectState: - nve_idx: util.Array # idx in NVEState + nve_idx: util.Array # idx in EntitiesState color: util.Array @@ -95,7 +95,7 @@ def get_type(attr): @dataclass class State: simulator_state: SimulatorState - nve_state: NVEState + entities_state: EntitiesState agent_state: AgentState object_state: ObjectState @@ -121,17 +121,17 @@ def nve_idx(self, etype, entity_idx): return self.field(etype).nve_idx[entity_idx] def e_idx(self, etype): - return self.nve_state.entity_idx[self.nve_state.entity_type == etype.value] + return self.entities_state.entity_idx[self.entities_state.entity_type == etype.value] def e_cond(self, etype): - return self.nve_state.entity_type == etype.value + return self.entities_state.entity_type == etype.value def row_idx(self, field, nve_idx): - return nve_idx if field == 'nve_state' else self.nve_state.entity_idx[jnp.array(nve_idx)] + return nve_idx if field == 'entities_state' else self.entities_state.entity_idx[jnp.array(nve_idx)] def __getattr__(self, name): def wrapper(e_type): - value = getattr(self.nve_state, name) + value = getattr(self.entities_state, name) if isinstance(value, rigid_body.RigidBody): return rigid_body.RigidBody(center=value.center[self.e_cond(e_type)], orientation=value.orientation[self.e_cond(e_type)]) @@ -209,9 +209,9 @@ def init_nve_state( existing_agents: Optional[Union[int, List[float], None]] = None, existing_objects: Optional[Union[int, List[float], None]] = None, seed: int = 0, - ) -> NVEState: + ) -> EntitiesState: """ - Initialize nve state with given parameters + Initialize entities state with given parameters """ max_agents = simulator_state.max_agents[0] max_objects = simulator_state.max_objects[0] @@ -238,7 +238,7 @@ def init_nve_state( existing_objects = _init_existing(existing_objects, max_objects) exists = jnp.concatenate((existing_agents, existing_objects), dtype=int) - return NVEState( + return EntitiesState( position=RigidBody(center=positions, orientation=orientations), momentum=None, force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), @@ -300,13 +300,13 @@ def init_state( simulator_state: SimulatorState, agents_state: AgentState, objects_state: ObjectState, - nve_state: NVEState + entities_state: EntitiesState ) -> State: return State( simulator_state=simulator_state, agent_state=agents_state, object_state=objects_state, - nve_state=nve_state + entities_state=entities_state ) \ No newline at end of file From 58264b0c991e9ae24cf864be6c983a16d236f966 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 4 Apr 2024 19:43:39 +0200 Subject: [PATCH 3/6] Change PRNG system in states generation --- vivarium/simulator/states.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index cdc69ed..85a4653 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -2,7 +2,7 @@ from enum import Enum import matplotlib.colors as mcolors -import jax.numpy as jnp +import jax.numpy as jnp from jax import random from jax_md import util, simulate, rigid_body @@ -112,7 +112,7 @@ def field(self, stype_or_nested_fields): return res - # TODO : Should we keep this function because it is duplicated below ? + # TODO : Should we keep this function because it is duplicated below ? # def nve_idx(self, etype): # cond = self.e_cond(etype) # return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list @@ -218,17 +218,16 @@ def init_nve_state( n_entities = max_agents + max_objects key = random.PRNGKey(seed) - key_pos, key_or = random.split(key) - key_ag, key_obj = random.split(key_pos) + key, key_agents_pos, key_objects_pos, key_orientations = random.split(key, 4) # If we have a list of agents or objects positions, transform it into a jax array, else initialize random positions - agents_positions = _init_positions(key_ag, agents_positions, max_agents, simulator_state.box_size) - objects_positions = _init_positions(key_obj, objects_positions, max_objects, simulator_state.box_size) + agents_positions = _init_positions(key_agents_pos, agents_positions, max_agents, simulator_state.box_size) + objects_positions = _init_positions(key_objects_pos, objects_positions, max_objects, simulator_state.box_size) # Assign their positions to each entities positions = jnp.concatenate((agents_positions, objects_positions)) # Assign random orientations between 0 and 2*pi - orientations = random.uniform(key_or, (n_entities,)) * 2 * jnp.pi + orientations = random.uniform(key_orientations, (n_entities,)) * 2 * jnp.pi agents_entities = jnp.full(max_agents, EntityType.AGENT.value) object_entities = jnp.full(max_objects, EntityType.OBJECT.value) From 25f7e6e24ea6b7c05b83c513f5441876d828cc5a Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 11 Apr 2024 16:13:43 +0200 Subject: [PATCH 4/6] Fix hydra problems with new init system and add physics_engine name --- conf/scene/default.yaml | 6 +- conf/scene/session_1.yaml | 4 +- scripts/run_server.py | 9 +- scripts/run_simulation.py | 4 +- vivarium/simulator/physics_engine.py | 293 +++++++++++++++++++++++++++ 5 files changed, 302 insertions(+), 14 deletions(-) create mode 100644 vivarium/simulator/physics_engine.py diff --git a/conf/scene/default.yaml b/conf/scene/default.yaml index e810bd9..f7e71bb 100644 --- a/conf/scene/default.yaml +++ b/conf/scene/default.yaml @@ -1,6 +1,6 @@ simulator: - n_agents: 10 - n_objects: 2 + max_agents: 10 + max_objects: 2 box_size: 100.0 num_steps_lax: 4 dt: 0.1 @@ -11,7 +11,7 @@ simulator: collision_eps: 0.1 collision_alpha: 0.5 -nve: +entities: diameter: 5. friction: 0.1 mass_center: 1. diff --git a/conf/scene/session_1.yaml b/conf/scene/session_1.yaml index e65dded..263f122 100644 --- a/conf/scene/session_1.yaml +++ b/conf/scene/session_1.yaml @@ -1,3 +1,3 @@ simulator: - n_agents: 1 - n_objects: 1 + max_agents: 1 + max_objects: 1 diff --git a/scripts/run_server.py b/scripts/run_server.py index f6f3ec8..b9b9699 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -10,7 +10,7 @@ from vivarium.simulator.states import init_nve_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator -from vivarium.simulator.sim_computation import dynamics_rigid +from vivarium.simulator.physics_engine import dynamics_rigid from vivarium.simulator.grpc_server.simulator_server import serve lg = logging.getLogger(__name__) @@ -27,12 +27,7 @@ def main(cfg: DictConfig = None) -> None: objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - nve_state = init_nve_state(simulator_state=simulator_state, **args.nve) - entities_state = init_nve_state( - simulator_state=simulator_state, - existing_agents=args.n_existing_agents, - existing_objects=args.n_existing_objects, - ) + entities_state = init_nve_state(simulator_state=simulator_state, **args.entities) state = init_state( simulator_state=simulator_state, diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index 84dd4e5..7f5ddff 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -1,5 +1,5 @@ -import hydra import logging +import hydra from omegaconf import DictConfig, OmegaConf @@ -26,7 +26,7 @@ def main(cfg: DictConfig = None) -> None: objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - entities_state = init_nve_state(simulator_state=simulator_state) + entities_state = init_nve_state(simulator_state=simulator_state, **args.entities) state = init_state( simulator_state=simulator_state, diff --git a/vivarium/simulator/physics_engine.py b/vivarium/simulator/physics_engine.py new file mode 100644 index 0000000..8fa4042 --- /dev/null +++ b/vivarium/simulator/physics_engine.py @@ -0,0 +1,293 @@ +from functools import partial + +import jax +import jax.numpy as jnp + +from jax import ops, vmap, lax +from jax_md import space, rigid_body, util, simulate, energy, quantity +f32 = util.f32 + + +# Only work on 2D environments atm +SPACE_NDIMS = 2 + +@vmap +def normal(theta): + return jnp.array([jnp.cos(theta), jnp.sin(theta)]) + +def switch_fn(fn_list): + def switch(index, *operands): + return jax.lax.switch(index, fn_list, *operands) + return switch + + +# Helper functions for collisions + +def collision_energy(displacement_fn, r_a, r_b, l_a, l_b, epsilon, alpha, mask): + """Compute the collision energy between a pair of particles + + :param displacement_fn: displacement function of jax_md + :param r_a: position of particle a + :param r_b: position of particle b + :param l_a: diameter of particle a + :param l_b: diameter of particle b + :param epsilon: interaction energy scale + :param alpha: interaction stiffness + :param mask: set the energy to 0 if one of the particles is masked + :return: collision energy between both particles + """ + dist = jnp.linalg.norm(displacement_fn(r_a, r_b)) + sigma = (l_a + l_b) / 2 + e = energy.soft_sphere(dist, sigma=sigma, epsilon=epsilon, alpha=f32(alpha)) + return jnp.where(mask, e, 0.) + +collision_energy = vmap(collision_energy, (None, 0, 0, 0, 0, None, None, 0)) + + +def total_collision_energy(positions, diameter, neighbor, displacement, exists_mask, epsilon, alpha): + """Compute the collision energy between all neighboring pairs of particles in the system + + :param positions: positions of all the particles + :param diameter: diameters of all the particles + :param neighbor: neighbor array of the system + :param displacement: dipalcement function of jax_md + :param exists_mask: mask to specify which particles exist + :param epsilon: interaction energy scale between two particles + :param alpha: interaction stiffness between two particles + :return: sum of all collisions energies of the system + """ + diameter = lax.stop_gradient(diameter) + senders, receivers = neighbor.idx + + r_senders = positions[senders] + r_receivers = positions[receivers] + l_senders = diameter[senders] + l_receivers = diameter[receivers] + + # Set collision energy to zero if the sender or receiver is non existing + mask = exists_mask[senders] * exists_mask[receivers] + energies = collision_energy(displacement, + r_senders, + r_receivers, + l_senders, + l_receivers, + epsilon, + alpha, + mask) + return jnp.sum(energies) + + +# Helper functions for motor function + +def lr_2_fwd_rot(left_spd, right_spd, base_length, wheel_diameter): + fwd = (wheel_diameter / 4.) * (left_spd + right_spd) + rot = 0.5 * (wheel_diameter / base_length) * (right_spd - left_spd) + return fwd, rot + + +def fwd_rot_2_lr(fwd, rot, base_length, wheel_diameter): + left = ((2.0 * fwd) - (rot * base_length)) / wheel_diameter + right = ((2.0 * fwd) + (rot * base_length)) / wheel_diameter + return left, right + + +def motor_command(wheel_activation, base_length, wheel_diameter): + fwd, rot = lr_2_fwd_rot(wheel_activation[0], wheel_activation[1], base_length, wheel_diameter) + return fwd, rot + +motor_command = vmap(motor_command, (0, 0, 0)) + + +# Functions to compute the verlet force on the whole system + +def get_verlet_force_fn(displacement): + coll_force_fn = quantity.force(partial(total_collision_energy, displacement=displacement)) + + def collision_force(state, neighbor, exists_mask): + return coll_force_fn( + state.entities_state.position.center, + neighbor=neighbor, + exists_mask=exists_mask, + diameter=state.entities_state.diameter, + epsilon=state.simulator_state.collision_eps, + alpha=state.simulator_state.collision_alpha + ) + + def friction_force(state, exists_mask): + cur_vel = state.entities_state.momentum.center / state.entities_state.mass.center + # stack the mask to give it the same shape as cur_vel (that has 2 rows for forward and angular velocities) + mask = jnp.stack([exists_mask] * 2, axis=1) + cur_vel = jnp.where(mask, cur_vel, 0.) + return - jnp.tile(state.entities_state.friction, (SPACE_NDIMS, 1)).T * cur_vel + + def motor_force(state, exists_mask): + agent_idx = state.agent_state.nve_idx + + body = rigid_body.RigidBody( + center=state.entities_state.position.center[agent_idx], + orientation=state.entities_state.position.orientation[agent_idx] + ) + + n = normal(body.orientation) + + fwd, rot = motor_command( + state.agent_state.motor, + state.entities_state.diameter[agent_idx], + state.agent_state.wheel_diameter + ) + + cur_vel = state.entities_state.momentum.center[agent_idx] / state.entities_state.mass.center[agent_idx] + cur_fwd_vel = vmap(jnp.dot)(cur_vel, n) + cur_rot_vel = state.entities_state.momentum.orientation[agent_idx] / state.entities_state.mass.orientation[agent_idx] + + fwd_delta = fwd - cur_fwd_vel + rot_delta = rot - cur_rot_vel + + fwd_force = n * jnp.tile(fwd_delta, (SPACE_NDIMS, 1)).T * jnp.tile(state.agent_state.speed_mul, (SPACE_NDIMS, 1)).T + rot_force = rot_delta * state.agent_state.theta_mul + + center=jnp.zeros_like(state.entities_state.position.center).at[agent_idx].set(fwd_force) + orientation=jnp.zeros_like(state.entities_state.position.orientation).at[agent_idx].set(rot_force) + + # apply mask to make non existing agents stand still + orientation = jnp.where(exists_mask, orientation, 0.) + # Because position has SPACE_NDMS dims, need to stack the mask to give it the same shape as center + exists_mask = jnp.stack([exists_mask] * SPACE_NDIMS, axis=1) + center = jnp.where(exists_mask, center, 0.) + + + return rigid_body.RigidBody(center=center, + orientation=orientation) + + def force_fn(state, neighbor, exists_mask): + mf = motor_force(state, exists_mask) + cf = collision_force(state, neighbor, exists_mask) + ff = friction_force(state, exists_mask) + + center = cf + ff + mf.center + orientation = mf.orientation + return rigid_body.RigidBody(center=center, orientation=orientation) + + return force_fn + + +# Helper functions for sensors + +def dist_theta(displ, theta): + """ + Compute the relative distance and angle from a source agent to a target agent + :param displ: Displacement vector (jnp arrray with shape (2,) from source to target + :param theta: Orientation of the source agent (in the reference frame of the map) + :return: dist: distance from source to target. + relative_theta: relative angle of the target in the reference frame of the source agent (front direction at angle 0) + """ + dist = jnp.linalg.norm(displ) + norm_displ = displ / dist + theta_displ = jnp.arccos(norm_displ[0]) * jnp.sign(jnp.arcsin(norm_displ[1])) + relative_theta = theta_displ - theta + return dist, relative_theta + +proximity_map = vmap(dist_theta, (0, 0)) + + +def sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists): + """ + Compute the proximeter activations (left, right) induced by the presence of an entity + :param dist: distance from the agent to the entity + :param relative_theta: angle of the entity in the reference frame of the agent (front direction at angle 0) + :param dist_max: Max distance of the proximiter (will return 0. above this distance) + :param cos_min: Field of view as a cosinus (e.g. cos_min = 0 means a pi/4 FoV on each proximeter, so pi/2 in total) + :return: left and right proximeter activation in a jnp array with shape (2,) + """ + cos_dir = jnp.cos(relative_theta) + prox = 1. - (dist / dist_max) + in_view = jnp.logical_and(dist < dist_max, cos_dir > cos_min) + at_left = jnp.logical_and(True, jnp.sin(relative_theta) >= 0) + left = in_view * at_left * prox + right = in_view * (1. - at_left) * prox + return jnp.array([left, right]) * target_exists # i.e. 0 if target does not exist + +sensor_fn = vmap(sensor_fn, (0, 0, 0, 0, 0)) + + +def sensor(displ, theta, dist_max, cos_min, max_agents, senders, target_exists): + dist, relative_theta = proximity_map(displ, theta) + proxs = ops.segment_max(sensor_fn(dist, relative_theta, dist_max, cos_min, target_exists), + senders, max_agents) + return proxs + + +# Functions to compute the dynamics of the whole system + +def dynamics_rigid(displacement, shift, behavior_bank, force_fn=None): + force_fn = force_fn or get_verlet_force_fn(displacement) + multi_switch = jax.vmap(switch_fn(behavior_bank), (0, 0, 0)) + # shape = rigid_body.monomer + + def init_fn(state, key, kT=0.): + key, _ = jax.random.split(key) + assert state.entities_state.momentum is None + assert not jnp.any(state.entities_state.force.center) and not jnp.any(state.entities_state.force.orientation) + + state = state.set(entities_state=simulate.initialize_momenta(state.entities_state, key, kT)) + return state + + def mask_momentum(entities_state, exists_mask): + """ + Set the momentum values to zeros for non existing entities + :param entities_state: entities_state + :param exists_mask: bool array specifying which entities exist or not + :return: entities_state: new entities state state with masked momentum values + """ + orientation = jnp.where(exists_mask, entities_state.momentum.orientation, 0) + exists_mask = jnp.stack([exists_mask] * SPACE_NDIMS, axis=1) + center = jnp.where(exists_mask, entities_state.momentum.center, 0) + momentum = rigid_body.RigidBody(center=center, orientation=orientation) + return entities_state.set(momentum=momentum) + + def physics_fn(state, force, shift_fn, dt, neighbor, mask): + """Apply a single step of velocity Verlet integration to a state.""" + # dt = f32(dt) + dt_2 = dt / 2. # f32(dt / 2) + # state = sensorimotor(state, neighbor) # now in step_fn + entities_state = simulate.momentum_step(state.entities_state, dt_2) + entities_state = simulate.position_step(entities_state, shift_fn, dt, neighbor=neighbor) + entities_state = entities_state.set(force=force) + entities_state = simulate.momentum_step(entities_state, dt_2) + entities_state = mask_momentum(entities_state, mask) + + return state.set(entities_state=entities_state) + + def compute_prox(state, agent_neighs_idx, target_exists_mask): + """ + Set agents' proximeter activations + :param state: full simulation State + :param agent_neighs_idx: Neighbor representation, where sources are only agents. Matrix of shape (2, n_pairs), + where n_pairs is the number of neighbor entity pairs where sources (first row) are agent indexes. + :param target_exists_mask: Specify which target entities exist. Vector with shape (n_entities,). + target_exists_mask[i] is True (resp. False) if entity of index i in state.entities_state exists (resp. don't exist). + :return: + """ + body = state.entities_state.position + mask = target_exists_mask[agent_neighs_idx[1, :]] + senders, receivers = agent_neighs_idx + Ra = body.center[senders] + Rb = body.center[receivers] + dR = - space.map_bond(displacement)(Ra, Rb) # Looks like it should be opposite, but don't understand why + prox = sensor(dR, body.orientation[senders], state.agent_state.proxs_dist_max[senders], + state.agent_state.proxs_cos_min[senders], len(state.agent_state.nve_idx), senders, mask) + return state.agent_state.set(prox=prox) + + def sensorimotor(agent_state): + motor = multi_switch(agent_state.behavior, agent_state.prox, agent_state.motor) + return agent_state.set(motor=motor) + + def step_fn(state, neighbor, agent_neighs_idx): + exists_mask = (state.entities_state.exists == 1) # Only existing entities have effect on others + state = state.set(agent_state=compute_prox(state, agent_neighs_idx, target_exists_mask=exists_mask)) + state = state.set(agent_state=sensorimotor(state.agent_state)) + force = force_fn(state, neighbor, exists_mask) + state = physics_fn(state, force, shift, state.simulator_state.dt[0], neighbor=neighbor, mask=exists_mask) + return state + + return init_fn, step_fn From 56e34c7b85ff146ecf7541b6c49a507ae097e810 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 11 Apr 2024 16:18:41 +0200 Subject: [PATCH 5/6] Clean state file --- vivarium/simulator/states.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 85a4653..0573443 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -10,7 +10,6 @@ from jax_md.rigid_body import RigidBody -# TODO : Add documentation on these classes class EntityType(Enum): AGENT = 0 OBJECT = 1 @@ -29,7 +28,8 @@ def is_entity(self): def to_entity_type(self): assert self.is_entity() return EntityType(self.value) - + + # No need to define position, momentum, force, and mass (i.e already in simulate.EntitiesState) @dataclass class EntitiesState(simulate.NVEState): @@ -64,7 +64,6 @@ class ObjectState: color: util.Array -# TODO : I think it would make more sense to have max_agents, max_objects here instead of n_*** @dataclass class SimulatorState: idx: util.Array @@ -112,11 +111,6 @@ def field(self, stype_or_nested_fields): return res - # TODO : Should we keep this function because it is duplicated below ? - # def nve_idx(self, etype): - # cond = self.e_cond(etype) - # return compress(range(len(cond)), cond) # https://stackoverflow.com/questions/21448225/getting-indices-of-true-values-in-a-boolean-list - def nve_idx(self, etype, entity_idx): return self.field(etype).nve_idx[entity_idx] @@ -198,6 +192,7 @@ def _init_existing(n_existing, n_entities): return exists_array +# TODO : Add options to have either 1 value or a list for parameters such as diameter, friction ... def init_nve_state( simulator_state: SimulatorState, diameter: float = 5., From 0a6bc9282fe37a82d683fd972de1b5c786cd5c11 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 11 Apr 2024 16:21:41 +0200 Subject: [PATCH 6/6] Rename init_nve init_entities --- scripts/run_server.py | 4 ++-- scripts/run_simulation.py | 4 ++-- tests/test_simulator_init.py | 6 +++--- tests/test_simulator_run.py | 4 ++-- vivarium/simulator/states.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/run_server.py b/scripts/run_server.py index b9b9699..feb6f98 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -7,7 +7,7 @@ from vivarium.simulator.states import init_simulator_state from vivarium.simulator.states import init_agent_state from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid @@ -27,7 +27,7 @@ def main(cfg: DictConfig = None) -> None: objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - entities_state = init_nve_state(simulator_state=simulator_state, **args.entities) + entities_state = init_entities_state(simulator_state=simulator_state, **args.entities) state = init_state( simulator_state=simulator_state, diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index 7f5ddff..4aae95b 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -7,7 +7,7 @@ from vivarium.simulator.states import init_simulator_state from vivarium.simulator.states import init_agent_state from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid @@ -26,7 +26,7 @@ def main(cfg: DictConfig = None) -> None: objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - entities_state = init_nve_state(simulator_state=simulator_state, **args.entities) + entities_state = init_entities_state(simulator_state=simulator_state, **args.entities) state = init_state( simulator_state=simulator_state, diff --git a/tests/test_simulator_init.py b/tests/test_simulator_init.py index a1cd46a..c0f09eb 100644 --- a/tests/test_simulator_init.py +++ b/tests/test_simulator_init.py @@ -2,7 +2,7 @@ from vivarium.simulator.states import init_simulator_state from vivarium.simulator.states import init_agent_state from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.sim_computation import dynamics_rigid @@ -13,7 +13,7 @@ def test_init_simulator_no_args(): simulator_state = init_simulator_state() agents_state = init_agent_state(simulator_state=simulator_state) objects_state = init_object_state(simulator_state=simulator_state) - entities_state = init_nve_state(simulator_state=simulator_state) + entities_state = init_entities_state(simulator_state=simulator_state) state = init_state( simulator_state=simulator_state, @@ -52,7 +52,7 @@ def test_init_simulator_args(): collision_eps=col_eps, collision_alpha=col_alpha) - entities_state = init_nve_state( + entities_state = init_entities_state( simulator_state, diameter=diameter, friction=friction) diff --git a/tests/test_simulator_run.py b/tests/test_simulator_run.py index b4ffbbd..6d12175 100644 --- a/tests/test_simulator_run.py +++ b/tests/test_simulator_run.py @@ -2,7 +2,7 @@ from vivarium.simulator.states import init_simulator_state from vivarium.simulator.states import init_agent_state from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_nve_state +from vivarium.simulator.states import init_entities_state from vivarium.simulator.states import init_state from vivarium.simulator.simulator import Simulator from vivarium.simulator.sim_computation import dynamics_rigid @@ -16,7 +16,7 @@ def test_simulator_run(): objects_state = init_object_state(simulator_state=simulator_state) - entities_state = init_nve_state(simulator_state=simulator_state) + entities_state = init_entities_state(simulator_state=simulator_state) state = init_state( simulator_state=simulator_state, diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 0573443..8750f0b 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -193,7 +193,7 @@ def _init_existing(n_existing, n_entities): # TODO : Add options to have either 1 value or a list for parameters such as diameter, friction ... -def init_nve_state( +def init_entities_state( simulator_state: SimulatorState, diameter: float = 5., friction: float = 0.1,