diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index e3d91ae..13c58df 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -218,7 +218,7 @@ def set_state_from_config_dict(config_dict, state=None): state = state or get_default_state(n_entities_dict) e_idx = jnp.zeros(sum(n_entities_dict.values()), dtype=int) for stype, configs in config_dict.items(): - params = configs[0].param_names() + params = configs[0].param_names() if len(configs) > 0 else [] for p in params: state_field_info = configs_to_state_dict[stype][p] nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'nve_state' else range(len(configs)) @@ -248,6 +248,8 @@ def set_configs_from_state(state, config_dict=None): for f in state_field_info.nested_field: value = getattr(value, f) for config in config_dict[stype]: + if config is None: + continue t = type(getattr(config, param)) row_idx = state.row_idx(state_field_info.nested_field[0], config.idx) if state_field_info.column_idx is None: diff --git a/vivarium/controllers/panel_controller.py b/vivarium/controllers/panel_controller.py index 485ad7e..b29e3a4 100644 --- a/vivarium/controllers/panel_controller.py +++ b/vivarium/controllers/panel_controller.py @@ -61,12 +61,13 @@ def __init__(self, **params): self.update_entity_list() for etype, selected in self.selected_entities.items(): - selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1) - selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True) + if selected is not None: + selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1) + selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True) def watch_selected_configs(self): watchers = {etype: config.param.watch(self.push_selected_to_config_list, config.param_names(), onlychanged=True) - for etype, config in self.selected_configs.items()} + for etype, config in self.selected_configs.items() if config is not None} return watchers def watch_selected_panel_configs(self): @@ -78,7 +79,8 @@ def watch_selected_panel_configs(self): def dont_push_selected_configs(self): if self._selected_configs_watchers is not None: for etype, config in self.selected_configs.items(): - config.param.unwatch(self._selected_configs_watchers[etype]) + if config is not None: + config.param.unwatch(self._selected_configs_watchers[etype]) try: yield finally: @@ -97,22 +99,30 @@ def dont_push_selected_panel_configs(self): def update_entity_list(self, *events): state = self.state for etype, selected in self.selected_entities.items(): - selected.param.selection.objects = state.entity_idx(etype).tolist() + if selected is not None: + selected.param.selection.objects = state.entity_idx(etype).tolist() def pull_selected_configs(self, *events): state = self.state + + for etype, config in self.selected_configs.items(): + if state.field(etype.to_state_type()).nve_idx.shape[0] == 0: + self.selected_configs[etype] = None + self.selected_entities[etype] = None config_dict = {etype.to_state_type(): [config] for etype, config in self.selected_configs.items()} with self.dont_push_selected_configs(): # Todo: check if for loop below is still required for etype, selected in self.selected_entities.items(): - config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0])) + if selected is not None: + config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0])) converters.set_configs_from_state(state, config_dict) return state def pull_selected_panel_configs(self, *events): with self.dont_push_selected_panel_configs(): for etype, panel_config in self.selected_panel_configs.items(): - panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict()) + if self.selected_entities[etype] is not None: + panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict()) def pull_all_data(self): self.pull_selected_configs() diff --git a/vivarium/interface/panel_app.py b/vivarium/interface/panel_app.py index 76dfbd8..a8c740d 100644 --- a/vivarium/interface/panel_app.py +++ b/vivarium/interface/panel_app.py @@ -32,8 +32,9 @@ def __init__(self, config, panel_configs, panel_simulator_config, selected, etyp self.cds.on_change('data', self.drag_cb) self.cds_view = self.create_cds_view() self.panel_simulator_config.param.watch(self.hide_all_non_existing, "hide_non_existing") - selected.param.watch(self.update_selected_plot, ['selection'], - onlychanged=True, precedence=0) + if selected is not None: + selected.param.watch(self.update_selected_plot, ['selection'], + onlychanged=True, precedence=0) for i, pc in enumerate(self.panel_configs): pc.param.watch(self.update_cds_view, pc.param_names(), onlychanged=True) self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=False) @@ -58,10 +59,11 @@ def update_cds(self, state): def create_cds_view(self): # For each attribute in the panel config, create a filter # that is a logical AND of the visibility and the attribute + params = self.panel_configs[0].param_names() if len(self.panel_configs) else [] return { attr: CDSView(filter=BooleanFilter( [getattr(pc, attr) and pc.visible for pc in self.panel_configs] - )) for attr in self.panel_configs[0].param_names() + )) for attr in params } def update_cds_view(self, event): @@ -214,7 +216,7 @@ def __init__(self, **kwargs): panel_simulator_config=self.controller.panel_simulator_config, selected=self.controller.selected_entities[etype], etype=etype, state=self.controller.state) - for etype, manager_class in self.entity_manager_classes.items() + for etype, manager_class in self.entity_manager_classes.items() if len(self.controller.configs[etype.to_state_type()]) } self.plot = self.create_plot() @@ -262,8 +264,8 @@ def create_plot(self): p.add_tools(hover) p.x_range = Range1d(0, self.controller.simulator_config.box_size) p.y_range = Range1d(0, self.controller.simulator_config.box_size) - draw_tool = PointDrawTool(renderers=[self.entity_managers[etype].plot(p) - for etype in EntityType], add=False) + draw_tool = PointDrawTool(renderers=[em.plot(p) + for em in self.entity_managers.values()], add=False) p.add_tools(draw_tool) return p