Skip to content

Commit

Permalink
Merge pull request #57 from clement-moulin-frier/corentin/init_state_…
Browse files Browse the repository at this point in the history
…without_configs

Init state without configs
  • Loading branch information
corentinlger committed Apr 4, 2024
2 parents 45291ed + efa959e commit a44e5ed
Show file tree
Hide file tree
Showing 16 changed files with 510 additions and 267 deletions.
46 changes: 24 additions & 22 deletions scripts/run_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import argparse

import numpy as np

import vivarium.simulator.behaviors as behaviors
from vivarium.controllers.config import SimulatorConfig, AgentConfig, ObjectConfig
from vivarium.simulator.sim_computation import StateType
from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.controllers.converters import set_state_from_config_dict
from vivarium.simulator.grpc_server.simulator_server import serve

lg = logging.getLogger(__name__)
Expand All @@ -17,7 +17,9 @@ def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box')
parser.add_argument('--n_agents', type=int, default=10, help='Number of agents')
parser.add_argument('--n_existing_agents', type=int, default=10, help='Number of agents')
parser.add_argument('--n_objects', type=int, default=2, help='Number of objects')
parser.add_argument('--n_existing_objects', type=int, default=2, help='Number of agents')
parser.add_argument('--num_steps-lax', type=int, default=4, help='Number of lax steps per loop')
parser.add_argument('--dt', type=float, default=0.1, help='Time step size')
parser.add_argument('--freq', type=float, default=40.0, help='Frequency parameter')
Expand All @@ -37,36 +39,36 @@ def parse_args():

logging.basicConfig(level=args.log_level.upper())

simulator_config = SimulatorConfig(
simulator_state = init_simulator_state(
box_size=args.box_size,
n_agents=args.n_agents,
n_objects=args.n_objects,
num_steps_lax=args.num_steps_lax,
neighbor_radius=args.neighbor_radius,
dt=args.dt,
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop,
collision_eps=args.collision_eps,
collision_alpha=args.collision_alpha
)

agent_configs = [AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)]
agents_state = init_agent_state(simulator_state=simulator_state)

objects_state = init_object_state(simulator_state=simulator_state)

object_configs = [ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)]
nve_state = init_nve_state(
simulator_state=simulator_state,
existing_agents=args.n_existing_agents,
existing_objects=args.n_existing_objects,
)

state = set_state_from_config_dict({StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
})
state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

Expand Down
53 changes: 21 additions & 32 deletions scripts/run_simulation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
import logging

import numpy as np

from vivarium.simulator import behaviors
from vivarium.simulator.sim_computation import dynamics_rigid, StateType
from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig
from vivarium.controllers import converters
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid

lg = logging.getLogger(__name__)

Expand All @@ -27,8 +28,8 @@ def parse_args():
# By default jit compile the code and use normal python loops
parser.add_argument('--to_jit', action='store_false', help='Whether to use JIT compilation')
parser.add_argument('--use_fori_loop', action='store_true', help='Whether to use fori loop')
parser.add_argument('--collision_eps', type=float, required=False, default=0.3)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.7)
parser.add_argument('--collision_eps', type=float, required=False, default=0.1)
parser.add_argument('--collision_alpha', type=float, required=False, default=0.5)

return parser.parse_args()

Expand All @@ -37,45 +38,33 @@ def parse_args():
args = parse_args()

logging.basicConfig(level=args.log_level.upper())

simulator_config = SimulatorConfig(
simulator_state = init_simulator_state(
box_size=args.box_size,
n_agents=args.n_agents,
n_objects=args.n_objects,
num_steps_lax=args.num_steps_lax,
neighbor_radius=args.neighbor_radius,
dt=args.dt,
freq=args.freq,
neighbor_radius=args.neighbor_radius,
to_jit=args.to_jit,
use_fori_loop=args.use_fori_loop,
collision_eps=args.collision_eps,
collision_alpha=args.collision_alpha
)

agent_configs = [
AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)
]

object_configs = [
ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)
]
agents_state = init_agent_state(simulator_state=simulator_state)

state = converters.set_state_from_config_dict(
{
StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
}
)
objects_state = init_object_state(simulator_state=simulator_state)

