Skip to content

Commit

Permalink
OWDistributions: support for Dask Tables
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Jul 12, 2023
1 parent 82e8b7b commit 8c55b47
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
12 changes: 12 additions & 0 deletions Orange/widgets/visualize/owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import count, groupby, repeat
from xml.sax.saxutils import escape

import dask
import numpy as np
from scipy.stats import norm, rayleigh, beta, gamma, pareto, expon

Expand All @@ -13,6 +14,7 @@
import pyqtgraph as pg

from Orange.data import Table, DiscreteVariable, ContinuousVariable, Domain
from Orange.data.dask import DaskTable
from Orange.preprocess.discretize import decimal_binnings, time_binnings, \
short_time_units
from Orange.statistics import distribution, contingency
Expand All @@ -29,6 +31,7 @@
LegendItem as SPGLegendItem



class ScatterPlotItem(pg.ScatterPlotItem):
Symbols = pg.graphicsItems.ScatterPlotItem.Symbols

Expand Down Expand Up @@ -444,6 +447,7 @@ def set_data(self, data):
self.var = varmodel[min(len(domain.class_vars), len(varmodel) - 1)]
if domain is not None and domain.has_discrete_class:
self.cvar = domain.class_var

self.reset_select()
self._user_var_bins.clear()
self.openContext(domain)
Expand Down Expand Up @@ -539,12 +543,17 @@ def set_valid_data(self):
return

column = self.data.get_column(self.var)
if isinstance(self.data, DaskTable):
column = dask.compute(column)[0]

valid_mask = np.isfinite(column)
if not np.any(valid_mask):
self.Error.no_defined_values_var(self.var.name)
return
if self.cvar:
ccolumn = self.data.get_column(self.cvar)
if isinstance(self.data, DaskTable):
ccolumn = dask.compute(ccolumn)[0]
valid_mask *= np.isfinite(ccolumn)
if not np.any(valid_mask):
self.Error.no_defined_values_pair(self.var.name, self.cvar.name)
Expand Down Expand Up @@ -882,6 +891,9 @@ def recompute_binnings(self):
if self.is_valid and self.var.is_continuous:
# binning is computed on valid var data, ignoring any cvar nans
column = self.data.get_column(self.var)
if isinstance(self.data, DaskTable):
column = dask.compute(column)[0]

if np.any(np.isfinite(column)):
if self.var.is_time:
self.binnings = time_binnings(column, min_unique=5)
Expand Down
14 changes: 12 additions & 2 deletions Orange/widgets/visualize/tests/test_owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from orangewidget.utils.combobox import qcombobox_emit_activated

from Orange.data import Table, Domain, DiscreteVariable
from Orange.tests.test_dasktable import temp_dasktable
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_FEATURE_NAME
from Orange.widgets.utils.itemmodels import DomainModel
Expand All @@ -21,6 +22,7 @@ class TestOWDistributions(WidgetTest):
def setUp(self):
self.widget = self.create_widget(OWDistributions) #: OWDistributions
self.iris = Table("iris")
self.heart_disease = Table("heart_disease")

def _set_cvar(self, cvar):
combo = self.widget.controls.cvar
Expand Down Expand Up @@ -526,7 +528,7 @@ def test_report(self):
widget.send_report()

def test_sort_by_freq_no_split(self):
data = Table("heart_disease")
data = self.heart_disease
domain = data.domain
sort_by_freq = self.widget.controls.sort_by_freq

Expand All @@ -549,7 +551,7 @@ def test_sort_by_freq_no_split(self):
self.assertEqual(out[1][1], 97)

def test_sort_by_freq_split(self):
data = Table("heart_disease")
data = self.heart_disease
domain = data.domain
sort_by_freq = self.widget.controls.sort_by_freq

Expand All @@ -576,5 +578,13 @@ def test_sort_by_freq_split(self):
self.assertEqual(out[4][2], 45)


class TestOWDistributionsWithDask(TestOWDistributions):

def setUp(self):
self.widget = self.create_widget(OWDistributions)
self.iris = temp_dasktable("iris")
self.heart_disease = temp_dasktable("heart_disease")


if __name__ == "__main__":
unittest.main()

0 comments on commit 8c55b47

Please sign in to comment.