Skip to content

Commit

Permalink
Merge pull request #85 from clement-moulin-frier/simulator_tuto_notebook
Browse files Browse the repository at this point in the history
Add simulator tutorial notebook
  • Loading branch information
corentinlger authored May 6, 2024
2 parents 0f43534 + e3cb20e commit bf118d6
Show file tree
Hide file tree
Showing 13 changed files with 1,182 additions and 65 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,8 @@ You can add your own tests in the tests/ subdirector. Make sure that the name or

## Tutorials

Several notebooks tutorials can be found in the [notebooks folder](https://github.com/clement-moulin-frier/vivarium/tree/main/notebooks), along with a tutorial for the web interface.
To help you get started and explore the project, we provide a set of Jupyter notebook tutorials located in the [notebooks folder](https://github.com/clement-moulin-frier/vivarium/tree/main/notebooks). These tutorials cover various aspects of the project, from using the graphical interface to interacting with simulations and understanding the backend.

- **Web Interface Tutorial**: Begin with the [web interface tutorial](https://github.com/clement-moulin-frier/vivarium/tree/main/notebooks/web_interface_tutorial.md) to gain a basic understanding of the project and learn how to use the graphical interface.
- **Quickstart Tutorial**: To learn how to interact with a simulation from a Jupyter notebook, follow the [quickstart tutorial](notebooks/quickstart_tutorial.ipynb). This tutorial will guide you through creating, running, and manipulating simulations within a notebook environment.
- **Simulator Tutorial**: For a deeper understanding of the simulator backend and its capabilities, check out the [simulator tutorial](notebooks/simulator_tutorial.ipynb). This tutorial provides insights into the underlying mechanics of the simulator and demonstrates how to leverage its features for advanced use cases.
1,034 changes: 1,034 additions & 0 deletions notebooks/simulator_tutorial.ipynb

Large diffs are not rendered by default.

30 changes: 15 additions & 15 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class StateFieldInfo:
'proximity_map_theta': StateFieldInfo(('agent_state', 'proximity_map_theta',), slice(None), neighbor_map_s_to_c, identity_c_to_s),
'behavior': StateFieldInfo(('agent_state', 'behavior',), None, behavior_s_to_c, behavior_c_to_s),
'color': StateFieldInfo(('agent_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s),
'idx': StateFieldInfo(('agent_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s),
'idx': StateFieldInfo(('agent_state', 'ent_idx',), None, identity_s_to_c, identity_c_to_s),
'exists': StateFieldInfo(('entity_state', 'exists'), None, identity_s_to_c, exists_c_to_s)
}

Expand All @@ -83,7 +83,7 @@ class StateFieldInfo:
'diameter': StateFieldInfo(('entity_state', 'diameter'), None, identity_s_to_c, identity_c_to_s),
'friction': StateFieldInfo(('entity_state', 'friction'), None, identity_s_to_c, identity_c_to_s),
'color': StateFieldInfo(('object_state', 'color',), np.arange(3), color_s_to_c, color_c_to_s),
'idx': StateFieldInfo(('object_state', 'nve_idx',), None, identity_s_to_c, identity_c_to_s),
'idx': StateFieldInfo(('object_state', 'ent_idx',), None, identity_s_to_c, identity_c_to_s),
'exists': StateFieldInfo(('entity_state', 'exists'), None, identity_s_to_c, exists_c_to_s)

}
Expand Down Expand Up @@ -120,7 +120,7 @@ def get_default_state(n_entities_dict):
friction=jnp.zeros(n_entities),
exists=jnp.ones(n_entities, dtype=int)
),
agent_state=AgentState(nve_idx=jnp.zeros(max_agents, dtype=int),
agent_state=AgentState(ent_idx=jnp.zeros(max_agents, dtype=int),
prox=jnp.zeros((max_agents, 2)),
motor=jnp.zeros((max_agents, 2)),
proximity_map_dist=jnp.zeros((max_agents, 1)),
Expand All @@ -133,12 +133,12 @@ def get_default_state(n_entities_dict):
proxs_dist_max=jnp.zeros(max_agents),
proxs_cos_min=jnp.zeros(max_agents),
color=jnp.zeros((max_agents, 3))),
object_state=ObjectState(nve_idx=jnp.zeros(max_objects, dtype=int), color=jnp.zeros((max_objects, 3))))
object_state=ObjectState(ent_idx=jnp.zeros(max_objects, dtype=int), color=jnp.zeros((max_objects, 3))))


EntityTuple = namedtuple('EntityTuple', ['idx', 'col', 'val'])
ValueTuple = namedtuple('ValueData', ['nve_idx', 'col_idx', 'row_map', 'col_map', 'val'])
StateChangeTuple = namedtuple('StateChange', ['nested_field', 'nve_idx', 'column_idx', 'value'])
ValueTuple = namedtuple('ValueData', ['ent_idx', 'col_idx', 'row_map', 'col_map', 'val'])
StateChangeTuple = namedtuple('StateChange', ['nested_field', 'ent_idx', 'column_idx', 'value'])


def events_to_nve_data(events, state):
Expand Down Expand Up @@ -166,17 +166,17 @@ def events_to_nve_data(events, state):
def nve_data_to_state_changes(nve_data, state):
value_data = dict()
for nf, nve_tuples in nve_data.items():
nve_idx = sorted(list(set([int(t.idx) for t in nve_tuples])))
row_map = {idx: i for i, idx in enumerate(nve_idx)}
ent_idx = sorted(list(set([int(t.idx) for t in nve_tuples])))
row_map = {idx: i for i, idx in enumerate(ent_idx)}
if nve_tuples[0].col is None:
val = np.array(state.field(nf)[np.array(nve_idx)])
val = np.array(state.field(nf)[np.array(ent_idx)])
col_map = None
col_idx = None
else:
col_idx = sorted(list(set([t.col for t in nve_tuples])))
col_map = {idx: i for i, idx in enumerate(col_idx)}
val = np.array(state.field(nf)[np.ix_(state.row_idx(nf[0], nve_idx), col_idx)])
value_data[nf] = ValueTuple(nve_idx, col_idx, row_map, col_map, val)
val = np.array(state.field(nf)[np.ix_(state.row_idx(nf[0], ent_idx), col_idx)])
value_data[nf] = ValueTuple(ent_idx, col_idx, row_map, col_map, val)

state_changes = []
for nf, value_tuple in value_data.items():
Expand All @@ -187,7 +187,7 @@ def nve_data_to_state_changes(nve_data, state):
else:
col = value_tuple.col_map[nve_tuple.col]
value_tuple.val[row, col] = nve_tuple.val
state_changes.append(StateChangeTuple(nf, value_data[nf].nve_idx,
state_changes.append(StateChangeTuple(nf, value_data[nf].ent_idx,
value_data[nf].col_idx, value_tuple.val))

return state_changes
Expand Down Expand Up @@ -231,12 +231,12 @@ def set_state_from_config_dict(config_dict, state=None):
params = configs[0].param_names()
for p in params:
state_field_info = configs_to_state_dict[stype][p]
nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'entity_state' else range(len(configs))
change = rec_set_dataclass(state, state_field_info.nested_field, jnp.array(nve_idx), state_field_info.column_idx,
ent_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'entity_state' else range(len(configs))
change = rec_set_dataclass(state, state_field_info.nested_field, jnp.array(ent_idx), state_field_info.column_idx,
jnp.array([state_field_info.config_to_state(getattr(c, p)) for c in configs]))
state = state.set(**change)
if stype.is_entity():
e_idx.at[state.field(stype).nve_idx].set(jnp.array(range(n_entities_dict[stype])))
e_idx.at[state.field(stype).ent_idx].set(jnp.array(range(n_entities_dict[stype])))

# TODO: something weird with the to lines below, the second one will have no effect (would need state = state.set(.)), but if we fix it we get only zeros in entity_state.entitiy_idx. As it is it seems to get correct values though
change = rec_set_dataclass(state, ('entity_state', 'entity_idx'), jnp.array(range(sum(n_entities_dict.values()))), None, e_idx)
Expand Down
6 changes: 3 additions & 3 deletions vivarium/controllers/panel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class PanelSimulatorConfig(Config):
class Selected(param.Parameterized):
selection = param.ListSelector([0], objects=[0])

def selection_nve_idx(self, nve_idx):
return nve_idx[np.array(self.selection)].tolist()
def selection_nve_idx(self, ent_idx):
return ent_idx[np.array(self.selection)].tolist()

def __len__(self):
return len(self.selection)
Expand Down Expand Up @@ -108,7 +108,7 @@ def pull_selected_configs(self, *events):
with self.dont_push_selected_configs():
# Todo: check if for loop below is still required
for etype, selected in self.selected_entities.items():
config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0]))
config_dict[etype.to_state_type()][0].idx = int(state.ent_idx(etype.to_state_type(), selected.selection[0]))
converters.set_configs_from_state(state, config_dict)
return state

Expand Down
8 changes: 4 additions & 4 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def proto_to_nve_state(entity_state):


def proto_to_agent_state(agent_state):
return AgentState(nve_idx=proto_to_ndarray(agent_state.nve_idx).astype(int),
return AgentState(ent_idx=proto_to_ndarray(agent_state.ent_idx).astype(int),
proximity_map_dist=proto_to_ndarray(agent_state.proximity_map_dist).astype(float),
proximity_map_theta=proto_to_ndarray(agent_state.proximity_map_theta).astype(float),
prox=proto_to_ndarray(agent_state.prox).astype(float),
Expand All @@ -64,7 +64,7 @@ def proto_to_agent_state(agent_state):


def proto_to_object_state(object_state):
return ObjectState(nve_idx=proto_to_ndarray(object_state.nve_idx).astype(int),
return ObjectState(ent_idx=proto_to_ndarray(object_state.ent_idx).astype(int),
color=proto_to_ndarray(object_state.color).astype(float),
)

Expand Down Expand Up @@ -111,7 +111,7 @@ def nve_state_to_proto(entity_state):


def agent_state_to_proto(agent_state):
return simulator_pb2.AgentState(nve_idx=ndarray_to_proto(agent_state.nve_idx),
return simulator_pb2.AgentState(ent_idx=ndarray_to_proto(agent_state.ent_idx),
proximity_map_dist=ndarray_to_proto(agent_state.proximity_map_dist),
proximity_map_theta=ndarray_to_proto(agent_state.proximity_map_theta),
prox=ndarray_to_proto(agent_state.prox),
Expand All @@ -128,6 +128,6 @@ def agent_state_to_proto(agent_state):


def object_state_to_proto(object_state):
return simulator_pb2.ObjectState(nve_idx=ndarray_to_proto(object_state.nve_idx),
return simulator_pb2.ObjectState(ent_idx=ndarray_to_proto(object_state.ent_idx),
color=ndarray_to_proto(object_state.color)
)
6 changes: 3 additions & 3 deletions vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ message EntityState {
}

message AgentState {
NDArray nve_idx = 1;
NDArray ent_idx = 1;
NDArray prox = 2;
NDArray motor = 3;
NDArray behavior = 4;
Expand All @@ -99,7 +99,7 @@ message AgentState {
}

message ObjectState {
NDArray nve_idx = 1;
NDArray ent_idx = 1;
NDArray custom_field = 2;
NDArray color = 3;
}
Expand All @@ -112,7 +112,7 @@ message State {
}

message StateChange {
repeated int32 nve_idx= 1;
repeated int32 ent_idx= 1;
repeated int32 col_idx= 2;
repeated string nested_field = 3;
NDArray value = 4;
Expand Down
4 changes: 2 additions & 2 deletions vivarium/simulator/grpc_server/simulator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def stop(self):
def get_change_time(self):
return self.stub.GetChangeTime(Empty()).time

def set_state(self, nested_field, nve_idx, column_idx, value):
state_change = simulator_pb2.StateChange(nested_field=nested_field, nve_idx=nve_idx, col_idx=column_idx,
def set_state(self, nested_field, ent_idx, column_idx, value):
state_change = simulator_pb2.StateChange(nested_field=nested_field, ent_idx=ent_idx, col_idx=column_idx,
value=ndarray_to_proto(value))
self.stub.SetState(state_change)

Expand Down
Loading

0 comments on commit bf118d6

Please sign in to comment.