Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when setting n_agents or n_objects to 0 #55

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def set_state_from_config_dict(config_dict, state=None):
state = state or get_default_state(n_entities_dict)
e_idx = jnp.zeros(sum(n_entities_dict.values()), dtype=int)
for stype, configs in config_dict.items():
params = configs[0].param_names()
params = configs[0].param_names() if len(configs) > 0 else []
for p in params:
state_field_info = configs_to_state_dict[stype][p]
nve_idx = [c.idx for c in configs] if state_field_info.nested_field[0] == 'nve_state' else range(len(configs))
Expand Down Expand Up @@ -248,6 +248,8 @@ def set_configs_from_state(state, config_dict=None):
for f in state_field_info.nested_field:
value = getattr(value, f)
for config in config_dict[stype]:
if config is None:
continue
t = type(getattr(config, param))
row_idx = state.row_idx(state_field_info.nested_field[0], config.idx)
if state_field_info.column_idx is None:
Expand Down
28 changes: 21 additions & 7 deletions vivarium/controllers/panel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class Selected(param.Parameterized):

def selection_nve_idx(self, nve_idx):
return nve_idx[np.array(self.selection)].tolist()

def __len__(self):
return len(self.selection)


class PanelController(SimulatorController):
Expand All @@ -57,16 +60,18 @@ def __init__(self, **params):
self.panel_configs = {stype: [stype_to_panel_config[stype]() for _ in range(len(configs))]
for stype, configs in self.configs.items()}
self.selected_panel_configs = {EntityType.AGENT: PanelAgentConfig(), EntityType.OBJECT: PanelObjectConfig()}
self.pull_selected_panel_configs()
self.panel_simulator_config = PanelSimulatorConfig()

self.update_entity_list()
for etype, selected in self.selected_entities.items():
selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1)
selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True)
if selected is not None:
selected.param.watch(self.pull_selected_configs, ['selection'], onlychanged=True, precedence=1)
selected.param.watch(self.pull_selected_panel_configs, ['selection'], onlychanged=True)

def watch_selected_configs(self):
watchers = {etype: config.param.watch(self.push_selected_to_config_list, config.param_names(), onlychanged=True)
for etype, config in self.selected_configs.items()}
for etype, config in self.selected_configs.items() if config is not None}
return watchers

def watch_selected_panel_configs(self):
Expand All @@ -78,7 +83,8 @@ def watch_selected_panel_configs(self):
def dont_push_selected_configs(self):
if self._selected_configs_watchers is not None:
for etype, config in self.selected_configs.items():
config.param.unwatch(self._selected_configs_watchers[etype])
if config is not None:
config.param.unwatch(self._selected_configs_watchers[etype])
try:
yield
finally:
Expand All @@ -97,22 +103,30 @@ def dont_push_selected_panel_configs(self):
def update_entity_list(self, *events):
state = self.state
for etype, selected in self.selected_entities.items():
selected.param.selection.objects = state.entity_idx(etype).tolist()
if selected is not None:
selected.param.selection.objects = state.entity_idx(etype).tolist()

def pull_selected_configs(self, *events):
state = self.state

for etype, config in self.selected_configs.items():
if state.field(etype.to_state_type()).nve_idx.shape[0] == 0:
self.selected_configs[etype] = None
self.selected_entities[etype] = None
config_dict = {etype.to_state_type(): [config] for etype, config in self.selected_configs.items()}
with self.dont_push_selected_configs():
# Todo: check if for loop below is still required
for etype, selected in self.selected_entities.items():
config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0]))
if selected is not None:
config_dict[etype.to_state_type()][0].idx = int(state.nve_idx(etype.to_state_type(), selected.selection[0]))
converters.set_configs_from_state(state, config_dict)
return state

def pull_selected_panel_configs(self, *events):
with self.dont_push_selected_panel_configs():
for etype, panel_config in self.selected_panel_configs.items():
panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict())
if self.selected_entities[etype] is not None:
panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict())

