diff --git a/Orange/projection/freeviz.py b/Orange/projection/freeviz.py index 8e9451fcfec..83437f97485 100644 --- a/Orange/projection/freeviz.py +++ b/Orange/projection/freeviz.py @@ -21,7 +21,7 @@ class FreeViz(LinearProjector): projection = FreeVizModel def __init__(self, weights=None, center=True, scale=True, dim=2, p=1, - initial=None, maxiter=500, alpha=0.1, + initial=None, maxiter=500, alpha=0.1, gravity=None, atol=1e-5, preprocessors=None): super().__init__(preprocessors=preprocessors) self.weights = weights @@ -33,6 +33,7 @@ def __init__(self, weights=None, center=True, scale=True, dim=2, p=1, self.maxiter = maxiter self.alpha = alpha self.atol = atol + self.gravity = gravity self.is_class_discrete = False self.components_ = None @@ -50,6 +51,7 @@ def get_components(self, X, Y): X, Y, weights=self.weights, center=self.center, scale=self.scale, dim=self.dim, p=self.p, initial=self.initial, maxiter=self.maxiter, alpha=self.alpha, atol=self.atol, + gravity=self.gravity, is_class_discrete=self.is_class_discrete)[1].T @classmethod @@ -104,7 +106,7 @@ def forces_regression(cls, distances, y, p=1): return F @classmethod - def forces_classification(cls, distances, y, p=1): + def forces_classification(cls, distances, y, p=1, gravity=None): diffclass = scipy.spatial.distance.pdist(y.reshape(-1, 1), "hamming") != 0 # handle attractive force if p == 1: @@ -120,6 +122,8 @@ def forces_classification(cls, distances, y, p=1): F[mask] = 1 / distances[mask] else: F[mask] = 1 / (distances[mask] ** p) + if gravity is not None: + F[mask] *= -np.sum(F[~mask]) / np.sum(F[mask]) / gravity return F @classmethod @@ -180,7 +184,8 @@ def gradient(cls, X, embeddings, forces, embedding_dist=None, weights=None): return G @classmethod - def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=False): + def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, + gravity=None, is_class_discrete=False): """ Return the gradient for the FreeViz [1]_ projection. @@ -214,7 +219,7 @@ def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete= assert X.ndim == 2 and X.shape[0] == y.shape[0] == embedding.shape[0] D = scipy.spatial.distance.pdist(embedding) if is_class_discrete: - forces = cls.forces_classification(D, y, p=p) + forces = cls.forces_classification(D, y, p=p, gravity=gravity) else: forces = cls.forces_regression(D, y, p=p) G = cls.gradient(X, embedding, forces, embedding_dist=D, weights=weights) @@ -234,7 +239,8 @@ def _rotate(cls, A): @classmethod def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1, - initial=None, maxiter=500, alpha=0.1, atol=1e-5, is_class_discrete=False): + initial=None, maxiter=500, alpha=0.1, atol=1e-5, gravity=None, + is_class_discrete=False): """ FreeViz @@ -341,6 +347,7 @@ def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1, step_i = 0 while step_i < maxiter: G = cls.freeviz_gradient(X, y, embeddings, p=p, weights=weights, + gravity=gravity, is_class_discrete=is_class_discrete) # Scale the changes (the largest anchor move is alpha * radius) diff --git a/Orange/widgets/visualize/owfreeviz.py b/Orange/widgets/visualize/owfreeviz.py index 1110a1e207e..ecc27b72332 100644 --- a/Orange/widgets/visualize/owfreeviz.py +++ b/Orange/widgets/visualize/owfreeviz.py @@ -5,7 +5,7 @@ import numpy as np from AnyQt.QtCore import Qt, QRectF, QLineF, QPoint -from AnyQt.QtGui import QPalette +from AnyQt.QtGui import QPalette, QFontMetrics from AnyQt.QtWidgets import QSizePolicy import pyqtgraph as pg @@ -137,9 +137,13 @@ class OWFreeViz(OWAnchorProjectionWidget, ConcurrentWidgetMixin): settings_version = 3 initialization = settings.Setting(InitType.Circular) + balance = settings.Setting(False) + gravity_index = settings.Setting(4) GRAPH_CLASS = OWFreeVizGraph graph = settings.SettingProvider(OWFreeVizGraph) + GravityValues = [0.1, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5] + class Error(OWAnchorProjectionWidget.Error): no_class_var = widget.Msg("Data must have a target variable.") multiple_class_vars = widget.Msg( @@ -159,6 +163,7 @@ class Warning(OWAnchorProjectionWidget.Warning): def __init__(self): OWAnchorProjectionWidget.__init__(self) ConcurrentWidgetMixin.__init__(self) + self.__optimized = False def _add_controls(self): self.__add_controls_start_box() @@ -177,6 +182,20 @@ def __add_controls_start_box(self): callback=self.__init_combo_changed, sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) ) + box2 = gui.hBox(box) + gui.checkBox( + box2, self, "balance", "Gravity", + callback=self.__gravity_changed) + self.grav_slider = gui.hSlider( + box2, self, "gravity_index", + minValue=0, maxValue=len(self.GravityValues) - 1, + callback=self.__gravity_dragged, createLabel=False) + self.gravity_label = gui.widgetLabel(box2) + self.gravity_label.setFixedWidth( + max(QFontMetrics(self.font()).horizontalAdvance(str(x)) + for x in self.GravityValues)) + self.gravity_label.setAlignment(Qt.AlignRight) + self.__update_gravity_label() self.run_button = gui.button(box, self, "Start", self._toggle_run) @property @@ -189,6 +208,21 @@ def effective_data(self): return self.data.transform(Domain(self.effective_variables, self.data.domain.class_vars)) + def __gravity_dragged(self): + self.balance = True + self.__gravity_changed() + + def __update_gravity_label(self): + self.gravity_label.setText(str(self.GravityValues[self.gravity_index])) + + def __gravity_changed(self): + gravity = self.GravityValues[self.gravity_index] + if self.projector is not None: + self.projector.gravity = gravity if self.balance else None + self.__update_gravity_label() + if self.task is None and self.__optimized: + self._run() + def __radius_slider_changed(self): self.graph.update_radius() @@ -232,6 +266,7 @@ def on_done(self, result: Result): self.projection = result.projection self.graph.set_sample_size(None) self.run_button.setText("Start") + self.__optimized = True self.commit.deferred() def on_exception(self, ex: Exception): @@ -253,14 +288,19 @@ def init_projection(self): anchors = FreeViz.init_radial(len(self.effective_variables)) \ if self.initialization == InitType.Circular \ else FreeViz.init_random(len(self.effective_variables), 2) + if self.balance: + gravity = self.GravityValues[self.gravity_index] + else: + gravity = None self.projector = FreeViz(scale=False, center=False, - initial=anchors, maxiter=10) + initial=anchors, maxiter=10, gravity=gravity) data = self.projector.preprocess(self.effective_data) self.projector.domain = data.domain self.projector.components_ = anchors.T self.projection = FreeVizModel(self.projector, self.projector.domain, 2) self.projection.pre_domain = data.domain self.projection.name = self.projector.name + self.__optimized = False def check_data(self): def error(err): diff --git a/Orange/widgets/visualize/tests/test_owfreeviz.py b/Orange/widgets/visualize/tests/test_owfreeviz.py index b27feb2aed0..f7a3d8554b7 100644 --- a/Orange/widgets/visualize/tests/test_owfreeviz.py +++ b/Orange/widgets/visualize/tests/test_owfreeviz.py @@ -1,7 +1,7 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring import unittest -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np @@ -156,6 +156,46 @@ def test_discrete_attributes(self): self.assertTrue(self.widget.Warning.removed_features.is_shown()) self.widget.run_button.click() + def test_gravity_slider(self): + w = self.widget + + w.balance = False + w.gravity_index = 0 + + w.grav_slider.setValue(2) + self.assertTrue(w.balance) + self.assertEqual(w.gravity_label.text(), str(w.GravityValues[2])) + + w.grav_slider.setValue(3) + self.assertTrue(w.balance) + self.assertEqual(w.gravity_label.text(), str(w.GravityValues[3])) + + assert w.projector is None + self.send_signal(self.widget.Inputs.data, Table("zoo")) + self.wait_until_finished() + assert w.projector is not None + + # w.projector.gravity has correct value if gravity was set before data + self.assertEqual(w.projector.gravity, w.GravityValues[3]) + + # ... and if set when the data is already present and projector exists + w.grav_slider.setValue(1) + self.assertEqual(w.projector.gravity, w.GravityValues[1]) + + # Check that optimization is restarted if the projection is optimized + with patch.object(w, "_run") as run, \ + patch.object(w, "_OWFreeViz__optimized", new=True): + w.grav_slider.setValue(2) + self.assertEqual(w.projector.gravity, w.GravityValues[2]) + run.assert_called_once() + + # Also, check that checkbox also does all that + run.reset_mock() + w.controls.balance.click() + self.assertFalse(w.balance) + self.assertIsNone(w.projector.gravity) + run.assert_called_once() + class TestOWFreeVizRunner(unittest.TestCase): @classmethod