Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init state without configs #57

Merged
merged 19 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ec5c2a8
First refactoring step save
corentinlger Mar 13, 2024
4534c5d
Merge branch 'main' of github.com:clement-moulin-frier/vivarium into …
corentinlger Mar 18, 2024
98046d4
Add all init functions in state.py
corentinlger Mar 18, 2024
b498f04
Fix imports with new states file
corentinlger Mar 18, 2024
89ecef7
Initialize states without configs
corentinlger Mar 18, 2024
4308ce6
Fix init state values that turn into NaN in client side
corentinlger Mar 20, 2024
868074e
Fix converters import error
corentinlger Mar 20, 2024
b86da7a
Change arguments for agents and objects states init
corentinlger Mar 20, 2024
225c5f6
Add option to test initialization with custom positions
corentinlger Mar 20, 2024
4ee97dc
Update tests for new state initialization
corentinlger Mar 20, 2024
d0e62e8
Merge branch 'main' of github.com:clement-moulin-frier/vivarium into …
corentinlger Mar 21, 2024
b7e98d9
Reduce default value of agents proximeters range
corentinlger Mar 21, 2024
fd96746
Refactor init_positions function and update comments
corentinlger Mar 21, 2024
98a505b
Allow defining a initial number of existing entities
corentinlger Mar 21, 2024
0d42317
Remove code to test custom positions in run_server
corentinlger Apr 3, 2024
5bc67f3
Merge branch 'main' of github.com:clement-moulin-frier/vivarium into …
corentinlger Apr 3, 2024
adefcce
Fix vmap bug on normal function and clean sim_computation
corentinlger Apr 4, 2024
a54f920
Add collision parameters and clean files
corentinlger Apr 4, 2024
efa959e
Fix scipy version
corentinlger Apr 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
jax==0.4.23
jaxlib==0.4.23
jax-md==0.2.8
scipy==1.12.0

# Interface
panel==1.3.8
Expand Down
46 changes: 24 additions & 22 deletions scripts/run_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import argparse

import numpy as np

import vivarium.simulator.behaviors as behaviors
from vivarium.controllers.config import SimulatorConfig, AgentConfig, ObjectConfig
from vivarium.simulator.sim_computation import StateType
from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.controllers.converters import set_state_from_config_dict
from vivarium.simulator.grpc_server.simulator_server import serve

lg = logging.getLogger(__name__)
Expand All @@ -17,7 +17,9 @@ def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box')
parser.add_argument('--n_agents', type=int, default=10, help='Number of agents')
parser.add_argument('--n_existing_agents', type=int, default=10, help='Number of agents')
parser.add_argument('--n_objects', type=int, default=2, help='Number of objects')
parser.add_argument('--n_existing_objects', type=int, default=2, help='Number of agents')
Comment on lines +20 to +22

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your proposition that " it would make more sense to rename the current n_agents to max_agents, and n_existing_agents to n_agents (same for objects)."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you already do it, but it will be useful to add an assert checking that the number of existing agents/objects does not exceed the max number

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I already check this

parser.add_argument('--num_steps-lax', type=int, default=4, help='Number of lax steps per loop')
parser.add_argument('--dt', type=float, default=0.1, help='Time step size')
parser.add_argument('--freq', type=float, default=40.0, help='Frequency parameter')
Expand All @@ -37,36 +39,36 @@ def parse_args():

logging.basicConfig(level=args.log_level.upper())

simulator_config = SimulatorConfig(
simulator_state = init_simulator_state(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part works to show that we can use a list of coordinates, but I think it should be removed before the merge

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is meant to be deleted. I just added it to make the testing of the PR easier

box_size=args.box_size,
n_agents=args.n_agents,
n_objects=args.n_objects,
num_steps_lax=args.num_steps_lax,
neighbor_radius=args.neighbor_radius,
dt=args.dt,
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop,
collision_eps=args.collision_eps,
collision_alpha=args.collision_alpha
)

agent_configs = [AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)]
agents_state = init_agent_state(simulator_state=simulator_state)

objects_state = init_object_state(simulator_state=simulator_state)

object_configs = [ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)]
nve_state = init_nve_state(
simulator_state=simulator_state,
existing_agents=args.n_existing_agents,
existing_objects=args.n_existing_objects,
)

state = set_state_from_config_dict({StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
})
state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

Expand Down
53 changes: 21 additions & 32 deletions scripts/run_simulation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
import logging

import numpy as np

from vivarium.simulator import behaviors
from vivarium.simulator.sim_computation import dynamics_rigid, StateType
from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig
from vivarium.controllers import converters
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid

lg = logging.getLogger(__name__)

Expand All @@ -27,8 +28,8 @@ def parse_args():
# By default jit compile the code and use normal python loops
parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation')
parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop')
parser.add_argument('--collision_eps', type=float, required=False, default=0.3)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.7)
parser.add_argument('--collision_eps', type=float, required=False, default=0.1)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.5)

