diff --git a/Orange/projection/base.py b/Orange/projection/base.py index b07f0cb39e1..1dae7f7b21b 100644 --- a/Orange/projection/base.py +++ b/Orange/projection/base.py @@ -1,3 +1,5 @@ +import warnings + import copy import inspect import threading @@ -9,6 +11,7 @@ from Orange.data.util import SharedComputeValue, get_unique_names from Orange.misc.wrapper_meta import WrapperMeta from Orange.preprocess import RemoveNaNRows +from Orange.util import dummy_callback, wrap_callback, OrangeDeprecationWarning import Orange.preprocess __all__ = ["LinearCombinationSql", "Projector", "Projection", "SklProjector", @@ -44,17 +47,36 @@ def fit(self, X, Y=None): raise NotImplementedError( "Classes derived from Projector must overload method fit") - def __call__(self, data): - data = self.preprocess(data) + def __call__(self, data, progress_callback=None): + if progress_callback is None: + progress_callback = dummy_callback + progress_callback(0, "Preprocessing...") + try: + cb = wrap_callback(progress_callback, end=0.1) + data = self.preprocess(data, progress_callback=cb) + except TypeError: + data = self.preprocess(data) + warnings.warn("A keyword argument 'progress_callback' has been " + "added to the preprocess() signature. Implementing " + "the method without the argument is deprecated and " + "will result in an error in the future.", + OrangeDeprecationWarning, stacklevel=2) self.domain = data.domain + progress_callback(0.1, "Fitting...") clf = self.fit(data.X, data.Y) clf.pre_domain = data.domain clf.name = self.name + progress_callback(1) return clf - def preprocess(self, data): - for pp in self.preprocessors: + def preprocess(self, data, progress_callback=None): + if progress_callback is None: + progress_callback = dummy_callback + n_pps = len(self.preprocessors) + for i, pp in enumerate(self.preprocessors): + progress_callback(i / n_pps) data = pp(data) + progress_callback(1) return data # Projectors implemented using `fit` access the `domain` through the @@ -208,8 +230,8 @@ def _get_sklparams(self, values): raise TypeError("Wrapper does not define '__wraps__'") return params - def preprocess(self, data): - data = super().preprocess(data) + def preprocess(self, data, progress_callback=None): + data = super().preprocess(data, progress_callback) if any(v.is_discrete and len(v.values) > 2 for v in data.domain.attributes): raise ValueError("Wrapped scikit-learn methods do not support " diff --git a/Orange/widgets/unsupervised/owpca.py b/Orange/widgets/unsupervised/owpca.py index 0a917fe0f11..29611146fb2 100644 --- a/Orange/widgets/unsupervised/owpca.py +++ b/Orange/widgets/unsupervised/owpca.py @@ -12,6 +12,7 @@ from Orange.preprocess import preprocess from Orange.projection import PCA from Orange.widgets import widget, gui, settings +from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin from Orange.widgets.utils.slidergraph import SliderGraph from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.widget import Input, Output @@ -21,7 +22,7 @@ LINE_NAMES = ["component variance", "cumulative variance"] -class OWPCA(widget.OWWidget): +class OWPCA(widget.OWWidget, ConcurrentWidgetMixin): name = "PCA" description = "Principal component analysis with a scree-diagram." icon = "icons/PCA.svg" @@ -57,13 +58,13 @@ class Error(widget.OWWidget.Error): def __init__(self): super().__init__() - self.data = None + ConcurrentWidgetMixin.__init__(self) + self.data = None self._pca = None self._transformed = None self._variance_ratio = None self._cumulative = None - self._init_projector() # Components Selection form = QFormLayout() @@ -114,6 +115,7 @@ def __init__(self): @Inputs.data def set_data(self, data): + self.cancel() self.clear_messages() self.clear() self.information() @@ -138,12 +140,11 @@ def set_data(self, data): self.clear_outputs() return - self._init_projector() - self.data = data self.fit() def fit(self): + self.cancel() self.clear() self.Warning.trivial_components.clear() if self.data is None: @@ -151,27 +152,45 @@ def fit(self): data = self.data - if self.normalize: - self._pca_projector.preprocessors = \ - self._pca_preprocessors + [preprocess.Normalize(center=False)] - else: - self._pca_projector.preprocessors = self._pca_preprocessors + projector = self._create_projector() if not isinstance(data, SqlTable): - pca = self._pca_projector(data) - variance_ratio = pca.explained_variance_ratio_ - cumulative = numpy.cumsum(variance_ratio) - - if numpy.isfinite(cumulative[-1]): - self.components_spin.setRange(0, len(cumulative)) - self._pca = pca - self._variance_ratio = variance_ratio - self._cumulative = cumulative - self._setup_plot() - else: - self.Warning.trivial_components() + self.start(self._call_projector, data, projector) + + @staticmethod + def _call_projector(data: Table, projector, state): + + def callback(i: float, status=""): + state.set_progress_value(i * 100) + if status: + state.set_status(status) + if state.is_interruption_requested(): + raise Exception # pylint: disable=broad-exception-raised + + return projector(data, progress_callback=callback) + + def on_done(self, result): + pca = result + variance_ratio = pca.explained_variance_ratio_ + cumulative = numpy.cumsum(variance_ratio) + + if numpy.isfinite(cumulative[-1]): + self.components_spin.setRange(0, len(cumulative)) + self._pca = pca + self._variance_ratio = variance_ratio + self._cumulative = cumulative + self._setup_plot() + else: + self.Warning.trivial_components() + + self.commit.now() - self.commit.now() + def on_partial_result(self, result): + pass + + def onDeleteWidget(self): + self.shutdown() + super().onDeleteWidget() def clear(self): self._pca = None @@ -184,7 +203,7 @@ def clear_outputs(self): self.Outputs.transformed_data.send(None) self.Outputs.data.send(None) self.Outputs.components.send(None) - self.Outputs.pca.send(self._pca_projector) + self.Outputs.pca.send(self._create_projector()) def _setup_plot(self): if self._pca is None: @@ -251,10 +270,13 @@ def _update_normalize(self): if self.data is None: self._invalidate_selection() - def _init_projector(self): - self._pca_projector = PCA(n_components=MAX_COMPONENTS, random_state=0) - self._pca_projector.component = self.ncomponents - self._pca_preprocessors = PCA.preprocessors + def _create_projector(self): + projector = PCA(n_components=MAX_COMPONENTS, random_state=0) + projector.component = self.ncomponents # for use as a Scorer + if self.normalize: + projector.preprocessors = \ + PCA.preprocessors + [preprocess.Normalize(center=False)] + return projector def _nselected_components(self): """Return the number of selected components.""" @@ -338,11 +360,10 @@ def commit(self): numpy.hstack((self.data.metas, transformed.X)), ids=self.data.ids) - self._pca_projector.component = self.ncomponents self.Outputs.transformed_data.send(transformed) self.Outputs.components.send(components) self.Outputs.data.send(data) - self.Outputs.pca.send(self._pca_projector) + self.Outputs.pca.send(self._create_projector()) def send_report(self): if self.data is None: diff --git a/Orange/widgets/unsupervised/tests/test_owpca.py b/Orange/widgets/unsupervised/tests/test_owpca.py index 3ea5009afe5..993871980a0 100644 --- a/Orange/widgets/unsupervised/tests/test_owpca.py +++ b/Orange/widgets/unsupervised/tests/test_owpca.py @@ -7,7 +7,6 @@ from Orange.data import Table, Domain, ContinuousVariable, TimeVariable from Orange.preprocess import preprocess -from Orange.preprocess.preprocess import Normalize from Orange.widgets.tests.base import WidgetTest from Orange.widgets.tests.utils import table_dense_sparse, possible_duplicate_table from Orange.widgets.unsupervised.owpca import OWPCA @@ -33,7 +32,7 @@ def test_constant_data(self): # Ignore the warning: the test checks whether the widget shows # Warning.trivial_components when this happens with np.errstate(invalid="ignore"): - self.send_signal(self.widget.Inputs.data, data) + self.send_signal(self.widget.Inputs.data, data, wait=5000) self.assertTrue(self.widget.Warning.trivial_components.is_shown()) self.assertIsNone(self.get_output(self.widget.Outputs.transformed_data)) self.assertIsNone(self.get_output(self.widget.Outputs.components)) @@ -56,12 +55,12 @@ def test_limit_components(self): X = np.random.RandomState(0).rand(101, 101) data = Table.from_numpy(None, X) self.widget.ncomponents = 100 - self.send_signal(self.widget.Inputs.data, data) + self.send_signal(self.widget.Inputs.data, data, wait=5000) tran = self.get_output(self.widget.Outputs.transformed_data) self.assertEqual(len(tran.domain.attributes), 100) self.widget.ncomponents = 101 # should not be accesible with self.assertRaises(IndexError): - self.send_signal(self.widget.Inputs.data, data) + self.widget._setup_plot() # pylint: disable=protected-access def test_migrate_settings_limits_components(self): settings = dict(ncomponents=10) @@ -84,9 +83,11 @@ def test_variance_shown(self): self.send_signal(self.widget.Inputs.data, self.iris) self.widget.maxp = 2 self.widget._setup_plot() + self.wait_until_finished() var2 = self.widget.variance_covered self.widget.ncomponents = 3 self.widget._update_selection_component_spin() + self.wait_until_finished() var3 = self.widget.variance_covered self.assertGreater(var3, var2) @@ -98,10 +99,11 @@ def test_unique_domain_components(self): def test_variance_attr(self): self.widget.ncomponents = 2 - self.send_signal(self.widget.Inputs.data, self.iris) + self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) self.wait_until_stop_blocking() self.widget._variance_ratio = np.array([0.5, 0.25, 0.2, 0.05]) self.widget.commit.now() + self.wait_until_finished() result = self.get_output(self.widget.Outputs.transformed_data) pc1, pc2 = result.domain.attributes @@ -162,8 +164,8 @@ def test_normalize_data(self, prepare_table): # Enable checkbox self.widget.controls.normalize.setChecked(True) self.assertTrue(self.widget.controls.normalize.isChecked()) - with patch.object(preprocess, "Normalize", wraps=Normalize) as normalize: - self.send_signal(self.widget.Inputs.data, data) + with patch.object(preprocess.Normalize, "__call__", wraps=lambda x: x) as normalize: + self.send_signal(self.widget.Inputs.data, data, wait=5000) self.wait_until_stop_blocking() self.assertTrue(self.widget.controls.normalize.isEnabled()) normalize.assert_called_once() @@ -171,8 +173,8 @@ def test_normalize_data(self, prepare_table): # Disable checkbox self.widget.controls.normalize.setChecked(False) self.assertFalse(self.widget.controls.normalize.isChecked()) - with patch.object(preprocess, "Normalize", wraps=Normalize) as normalize: - self.send_signal(self.widget.Inputs.data, data) + with patch.object(preprocess.Normalize, "__call__", wraps=lambda x: x) as normalize: + self.send_signal(self.widget.Inputs.data, data, wait=5000) self.wait_until_stop_blocking() self.assertTrue(self.widget.controls.normalize.isEnabled()) normalize.assert_not_called() @@ -185,13 +187,14 @@ def test_normalization_variance(self, prepare_table): # Enable normalization self.widget.controls.normalize.setChecked(True) self.assertTrue(self.widget.normalize) - self.send_signal(self.widget.Inputs.data, data) + self.send_signal(self.widget.Inputs.data, data, wait=5000) self.wait_until_stop_blocking() variance_normalized = self.widget.variance_covered # Disable normalization self.widget.controls.normalize.setChecked(False) self.assertFalse(self.widget.normalize) + self.wait_until_finished() self.wait_until_stop_blocking() variance_unnormalized = self.widget.variance_covered