def pull_all_data(self):
self.pull_selected_configs()
Expand Down
31 changes: 17 additions & 14 deletions vivarium/interface/panel_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def __init__(self, config, panel_configs, panel_simulator_config, selected, etyp
self.cds.on_change('data', self.drag_cb)
self.cds_view = self.create_cds_view()
self.panel_simulator_config.param.watch(self.hide_all_non_existing, "hide_non_existing")
selected.param.watch(self.update_selected_plot, ['selection'],
onlychanged=True, precedence=0)
if selected is not None:
selected.param.watch(self.update_selected_plot, ['selection'],
onlychanged=True, precedence=0)
for i, pc in enumerate(self.panel_configs):
pc.param.watch(self.update_cds_view, pc.param_names(), onlychanged=True)
self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=False)
self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=True)

def drag_cb(self, attr, old, new):
for i, c in enumerate(self.config):
Expand All @@ -58,15 +59,16 @@ def update_cds(self, state):
def create_cds_view(self):
# For each attribute in the panel config, create a filter
# that is a logical AND of the visibility and the attribute
params = self.panel_configs[0].param_names() if len(self.panel_configs) else []
return {
attr: CDSView(filter=BooleanFilter(
[getattr(pc, attr) and pc.visible for pc in self.panel_configs]
)) for attr in self.panel_configs[0].param_names()
)) for attr in params
}

def update_cds_view(self, event):
n = event.name
for attr in [n] if n != "visible" else self.panel_configs[0].param_names():
for attr in [n] if n != "visible" else (self.panel_configs[0].param_names() if len(self.panel_configs) else []):
f = [getattr(pc, attr) and pc.visible for pc in self.panel_configs]
self.cds_view[attr].filter = BooleanFilter(f)

Expand Down Expand Up @@ -214,7 +216,7 @@ def __init__(self, **kwargs):
panel_simulator_config=self.controller.panel_simulator_config,
selected=self.controller.selected_entities[etype], etype=etype,
state=self.controller.state)
for etype, manager_class in self.entity_manager_classes.items()
for etype, manager_class in self.entity_manager_classes.items() if len(self.controller.configs[etype.to_state_type()])
}

self.plot = self.create_plot()
Expand All @@ -231,8 +233,9 @@ def start_toggle_cb(self, event):


def entity_toggle_cb(self, event):
for i, t in enumerate(self.config_types):
self.config_columns[i].visible = t in event.new
self.config_columns[0].visible = "SIMULATOR" in event.new
for i, t in enumerate(self.entity_managers.keys()):
self.config_columns[i].visible = t.name in event.new

def update_timestep_cb(self, event):
self.pcb_plot.period = event.new
Expand Down Expand Up @@ -262,25 +265,25 @@ def create_plot(self):
p.add_tools(hover)
p.x_range = Range1d(0, self.controller.simulator_config.box_size)
p.y_range = Range1d(0, self.controller.simulator_config.box_size)
draw_tool = PointDrawTool(renderers=[self.entity_managers[etype].plot(p)
for etype in EntityType], add=False)
draw_tool = PointDrawTool(renderers=[em.plot(p)
for em in self.entity_managers.values()], add=False)
p.add_tools(draw_tool)
return p

def create_app(self):
self.config_columns = pn.Row(*
[pn.Column(
pn.pane.Markdown("### SIMULATOR", align="center"),
pn.panel(self.controller.panel_simulator_config, name="Visualization configurations"),
pn.panel(self.controller.panel_simulator_config, name="Visualization configuration"),
pn.panel(self.controller.simulator_config, name="Configurations"),
visible=False, sizing_mode="scale_height", scroll=True)] +
[pn.Column(
pn.pane.Markdown(f"### {etype.name}", align="center"),
self.controller.selected_entities[etype],
pn.panel(self.controller.selected_panel_configs[etype], name="Visualization configurations"),
pn.panel(self.controller.selected_configs[etype], name="State configurations"),
pn.panel(self.controller.selected_panel_configs[etype], name="Visualization configuration"),
pn.panel(self.controller.selected_configs[etype], name="State configuration"),
visible=True, sizing_mode="scale_height", scroll=True)
for etype in EntityType])
for etype in self.entity_managers.keys()])

app = pn.Row(pn.Column(pn.Row(pn.pane.Markdown("### Start/Stop server", align="center"),
self.start_toggle),
Expand Down
Loading