Skip to content

Commit

Permalink
FreeViz: Allow setting ratio btw attractive and repulsive forces
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Aug 25, 2023
1 parent ca5bac8 commit 1947679
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
17 changes: 12 additions & 5 deletions Orange/projection/freeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class FreeViz(LinearProjector):
projection = FreeVizModel

def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1,
initial=None, maxiter=500, alpha=0.1, gravity=None,
atol=1e-5, preprocessors=None):
super().__init__(preprocessors=preprocessors)
self.weights = weights
Expand All @@ -33,6 +33,7 @@ def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
self.maxiter = maxiter
self.alpha = alpha
self.atol = atol
self.gravity = gravity
self.is_class_discrete = False
self.components_ = None

Expand All @@ -50,6 +51,7 @@ def get_components(self, X, Y):
X, Y, weights=self.weights, center=self.center, scale=self.scale,
dim=self.dim, p=self.p, initial=self.initial,
maxiter=self.maxiter, alpha=self.alpha, atol=self.atol,
gravity=self.gravity,
is_class_discrete=self.is_class_discrete)[1].T

@classmethod
Expand Down Expand Up @@ -104,7 +106,7 @@ def forces_regression(cls, distances, y, p=1):
return F

@classmethod
def forces_classification(cls, distances, y, p=1):
def forces_classification(cls, distances, y, p=1, gravity=None):
diffclass = scipy.spatial.distance.pdist(y.reshape(-1, 1), "hamming") != 0
# handle attractive force
if p == 1:
Expand All @@ -120,6 +122,8 @@ def forces_classification(cls, distances, y, p=1):
F[mask] = 1 / distances[mask]
else:
F[mask] = 1 / (distances[mask] ** p)
if gravity is not None:
F[mask] *= -np.sum(F[~mask]) / np.sum(F[mask]) / gravity

Check warning on line 126 in Orange/projection/freeviz.py

View check run for this annotation

Codecov / codecov/patch

Orange/projection/freeviz.py#L126

Added line #L126 was not covered by tests
return F

@classmethod
Expand Down Expand Up @@ -180,7 +184,8 @@ def gradient(cls, X, embeddings, forces, embedding_dist=None, weights=None):
return G

@classmethod
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=False):
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None,
gravity=None, is_class_discrete=False):
"""
Return the gradient for the FreeViz [1]_ projection.
Expand Down Expand Up @@ -214,7 +219,7 @@ def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=
assert X.ndim == 2 and X.shape[0] == y.shape[0] == embedding.shape[0]
D = scipy.spatial.distance.pdist(embedding)
if is_class_discrete:
forces = cls.forces_classification(D, y, p=p)
forces = cls.forces_classification(D, y, p=p, gravity=gravity)
else:
forces = cls.forces_regression(D, y, p=p)
G = cls.gradient(X, embedding, forces, embedding_dist=D, weights=weights)
Expand All @@ -234,7 +239,8 @@ def _rotate(cls, A):

@classmethod
def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1, atol=1e-5, is_class_discrete=False):
initial=None, maxiter=500, alpha=0.1, atol=1e-5, gravity=None,
is_class_discrete=False):
"""
FreeViz
Expand Down Expand Up @@ -341,6 +347,7 @@ def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
step_i = 0
while step_i < maxiter:
G = cls.freeviz_gradient(X, y, embeddings, p=p, weights=weights,
gravity=gravity,
is_class_discrete=is_class_discrete)

# Scale the changes (the largest anchor move is alpha * radius)
Expand Down
44 changes: 42 additions & 2 deletions Orange/widgets/visualize/owfreeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from AnyQt.QtCore import Qt, QRectF, QLineF, QPoint
from AnyQt.QtGui import QPalette
from AnyQt.QtGui import QPalette, QFontMetrics
from AnyQt.QtWidgets import QSizePolicy

import pyqtgraph as pg
Expand Down Expand Up @@ -137,9 +137,13 @@ class OWFreeViz(OWAnchorProjectionWidget, ConcurrentWidgetMixin):

