diff --git a/Orange/widgets/visualize/owdistributions.py b/Orange/widgets/visualize/owdistributions.py index bd6f656c443..dbc61af47cf 100644 --- a/Orange/widgets/visualize/owdistributions.py +++ b/Orange/widgets/visualize/owdistributions.py @@ -2,6 +2,7 @@ from itertools import count, groupby, repeat from xml.sax.saxutils import escape +import dask import numpy as np from scipy.stats import norm, rayleigh, beta, gamma, pareto, expon @@ -13,6 +14,7 @@ import pyqtgraph as pg from Orange.data import Table, DiscreteVariable, ContinuousVariable, Domain +from Orange.data.dask import DaskTable from Orange.preprocess.discretize import decimal_binnings, time_binnings, \ short_time_units from Orange.statistics import distribution, contingency @@ -29,6 +31,7 @@ LegendItem as SPGLegendItem + class ScatterPlotItem(pg.ScatterPlotItem): Symbols = pg.graphicsItems.ScatterPlotItem.Symbols @@ -444,6 +447,7 @@ def set_data(self, data): self.var = varmodel[min(len(domain.class_vars), len(varmodel) - 1)] if domain is not None and domain.has_discrete_class: self.cvar = domain.class_var + self.reset_select() self._user_var_bins.clear() self.openContext(domain) @@ -539,12 +543,17 @@ def set_valid_data(self): return column = self.data.get_column(self.var) + if isinstance(self.data, DaskTable): + column = dask.compute(column)[0] + valid_mask = np.isfinite(column) if not np.any(valid_mask): self.Error.no_defined_values_var(self.var.name) return if self.cvar: ccolumn = self.data.get_column(self.cvar) + if isinstance(self.data, DaskTable): + ccolumn = dask.compute(ccolumn)[0] valid_mask *= np.isfinite(ccolumn) if not np.any(valid_mask): self.Error.no_defined_values_pair(self.var.name, self.cvar.name) @@ -882,6 +891,9 @@ def recompute_binnings(self): if self.is_valid and self.var.is_continuous: # binning is computed on valid var data, ignoring any cvar nans column = self.data.get_column(self.var) + if isinstance(self.data, DaskTable): + column = dask.compute(column)[0] + if np.any(np.isfinite(column)): if self.var.is_time: self.binnings = time_binnings(column, min_unique=5) diff --git a/Orange/widgets/visualize/tests/test_owdistributions.py b/Orange/widgets/visualize/tests/test_owdistributions.py index 66c3ca2f1ea..6322379f827 100644 --- a/Orange/widgets/visualize/tests/test_owdistributions.py +++ b/Orange/widgets/visualize/tests/test_owdistributions.py @@ -11,6 +11,7 @@ from orangewidget.utils.combobox import qcombobox_emit_activated from Orange.data import Table, Domain, DiscreteVariable +from Orange.tests.test_dasktable import temp_dasktable from Orange.widgets.tests.base import WidgetTest from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_FEATURE_NAME from Orange.widgets.utils.itemmodels import DomainModel @@ -21,6 +22,7 @@ class TestOWDistributions(WidgetTest): def setUp(self): self.widget = self.create_widget(OWDistributions) #: OWDistributions self.iris = Table("iris") + self.heart_disease = Table("heart_disease") def _set_cvar(self, cvar): combo = self.widget.controls.cvar @@ -526,7 +528,7 @@ def test_report(self): widget.send_report() def test_sort_by_freq_no_split(self): - data = Table("heart_disease") + data = self.heart_disease domain = data.domain sort_by_freq = self.widget.controls.sort_by_freq @@ -549,7 +551,7 @@ def test_sort_by_freq_no_split(self): self.assertEqual(out[1][1], 97) def test_sort_by_freq_split(self): - data = Table("heart_disease") + data = self.heart_disease domain = data.domain sort_by_freq = self.widget.controls.sort_by_freq @@ -576,5 +578,13 @@ def test_sort_by_freq_split(self): self.assertEqual(out[4][2], 45) +class TestOWDistributionsWithDask(TestOWDistributions): + + def setUp(self): + self.widget = self.create_widget(OWDistributions) + self.iris = temp_dasktable("iris") + self.heart_disease = temp_dasktable("heart_disease") + + if __name__ == "__main__": unittest.main()