Skip to content

Commit

Permalink
Merge pull request #46 from clement-moulin-frier/corentin/new_collisions
Browse files Browse the repository at this point in the history
Add collisions as part of the simulator state and change default values
  • Loading branch information
corentinlger authored Mar 23, 2024
2 parents cb3e256 + c2266a7 commit e085e43
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 93 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ param==2.0.2
matplotlib==3.8.2

# Client-server communication
grpcio==1.60.0
grpcio
grpcio-tools

# Tests and code formatting
pytest==8.0.0
Expand Down
7 changes: 6 additions & 1 deletion scripts/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def parse_args():
parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation')
parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop')
parser.add_argument('--log_level', type=str, default='INFO', help='Logging level')
parser.add_argument('--collision_eps', type=float, required=False, default=0.1)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.5)

return parser.parse_args()

Expand All @@ -44,7 +46,9 @@ def parse_args():
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop
use_fori_loop=args.use_fori_loop,
collision_eps=args.collision_eps,
collision_alpha=args.collision_alpha
)

agent_configs = [AgentConfig(idx=i,
Expand All @@ -65,5 +69,6 @@ def parse_args():
})

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

lg.info('Simulator server started')
serve(simulator)
10 changes: 7 additions & 3 deletions scripts/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
# Experiment run arguments
parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation loops')
parser.add_argument('--log_level', type=str, default='INFO', help='Logging level')
parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation steps')
# Simulator config arguments
parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box')
parser.add_argument('--n_agents', type=int, default=10, help='Number of agents')
Expand All @@ -26,7 +27,8 @@ def parse_args():
# By default jit compile the code and use normal python loops
parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation')
parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop')
parser.add_argument('--log_level', type=str, default='INFO', help='Logging level')
parser.add_argument('--collision_eps', type=float, required=False, default=0.3)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.7)

return parser.parse_args()

Expand All @@ -45,7 +47,9 @@ def parse_args():
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop
use_fori_loop=args.use_fori_loop,
collision_eps=args.collision_eps,
collision_alpha=args.collision_alpha
)

agent_configs = [
Expand Down
4 changes: 3 additions & 1 deletion vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ObjectConfig(Config):
mass_orientation = param.Number(mass_orientation)
diameter = param.Number(5.)
color = param.Color('red')
friction = param.Number(10.)
friction = param.Number(0.1)
exists = param.Boolean(True)

def __init__(self, **params):
Expand All @@ -83,6 +83,8 @@ class SimulatorConfig(Config):
neighbor_radius = param.Number(100., bounds=(0, None))
to_jit = param.Boolean(True)
use_fori_loop = param.Boolean(False)
collision_eps = param.Number(0.1)
collision_alpha = param.Number(0.5)

def __init__(self, **params):
super().__init__(**params)
Expand Down
6 changes: 4 additions & 2 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
object_state_fields = [f.name for f in jax_md.dataclasses.fields(ObjectState)]

object_common_fields = [f for f in object_config_fields if f in object_state_fields]
#

simulator_config_fields = SimulatorConfig.param.objects().keys()
simulator_state_fields = [f.name for f in jax_md.dataclasses.fields(SimulatorState)]

Expand Down Expand Up @@ -106,7 +106,9 @@ def get_default_state(n_entities_dict):
n_agents=jnp.array([n_agents]), n_objects=jnp.array([n_objects]),
num_steps_lax=jnp.array([1]), dt=jnp.array([1.]), freq=jnp.array([1.]),
neighbor_radius=jnp.array([1.]),
to_jit= jnp.array([1]), use_fori_loop=jnp.array([0])),
to_jit= jnp.array([1]), use_fori_loop=jnp.array([0]),
collision_alpha=jnp.array([0.]),
collision_eps=jnp.array([0.])),
nve_state=NVEState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)),
momentum=None,
force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)),
Expand Down
26 changes: 15 additions & 11 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@ def proto_to_state(state):


def proto_to_simulator_state(simulator_state):
return SimulatorState(idx=proto_to_ndarray(simulator_state.idx),
box_size=proto_to_ndarray(simulator_state.box_size),
n_agents=proto_to_ndarray(simulator_state.n_agents),
n_objects=proto_to_ndarray(simulator_state.n_objects),
num_steps_lax=proto_to_ndarray(simulator_state.num_steps_lax),
dt=proto_to_ndarray(simulator_state.dt),
freq=proto_to_ndarray(simulator_state.freq),
neighbor_radius=proto_to_ndarray(simulator_state.neighbor_radius),
to_jit=proto_to_ndarray(simulator_state.to_jit),
use_fori_loop=proto_to_ndarray(simulator_state.use_fori_loop)
return SimulatorState(idx=proto_to_ndarray(simulator_state.idx).astype(int),
box_size=proto_to_ndarray(simulator_state.box_size).astype(float),
n_agents=proto_to_ndarray(simulator_state.n_agents).astype(int),
n_objects=proto_to_ndarray(simulator_state.n_objects).astype(int),
num_steps_lax=proto_to_ndarray(simulator_state.num_steps_lax).astype(int),
dt=proto_to_ndarray(simulator_state.dt).astype(float),
freq=proto_to_ndarray(simulator_state.freq).astype(float),
neighbor_radius=proto_to_ndarray(simulator_state.neighbor_radius).astype(float),
to_jit=proto_to_ndarray(simulator_state.to_jit).astype(int),
use_fori_loop=proto_to_ndarray(simulator_state.use_fori_loop).astype(int),
collision_eps=proto_to_ndarray(simulator_state.collision_eps).astype(float),
collision_alpha=proto_to_ndarray(simulator_state.collision_alpha).astype(float)
)


Expand Down Expand Up @@ -81,7 +83,9 @@ def simulator_state_to_proto(simulator_state):
freq=ndarray_to_proto(simulator_state.freq),
neighbor_radius=ndarray_to_proto(simulator_state.neighbor_radius),
to_jit=ndarray_to_proto(simulator_state.to_jit),
use_fori_loop=ndarray_to_proto(simulator_state.use_fori_loop)
use_fori_loop=ndarray_to_proto(simulator_state.use_fori_loop),
collision_eps=ndarray_to_proto(simulator_state.collision_eps),
collision_alpha=ndarray_to_proto(simulator_state.collision_alpha)
)


Expand Down
2 changes: 2 additions & 0 deletions vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ message SimulatorState {
NDArray neighbor_radius = 8;
NDArray to_jit = 9;
NDArray use_fori_loop = 10;
NDArray collision_eps = 11;
NDArray collision_alpha = 12;
}

message NVEState {
Expand Down
41 changes: 21 additions & 20 deletions vivarium/simulator/grpc_server/simulator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e085e43

Please sign in to comment.