Skip to content

Commit

Permalink
Merge pull request #6511 from JakaKokosar/dask-violinplot
Browse files Browse the repository at this point in the history
[ENH] OWViolinPlot: support for Dask Tables
  • Loading branch information
markotoplak committed Oct 10, 2023
2 parents a4d6537 + ad0a0f4 commit a82ec37
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
15 changes: 15 additions & 0 deletions Orange/widgets/visualize/owviolinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import chain, count
from typing import List, Optional, Tuple, Set, Sequence

import dask
import numpy as np
from scipy import stats
from sklearn.neighbors import KernelDensity
Expand All @@ -19,6 +20,7 @@
VisualSettingsDialog

from Orange.data import ContinuousVariable, DiscreteVariable, Table
from Orange.data.dask import DaskTable
from Orange.widgets import gui
from Orange.widgets.settings import ContextSetting, DomainContextHandler, \
Setting
Expand Down Expand Up @@ -997,6 +999,8 @@ def compute_score(attr):
if attr is group_var:
return 3
col = self.data.get_column(attr)
if isinstance(self.data, DaskTable):
col = dask.compute(col)[0]
groups = (col[group_col == i] for i in range(n_groups))
groups = (col[~np.isnan(col)] for col in groups)
groups = [group for group in groups if len(group)]
Expand All @@ -1011,6 +1015,8 @@ def compute_score(attr):
if self.order_by_importance and group_var is not None:
n_groups = len(group_var.values)
group_col = self.data.get_column(group_var)
if isinstance(self.data, DaskTable):
group_col = dask.compute(group_var)[0]
self._sort_list(self._value_var_model, self._value_var_view,
compute_score)
else:
Expand All @@ -1023,6 +1029,8 @@ def compute_stat(group):
if group is None:
return -1
col = self.data.get_column(group)
if isinstance(self.data, DaskTable):
col = dask.compute(col)[0]
groups = (value_col[col == i] for i in range(len(group.values)))
groups = (col[~np.isnan(col)] for col in groups)
groups = [group for group in groups if len(group)]
Expand All @@ -1036,6 +1044,8 @@ def compute_stat(group):
value_var = self.value_var
if self.order_grouping_by_importance:
value_col = self.data.get_column(value_var)
if isinstance(self.data, DaskTable):
value_col = dask.compute(value_col)[0]
self._sort_list(self._group_var_model, self._group_var_view,
compute_stat)
else:
Expand Down Expand Up @@ -1069,9 +1079,14 @@ def setup_plot(self):
return

y = self.data.get_column(self.value_var)
if isinstance(self.data, DaskTable):
y = dask.compute(y)[0]

x = None
if self.group_var:
x = self.data.get_column(self.group_var)
if isinstance(self.data, DaskTable):
x = dask.compute(x)[0]
self.graph.set_data(y, self.value_var, x, self.group_var)

def apply_selection(self):
Expand Down
26 changes: 24 additions & 2 deletions Orange/widgets/visualize/tests/test_owviolinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyqtgraph import ViewBox

from Orange.data import Table
from Orange.tests.test_dasktable import temp_dasktable
from Orange.widgets.tests.base import datasets, simulate, \
WidgetOutputsTestMixin, WidgetTest
from Orange.widgets.visualize.owviolinplot import OWViolinPlot, \
Expand All @@ -34,6 +35,7 @@ def setUpClass(cls):
cls.signal_name = OWViolinPlot.Inputs.data
cls.signal_data = cls.data
cls.housing = Table("housing")
cls.zoo = Table("zoo")

def setUp(self):
self.widget = self.create_widget(OWViolinPlot)
Expand All @@ -50,8 +52,7 @@ def test_kernels(self):
simulate.combobox_activate_item(kernel_combo, kernel)

def test_no_cont_features(self):
data = Table("zoo")
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.data, self.zoo)
self.assertTrue(self.widget.Error.no_cont_features.is_shown())
self.send_signal(self.widget.Inputs.data, None)
self.assertFalse(self.widget.Error.no_cont_features.is_shown())
Expand Down Expand Up @@ -392,5 +393,26 @@ def __select_value(list_, value):
idx, QItemSelectionModel.ClearAndSelect)


class TestOWViolinPlotWithDask(TestOWViolinPlot):

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.data = temp_dasktable("iris")
cls.signal_data = cls.data
cls.housing = temp_dasktable("housing")
cls.zoo = temp_dasktable("zoo")

def test_datasets(self):
self.widget.controls.show_strip_plot.setChecked(True)
self.widget.controls.show_rug_plot.setChecked(True)
for ds in datasets.datasets():
ds = temp_dasktable(ds)
self.send_signal(self.widget.Inputs.data, ds)
for i in range(3):
cb = self.widget.controls.scale_index
simulate.combobox_activate_index(cb, i)


if __name__ == "__main__":
unittest.main()

0 comments on commit a82ec37

Please sign in to comment.