Skip to content

Commit

Permalink
Fix bug when setting n_agents or n_objects to 0
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-moulin-frier committed Mar 20, 2024
1 parent cb3e256 commit 8a20e3a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
4 changes: 3 additions & 1 deletion vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def set_state_from_config_dict(config_dict, state=None):
state = state or get_default_state(n_entities_dict)
e_idx = jnp.zeros(sum(n_entities_dict.values()), dtype=int)
for stype, configs in config_dict.items():
params = configs[0].param_names()
params = configs[0].param_names() if len(configs) > 0 else []
for p in params:
state_field_info = configs_to_state_dict[stype][p]
nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'nve_state' else range(len(configs))
Expand Down Expand Up @@ -248,6 +248,8 @@ def set_configs_from_state(state, config_dict=None):
for f in state_field_info.nested_field:
value = getattr(value, f)
for config in config_dict[stype]:
if config is None:
continue
t = type(getattr(config, param))
row_idx = state.row_idx(state_field_info.nested_field[0], config.idx)
if state_field_info.column_idx is None:
Expand Down
24 changes: 17 additions & 7 deletions vivarium/controllers/panel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ def __init__(self, **params):

self.update_entity_list()
for etype, selected in self.selected_entities.items():
selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1)
selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True)
if selected is not None:
selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1)
selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True)

def watch_selected_configs(self):
watchers = {etype: config.param.watch(self.push_selected_to_config_list, config.param_names(), onlychanged=True)
for etype, config in self.selected_configs.items()}
for etype, config in self.selected_configs.items() if config is not None}
return watchers

def watch_selected_panel_configs(self):
Expand All @@ -78,7 +79,8 @@ def watch_selected_panel_configs(self):
def dont_push_selected_configs(self):
if self._selected_configs_watchers is not None:
for etype, config in self.selected_configs.items():
config.param.unwatch(self._selected_configs_watchers[etype])
if config is not None:
config.param.unwatch(self._selected_configs_watchers[etype])
try:
yield
finally:
Expand All @@ -97,22 +99,30 @@ def dont_push_selected_panel_configs(self):
def update_entity_list(self, *events):
state = self.state
for etype, selected in self.selected_entities.items():
selected.param.selection.objects = state.entity_idx(etype).tolist()
if selected is not None:
selected.param.selection.objects = state.entity_idx(etype).tolist()

def pull_selected_configs(self, *events):
state = self.state

for etype, config in self.selected_configs.items():
if state.field(etype.to_state_type()).nve_idx.shape[0] == 0:
self.selected_configs[etype] = None
self.selected_entities[etype] = None
config_dict = {etype.to_state_type(): [config] for etype, config in self.selected_configs.items()}
with self.dont_push_selected_configs():
# Todo: check if for loop below is still required
for etype, selected in self.selected_entities.items():
config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0]))
if selected is not None:
config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0]))
converters.set_configs_from_state(state, config_dict)
return state

def pull_selected_panel_configs(self, *events):
with self.dont_push_selected_panel_configs():
for etype, panel_config in self.selected_panel_configs.items():
panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict())
if self.selected_entities[etype] is not None:
panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict())

def pull_all_data(self):
self.pull_selected_configs()
Expand Down
14 changes: 8 additions & 6 deletions vivarium/interface/panel_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def __init__(self, config, panel_configs, panel_simulator_config, selected, etyp
self.cds.on_change('data', self.drag_cb)
self.cds_view = self.create_cds_view()
self.panel_simulator_config.param.watch(self.hide_all_non_existing, "hide_non_existing")
selected.param.watch(self.update_selected_plot, ['selection'],
onlychanged=True, precedence=0)
if selected is not None:
selected.param.watch(self.update_selected_plot, ['selection'],
onlychanged=True, precedence=0)
for i, pc in enumerate(self.panel_configs):
pc.param.watch(self.update_cds_view, pc.param_names(), onlychanged=True)
self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=False)
Expand All @@ -58,10 +59,11 @@ def update_cds(self, state):
def create_cds_view(self):
# For each attribute in the panel config, create a filter
# that is a logical AND of the visibility and the attribute
params = self.panel_configs[0].param_names() if len(self.panel_configs) else []
return {
attr: CDSView(filter=BooleanFilter(
[getattr(pc, attr) and pc.visible for pc in self.panel_configs]
)) for attr in self.panel_configs[0].param_names()
)) for attr in params
}

def update_cds_view(self, event):
Expand Down Expand Up @@ -214,7 +216,7 @@ def __init__(self, **kwargs):
panel_simulator_config=self.controller.panel_simulator_config,
selected=self.controller.selected_entities[etype], etype=etype,
state=self.controller.state)
for etype, manager_class in self.entity_manager_classes.items()
for etype, manager_class in self.entity_manager_classes.items() if len(self.controller.configs[etype.to_state_type()])
}

self.plot = self.create_plot()
Expand Down Expand Up @@ -262,8 +264,8 @@ def create_plot(self):
p.add_tools(hover)
p.x_range = Range1d(0, self.controller.simulator_config.box_size)
p.y_range = Range1d(0, self.controller.simulator_config.box_size)
draw_tool = PointDrawTool(renderers=[self.entity_managers[etype].plot(p)
for etype in EntityType], add=False)
draw_tool = PointDrawTool(renderers=[em.plot(p)
for em in self.entity_managers.values()], add=False)
p.add_tools(draw_tool)
return p

Expand Down

0 comments on commit 8a20e3a

Please sign in to comment.