Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added max_speed attribute to agents #72

Merged
merged 6 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conf/scene/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ agents:
behavior: 1
wheel_diameter: 2.
speed_mul: 1.
max_speed: 10.
theta_mul: 1.
prox_dist_max: 40.
prox_cos_min: 0.
Expand Down
4 changes: 3 additions & 1 deletion tests/test_simulator_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.simulator.physics_engine import dynamics_rigid


def test_init_simulator_no_args():
Expand Down Expand Up @@ -40,6 +40,7 @@ def test_init_simulator_args():
behavior = 1
wheel_diameter = 2.0
speed_mul = 1.0
max_speed = 10.0
theta_mul = 1.0
prox_dist_max = 20.0
prox_cos_min = 0.0
Expand All @@ -62,6 +63,7 @@ def test_init_simulator_args():
behavior=behavior,
wheel_diameter=wheel_diameter,
speed_mul=speed_mul,
max_speed=max_speed,
theta_mul=theta_mul,
prox_dist_max=prox_dist_max,
prox_cos_min=prox_cos_min)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.simulator.physics_engine import dynamics_rigid

NUM_STEPS = 50

Expand Down
1 change: 1 addition & 0 deletions vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class AgentConfig(Config):
wheel_diameter = param.Number(2.)
diameter = param.Number(5.)
speed_mul = param.Number(1.)
max_speed = param.Number(10.)
theta_mul = param.Number(1.)
proxs_dist_max = param.Number(100., bounds=(0, None))
proxs_cos_min = param.Number(0., bounds=(-1., 1.))
Expand Down
1 change: 1 addition & 0 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_default_state(n_entities_dict):
behavior=jnp.zeros(max_agents, dtype=int),
wheel_diameter=jnp.zeros(max_agents),
speed_mul=jnp.zeros(max_agents),
max_speed=jnp.zeros(max_agents),
theta_mul=jnp.zeros(max_agents),
proxs_dist_max=jnp.zeros(max_agents),
proxs_cos_min=jnp.zeros(max_agents),
Expand Down
2 changes: 2 additions & 0 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def proto_to_agent_state(agent_state):
behavior=proto_to_ndarray(agent_state.behavior).astype(int),
wheel_diameter=proto_to_ndarray(agent_state.wheel_diameter).astype(float),
speed_mul=proto_to_ndarray(agent_state.speed_mul).astype(float),
max_speed=proto_to_ndarray(agent_state.max_speed).astype(float),
theta_mul=proto_to_ndarray(agent_state.theta_mul).astype(float),
proxs_dist_max=proto_to_ndarray(agent_state.proxs_dist_max).astype(float),
proxs_cos_min=proto_to_ndarray(agent_state.proxs_cos_min).astype(float),
Expand Down Expand Up @@ -114,6 +115,7 @@ def agent_state_to_proto(agent_state):
behavior=ndarray_to_proto(agent_state.behavior),
wheel_diameter=ndarray_to_proto(agent_state.wheel_diameter),
speed_mul=ndarray_to_proto(agent_state.speed_mul),
max_speed=ndarray_to_proto(agent_state.max_speed),
theta_mul=ndarray_to_proto(agent_state.theta_mul),
proxs_dist_max=ndarray_to_proto(agent_state.proxs_dist_max),
proxs_cos_min=ndarray_to_proto(agent_state.proxs_cos_min),
Expand Down
9 changes: 5 additions & 4 deletions vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ message AgentState {
NDArray behavior = 4;
NDArray wheel_diameter = 5;
NDArray speed_mul = 6;
NDArray theta_mul = 7;
NDArray proxs_dist_max = 8;
NDArray proxs_cos_min = 9;
NDArray color = 10;
NDArray max_speed = 7;
NDArray theta_mul = 8;
NDArray proxs_dist_max = 9;
NDArray proxs_cos_min = 10;
NDArray color = 11;
}

message ObjectState {
Expand Down
28 changes: 14 additions & 14 deletions vivarium/simulator/grpc_server/simulator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions vivarium/simulator/grpc_server/simulator_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ class EntitiesState(_message.Message):
def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class AgentState(_message.Message):
__slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color")
__slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color")
NVE_IDX_FIELD_NUMBER: _ClassVar[int]
PROX_FIELD_NUMBER: _ClassVar[int]
MOTOR_FIELD_NUMBER: _ClassVar[int]
BEHAVIOR_FIELD_NUMBER: _ClassVar[int]
WHEEL_DIAMETER_FIELD_NUMBER: _ClassVar[int]
SPEED_MUL_FIELD_NUMBER: _ClassVar[int]
MAX_SPEED_FIELD_NUMBER: _ClassVar[int]
THETA_MUL_FIELD_NUMBER: _ClassVar[int]
PROXS_DIST_MAX_FIELD_NUMBER: _ClassVar[int]
PROXS_COS_MIN_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -94,11 +95,12 @@ class AgentState(_message.Message):
behavior: NDArray
wheel_diameter: NDArray
speed_mul: NDArray
max_speed: NDArray
theta_mul: NDArray
proxs_dist_max: NDArray
proxs_cos_min: NDArray
color: NDArray
def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...
def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class ObjectState(_message.Message):
__slots__ = ("nve_idx", "custom_field", "color")
Expand Down
2 changes: 2 additions & 0 deletions vivarium/simulator/physics_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def motor_force(state, exists_mask):
state.entities_state.diameter[agent_idx],
state.agent_state.wheel_diameter
)
# `a_max` arg is deprecated in recent versions of jax, replaced by `max`
fwd = jnp.clip(fwd, a_max=state.agent_state.max_speed)

cur_vel = state.entities_state.momentum.center[agent_idx] / state.entities_state.mass.center[agent_idx]
cur_fwd_vel = vmap(jnp.dot)(cur_vel, n)
Expand Down
Loading
Loading