settings_version = 3
initialization = settings.Setting(InitType.Circular)
balance = settings.Setting(False)
gravity_index = settings.Setting(4)
GRAPH_CLASS = OWFreeVizGraph
graph = settings.SettingProvider(OWFreeVizGraph)

GravityValues = [0.1, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5]

class Error(OWAnchorProjectionWidget.Error):
no_class_var = widget.Msg("Data must have a target variable.")
multiple_class_vars = widget.Msg(
Expand All @@ -159,6 +163,7 @@ class Warning(OWAnchorProjectionWidget.Warning):
def __init__(self):
OWAnchorProjectionWidget.__init__(self)
ConcurrentWidgetMixin.__init__(self)
self.__optimized = False

def _add_controls(self):
self.__add_controls_start_box()
Expand All @@ -177,6 +182,20 @@ def __add_controls_start_box(self):
callback=self.__init_combo_changed,
sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
)
box2 = gui.hBox(box)
gui.checkBox(
box2, self, "balance", "Gravity",
callback=self.__gravity_changed)
self.grav_slider = gui.hSlider(
box2, self, "gravity_index",
minValue=0, maxValue=len(self.GravityValues) - 1,
callback=self.__gravity_dragged, createLabel=False)
self.gravity_label = gui.widgetLabel(box2)
self.gravity_label.setFixedWidth(
max(QFontMetrics(self.font()).horizontalAdvance(str(x))
for x in self.GravityValues))
self.gravity_label.setAlignment(Qt.AlignRight)
self.__update_gravity_label()
self.run_button = gui.button(box, self, "Start", self._toggle_run)

@property
Expand All @@ -189,6 +208,21 @@ def effective_data(self):
return self.data.transform(Domain(self.effective_variables,
self.data.domain.class_vars))

def __gravity_dragged(self):
self.balance = True
self.__update_gravity_label()
self.__gravity_changed()

Check warning on line 214 in Orange/widgets/visualize/owfreeviz.py

View check run for this annotation

Codecov / codecov/patch

Orange/widgets/visualize/owfreeviz.py#L212-L214

Added lines #L212 - L214 were not covered by tests

def __update_gravity_label(self):
self.gravity_label.setText(str(self.GravityValues[self.gravity_index]))

def __gravity_changed(self):
gravity = self.GravityValues[self.gravity_index]
self.projector.gravity = gravity if self.balance else None
self.__update_gravity_label()
if self.task is None and self.__optimized:
self._run()

Check warning on line 224 in Orange/widgets/visualize/owfreeviz.py

View check run for this annotation

Codecov / codecov/patch

Orange/widgets/visualize/owfreeviz.py#L220-L224

Added lines #L220 - L224 were not covered by tests

def __radius_slider_changed(self):
self.graph.update_radius()

Expand Down Expand Up @@ -232,6 +266,7 @@ def on_done(self, result: Result):
self.projection = result.projection
self.graph.set_sample_size(None)
self.run_button.setText("Start")
self.__optimized = True
self.commit.deferred()

def on_exception(self, ex: Exception):
Expand All @@ -253,14 +288,19 @@ def init_projection(self):
anchors = FreeViz.init_radial(len(self.effective_variables)) \
if self.initialization == InitType.Circular \
else FreeViz.init_random(len(self.effective_variables), 2)
if self.balance:
gravity = self.GravityValues[self.gravity_index]

Check warning on line 292 in Orange/widgets/visualize/owfreeviz.py

View check run for this annotation

Codecov / codecov/patch

Orange/widgets/visualize/owfreeviz.py#L292

Added line #L292 was not covered by tests
else:
gravity = None
self.projector = FreeViz(scale=False, center=False,
initial=anchors, maxiter=10)
initial=anchors, maxiter=10, gravity=gravity)
data = self.projector.preprocess(self.effective_data)
self.projector.domain = data.domain
self.projector.components_ = anchors.T
self.projection = FreeVizModel(self.projector, self.projector.domain, 2)
self.projection.pre_domain = data.domain
self.projection.name = self.projector.name
self.__optimized = False

def check_data(self):
def error(err):
Expand Down

0 comments on commit 1947679

Please sign in to comment.