nve_state = init_nve_state(simulator_state=simulator_state)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

Expand Down
46 changes: 0 additions & 46 deletions tests/test_simulator.py

This file was deleted.

81 changes: 81 additions & 0 deletions tests/test_simulator_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid


def test_init_simulator_no_args():
""" Test the initialization of state without arguments """
simulator_state = init_simulator_state()
agents_state = init_agent_state(simulator_state=simulator_state)
objects_state = init_object_state(simulator_state=simulator_state)
nve_state = init_nve_state(simulator_state=simulator_state)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

assert simulator


def test_init_simulator_args():
""" Test the initialization of state with arguments """
box_size = 100.0
n_agents = 10
n_objects = 2
col_eps = 0.1
col_alpha = 0.5

diameter = 5.0
friction = 0.1
behavior = 1
wheel_diameter = 2.0
speed_mul = 1.0
theta_mul = 1.0
prox_dist_max = 20.0
prox_cos_min = 0.0
color = "red"

simulator_state = init_simulator_state(
box_size=box_size,
n_agents=n_agents,
n_objects=n_objects,
collision_eps=col_eps,
collision_alpha=col_alpha)

nve_state = init_nve_state(
simulator_state,
diameter=diameter,
friction=friction)

agent_state = init_agent_state(
simulator_state,
behavior=behavior,
wheel_diameter=wheel_diameter,
speed_mul=speed_mul,
theta_mul=theta_mul,
prox_dist_max=prox_dist_max,
prox_cos_min=prox_cos_min)

object_state = init_object_state(
simulator_state,
color=color)

state = init_state(
simulator_state=simulator_state,
agents_state=agent_state,
objects_state=object_state,
nve_state=nve_state)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

assert simulator
34 changes: 34 additions & 0 deletions tests/test_simulator_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_nve_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid

NUM_STEPS = 50

def test_simulator_run():
simulator_state = init_simulator_state()

agents_state = init_agent_state(simulator_state=simulator_state)

objects_state = init_object_state(simulator_state=simulator_state)

nve_state = init_nve_state(simulator_state=simulator_state)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
nve_state=nve_state
)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

simulator.run(threaded=False, num_steps=NUM_STEPS)

assert simulator
2 changes: 1 addition & 1 deletion vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from param import Parameterized

import vivarium.simulator.behaviors as behaviors
from vivarium.simulator.sim_computation import StateType
from vivarium.simulator.states import StateType

from jax_md.rigid_body import monomer

Expand Down
16 changes: 7 additions & 9 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import typing
import dataclasses
from collections import namedtuple, defaultdict

import jax_md
import jax.numpy as jnp
import numpy as np
import matplotlib.colors as mcolors

import jax_md
from jax_md.util import f32
from jax_md.rigid_body import RigidBody

import dataclasses
import typing
from collections import namedtuple, defaultdict

from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig, stype_to_config, config_to_stype
from vivarium.simulator.sim_computation import State, SimulatorState, NVEState, AgentState, ObjectState, EntityType, StateType
from vivarium.simulator.states import State, SimulatorState, NVEState, AgentState, ObjectState, EntityType, StateType
from vivarium.simulator.behaviors import behavior_name_map, reversed_behavior_name_map

import matplotlib.colors as mcolors


agent_config_fields = AgentConfig.param.objects().keys()
agent_state_fields = [f.name for f in jax_md.dataclasses.fields(AgentState)]
Expand Down
2 changes: 1 addition & 1 deletion vivarium/controllers/notebook_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging

from vivarium.controllers.simulator_controller import SimulatorController
from vivarium.simulator.sim_computation import StateType, EntityType
from vivarium.simulator.states import StateType, EntityType

lg = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion vivarium/controllers/panel_controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from vivarium.controllers import converters
from vivarium.controllers.config import AgentConfig, ObjectConfig, config_to_stype, Config
from vivarium.controllers.simulator_controller import SimulatorController
from vivarium.simulator.sim_computation import EntityType, StateType
from vivarium.simulator.states import EntityType, StateType
from vivarium.simulator.grpc_server.simulator_client import SimulatorGRPCClient

import param
Expand Down
Loading

0 comments on commit a44e5ed

Please sign in to comment.