return parser.parse_args()

Expand All @@ -37,45 +38,33 @@ def parse_args():
args = parse_args()

logging.basicConfig(level=args.log_level.upper())

simulator_config = SimulatorConfig(
simulator_state = init_simulator_state(
box_size=args.box_size,
n_agents=args.n_agents,
n_objects=args.n_objects,
num_steps_lax=args.num_steps_lax,
neighbor_radius=args.neighbor_radius,
dt=args.dt,
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop,
collision_eps=args.collision_eps,
collision_alpha=args.collision_alpha
)

agent_configs = [
AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)
]

object_configs = [
ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)
]
agents_state = init_agent_state(simulator_state=simulator_state)

state = converters.set_state_from_config_dict(
{
StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
}
)
objects_state = init_object_state(simulator_state=simulator_state)

nve_state = init_nve_state(simulator_state=simulator_state)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

Expand Down
46 changes: 0 additions & 46 deletions tests/test_simulator.py

This file was deleted.

81 changes: 81 additions & 0 deletions tests/test_simulator_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid


def test_init_simulator_no_args():
""" Test the initialization of state without arguments """
simulator_state = init_simulator_state()
agents_state = init_agent_state(simulator_state=simulator_state)
objects_state = init_object_state(simulator_state=simulator_state)
nve_state = init_nve_state(simulator_state=simulator_state)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

assert simulator


def test_init_simulator_args():
""" Test the initialization of state with arguments """

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also want to test the case where we only provide a subset of the arguments, to see if it merges well with the default values

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another useful test to do would be to check if (some of) the values in the created state are indeed the ones that are provided in the args or defaults (e.g. just adding a few assert in all your test functions)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I'll add that

box_size = 100.0
n_agents = 10
n_objects = 2
col_eps = 0.1
col_alpha = 0.5

diameter = 5.0
friction = 0.1
behavior = 1
wheel_diameter = 2.0
speed_mul = 1.0
theta_mul = 1.0
prox_dist_max = 20.0
prox_cos_min = 0.0
color = "red"

simulator_state = init_simulator_state(
box_size=box_size,
n_agents=n_agents,
n_objects=n_objects,
collision_eps=col_eps,
collision_alpha=col_alpha)

nve_state = init_nve_state(
simulator_state,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong indentation

diameter=diameter,
friction=friction)

agent_state = init_agent_state(
simulator_state,
behavior=behavior,
wheel_diameter=wheel_diameter,
speed_mul=speed_mul,
theta_mul=theta_mul,
prox_dist_max=prox_dist_max,
prox_cos_min=prox_cos_min)

object_state = init_object_state(
simulator_state,
color=color)

state = init_state(
simulator_state=simulator_state,
agents_state=agent_state,
objects_state=object_state,
nve_state=nve_state)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

assert simulator
34 changes: 34 additions & 0 deletions tests/test_simulator_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid

NUM_STEPS = 50

def test_simulator_run():
simulator_state = init_simulator_state()

agents_state = init_agent_state(simulator_state=simulator_state)

objects_state = init_object_state(simulator_state=simulator_state)

nve_state = init_nve_state(simulator_state=simulator_state)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

simulator.run(threaded=False, num_steps=NUM_STEPS)

assert simulator
2 changes: 1 addition & 1 deletion vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from param import Parameterized

import vivarium.simulator.behaviors as behaviors
from vivarium.simulator.sim_computation import StateType
from vivarium.simulator.states import StateType

from jax_md.rigid_body import monomer

Expand Down
16 changes: 7 additions & 9 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import typing
import dataclasses
from collections import namedtuple, defaultdict

import jax_md
import jax.numpy as jnp
import numpy as np
import matplotlib.colors as mcolors

import jax_md
from jax_md.util import f32
from jax_md.rigid_body import RigidBody

import dataclasses
import typing
from collections import namedtuple, defaultdict

from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig, stype_to_config, config_to_stype
from vivarium.simulator.sim_computation import State, SimulatorState, NVEState, AgentState, ObjectState, EntityType, StateType
from vivarium.simulator.states import State, SimulatorState, NVEState, AgentState, ObjectState, EntityType, StateType
from vivarium.simulator.behaviors import behavior_name_map, reversed_behavior_name_map

import matplotlib.colors as mcolors


agent_config_fields = AgentConfig.param.objects().keys()
agent_state_fields = [f.name for f in jax_md.dataclasses.fields(AgentState)]
Expand Down
2 changes: 1 addition & 1 deletion vivarium/controllers/notebook_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging

from vivarium.controllers.simulator_controller import SimulatorController
from vivarium.simulator.sim_computation import StateType, EntityType
from vivarium.simulator.states import StateType, EntityType

lg = logging.getLogger(__name__)

Expand Down
Loading
Loading