From b43e7dbbf215c05b9b4cf590b0f8ef7c54c513f7 Mon Sep 17 00:00:00 2001 From: Clare Shanahan Date: Mon, 6 Jan 2025 16:13:24 -0500 Subject: [PATCH] row select --- jdaviz/configs/imviz/tests/test_catalogs.py | 44 +++++++++++++++++++++ jdaviz/core/template_mixin.py | 42 ++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/jdaviz/configs/imviz/tests/test_catalogs.py b/jdaviz/configs/imviz/tests/test_catalogs.py index 956c32af68..ef4695d165 100644 --- a/jdaviz/configs/imviz/tests/test_catalogs.py +++ b/jdaviz/configs/imviz/tests/test_catalogs.py @@ -392,3 +392,47 @@ def test_offline_ecsv_catalog_with_extra_columns(imviz_helper, image_2d_wcs, tmp assert item['is_extended'] == tbl['is_extended'][idx] assert float(item['roundness']) == tbl['roundness'][idx] assert float(item['sharpness']) == tbl['sharpness'][idx] + + +def test_select_catalog_table_rows(imviz_helper, image_2d_wcs, tmp_path): + + """Test the ``select_rows`` functionality on table in plugin.""" + + arr = np.ones((500, 500)) + ndd = NDData(arr, wcs=image_2d_wcs) + imviz_helper.load_data(ndd) + + # write out table to load back in + # NOTE: if we ever support loading Table obj directly, replace this and + # remove tmp_path + sky_coord = SkyCoord(ra=[337.49056532, 337.46086081, 337.46586081, + 337.46786081, 337.47586081, 337.47686081], + dec=[-20.80555273, -20.7777673, -20.7877673, + -20.7877673, -20.7877673, -20.7877673], unit='deg') + tbl = Table({'sky_centroid': [sky_coord], + 'label': ['Source_1', 'Source_2', 'Source_3', 'Source_4', + 'Source_5', 'Source_6']}) + tbl_file = str(tmp_path / 'test_catalog.ecsv') + tbl.write(tbl_file, overwrite=True) + + catalogs_plugin = imviz_helper.plugins['Catalog Search'] + plugin_table = catalogs_plugin._obj.table + + # load catalog + catalogs_plugin._obj.from_file = tbl_file + catalogs_plugin._obj.search() + + # select a single row: + plugin_table.select_rows(3) + assert len(plugin_table.selected_rows) == 1 + assert plugin_table.selected_rows[0]['Right Ascension (degrees)'] == '337.46786' + + # select multiple rows by indices + # plugin_table.select_rows([1, 2, 4]) + + # select a range of rows: + plugin_table.select_rows(slice(0, 3)) + assert len(plugin_table.selected_rows) == 3 + + # select rows with multi dim numpy slice + # plugin_table.select_rows(np.s_[0:2, 3:4]) diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index a678570a95..fc1a3cdc3a 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -4829,6 +4829,48 @@ def clear_table(self): self._qtable = None self._plugin.session.hub.broadcast(PluginTableModifiedMessage(sender=self)) + def select_rows(self, rows): + """ + Select rows from the current table by index, indices, or slice. + + This method first clears the currently selected rows by resetting + `self.selected_rows` to an empty list. The rows specified by the input + `rows` are then applied as the new selection. + + Parameters + ---------- + rows : int, list of int, slice, or tuple of slice + The rows to select. This can be: + - An integer specifying a single row index. + - A list of integers specifying multiple row indices. + - A slice object specifying a range of rows. + - A tuple of slices (e.g using np.s_) + + """ + + if isinstance(rows, tuple): + # if 'rows' are a numpy slice, then it could be + # a tuple of slices. + if isinstance(rows[0], slice): + selected = [] + for sl in rows: + selected.append(self.items[rows][0]) + else: + selected = self.items[rows] + + + elif isinstance(rows, (slice, list)): + selected = self.items[rows] + elif isinstance(rows, int): + selected = [self.items[rows]] + + # first, deselect to revert current selection + self.selected_rows = [] + + # then apply current new selection + self.selected_rows = selected + + def vue_clear_table(self, data=None): # if the plugin (or via the TableMixin) has its own clear_table implementation, # call that, otherwise call the one defined here