diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index e3d91ae..13c58df 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -218,7 +218,7 @@ def set_state_from_config_dict(config_dict, state=None): state = state or get_default_state(n_entities_dict) e_idx = jnp.zeros(sum(n_entities_dict.values()), dtype=int) for stype, configs in config_dict.items(): - params = configs[0].param_names() + params = configs[0].param_names() if len(configs) > 0 else [] for p in params: state_field_info = configs_to_state_dict[stype][p] nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'nve_state' else range(len(configs)) @@ -248,6 +248,8 @@ def set_configs_from_state(state, config_dict=None): for f in state_field_info.nested_field: value = getattr(value, f) for config in config_dict[stype]: + if config is None: + continue t = type(getattr(config, param)) row_idx = state.row_idx(state_field_info.nested_field[0], config.idx) if state_field_info.column_idx is None: diff --git a/vivarium/controllers/panel_controller.py b/vivarium/controllers/panel_controller.py index 485ad7e..a481eeb 100644 --- a/vivarium/controllers/panel_controller.py +++ b/vivarium/controllers/panel_controller.py @@ -44,6 +44,9 @@ class Selected(param.Parameterized): def selection_nve_idx(self, nve_idx): return nve_idx[np.array(self.selection)].tolist() + + def __len__(self): + return len(self.selection) class PanelController(SimulatorController): @@ -57,16 +60,18 @@ def __init__(self, **params): self.panel_configs = {stype: [stype_to_panel_config[stype]() for _ in range(len(configs))] for stype, configs in self.configs.items()} self.selected_panel_configs = {EntityType.AGENT: PanelAgentConfig(), EntityType.OBJECT: PanelObjectConfig()} + self.pull_selected_panel_configs() self.panel_simulator_config = PanelSimulatorConfig() self.update_entity_list() for etype, selected in self.selected_entities.items(): - selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1) - selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True) + if selected is not None: + selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1) + selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True) def watch_selected_configs(self): watchers = {etype: config.param.watch(self.push_selected_to_config_list, config.param_names(), onlychanged=True) - for etype, config in self.selected_configs.items()} + for etype, config in self.selected_configs.items() if config is not None} return watchers def watch_selected_panel_configs(self): @@ -78,7 +83,8 @@ def watch_selected_panel_configs(self): def dont_push_selected_configs(self): if self._selected_configs_watchers is not None: for etype, config in self.selected_configs.items(): - config.param.unwatch(self._selected_configs_watchers[etype]) + if config is not None: + config.param.unwatch(self._selected_configs_watchers[etype]) try: yield finally: @@ -97,22 +103,30 @@ def dont_push_selected_panel_configs(self): def update_entity_list(self, *events): state = self.state for etype, selected in self.selected_entities.items(): - selected.param.selection.objects = state.entity_idx(etype).tolist() + if selected is not None: + selected.param.selection.objects = state.entity_idx(etype).tolist() def pull_selected_configs(self, *events): state = self.state + + for etype, config in self.selected_configs.items(): + if state.field(etype.to_state_type()).nve_idx.shape[0] == 0: + self.selected_configs[etype] = None + self.selected_entities[etype] = None config_dict = {etype.to_state_type(): [config] for etype, config in self.selected_configs.items()} with self.dont_push_selected_configs(): # Todo: check if for loop below is still required for etype, selected in self.selected_entities.items(): - config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0])) + if selected is not None: + config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0])) converters.set_configs_from_state(state, config_dict) return state def pull_selected_panel_configs(self, *events): with self.dont_push_selected_panel_configs(): for etype, panel_config in self.selected_panel_configs.items(): - panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict()) + if self.selected_entities[etype] is not None: + panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict()) def pull_all_data(self): self.pull_selected_configs() diff --git a/vivarium/interface/panel_app.py b/vivarium/interface/panel_app.py index 76dfbd8..c7da929 100644 --- a/vivarium/interface/panel_app.py +++ b/vivarium/interface/panel_app.py @@ -32,11 +32,12 @@ def __init__(self, config, panel_configs, panel_simulator_config, selected, etyp self.cds.on_change('data', self.drag_cb) self.cds_view = self.create_cds_view() self.panel_simulator_config.param.watch(self.hide_all_non_existing, "hide_non_existing") - selected.param.watch(self.update_selected_plot, ['selection'], - onlychanged=True, precedence=0) + if selected is not None: + selected.param.watch(self.update_selected_plot, ['selection'], + onlychanged=True, precedence=0) for i, pc in enumerate(self.panel_configs): pc.param.watch(self.update_cds_view, pc.param_names(), onlychanged=True) - self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=False) + self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=True) def drag_cb(self, attr, old, new): for i, c in enumerate(self.config): @@ -58,15 +59,16 @@ def update_cds(self, state): def create_cds_view(self): # For each attribute in the panel config, create a filter # that is a logical AND of the visibility and the attribute + params = self.panel_configs[0].param_names() if len(self.panel_configs) else [] return { attr: CDSView(filter=BooleanFilter( [getattr(pc, attr) and pc.visible for pc in self.panel_configs] - )) for attr in self.panel_configs[0].param_names() + )) for attr in params } def update_cds_view(self, event): n = event.name - for attr in [n] if n != "visible" else self.panel_configs[0].param_names(): + for attr in [n] if n != "visible" else (self.panel_configs[0].param_names() if len(self.panel_configs) else []): f = [getattr(pc, attr) and pc.visible for pc in self.panel_configs] self.cds_view[attr].filter = BooleanFilter(f) @@ -214,7 +216,7 @@ def __init__(self, **kwargs): panel_simulator_config=self.controller.panel_simulator_config, selected=self.controller.selected_entities[etype], etype=etype, state=self.controller.state) - for etype, manager_class in self.entity_manager_classes.items() + for etype, manager_class in self.entity_manager_classes.items() if len(self.controller.configs[etype.to_state_type()]) } self.plot = self.create_plot() @@ -231,8 +233,9 @@ def start_toggle_cb(self, event): def entity_toggle_cb(self, event): - for i, t in enumerate(self.config_types): - self.config_columns[i].visible = t in event.new + self.config_columns[0].visible = "SIMULATOR" in event.new + for i, t in enumerate(self.entity_managers.keys()): + self.config_columns[i].visible = t.name in event.new def update_timestep_cb(self, event): self.pcb_plot.period = event.new @@ -262,8 +265,8 @@ def create_plot(self): p.add_tools(hover) p.x_range = Range1d(0, self.controller.simulator_config.box_size) p.y_range = Range1d(0, self.controller.simulator_config.box_size) - draw_tool = PointDrawTool(renderers=[self.entity_managers[etype].plot(p) - for etype in EntityType], add=False) + draw_tool = PointDrawTool(renderers=[em.plot(p) + for em in self.entity_managers.values()], add=False) p.add_tools(draw_tool) return p @@ -271,16 +274,16 @@ def create_app(self): self.config_columns = pn.Row(* [pn.Column( pn.pane.Markdown("### SIMULATOR", align="center"), - pn.panel(self.controller.panel_simulator_config, name="Visualization configurations"), + pn.panel(self.controller.panel_simulator_config, name="Visualization configuration"), pn.panel(self.controller.simulator_config, name="Configurations"), visible=False, sizing_mode="scale_height", scroll=True)] + [pn.Column( pn.pane.Markdown(f"### {etype.name}", align="center"), self.controller.selected_entities[etype], - pn.panel(self.controller.selected_panel_configs[etype], name="Visualization configurations"), - pn.panel(self.controller.selected_configs[etype], name="State configurations"), + pn.panel(self.controller.selected_panel_configs[etype], name="Visualization configuration"), + pn.panel(self.controller.selected_configs[etype], name="State configuration"), visible=True, sizing_mode="scale_height", scroll=True) - for etype in EntityType]) + for etype in self.entity_managers.keys()]) app = pn.Row(pn.Column(pn.Row(pn.pane.Markdown("### Start/Stop server", align="center"), self.start_toggle),