Skip to content

Commit

Permalink
FreeViz: Allow setting ratio btw attractive and repulsive forces
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jul 19, 2023
1 parent ff152a9 commit 3de6f8b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
17 changes: 12 additions & 5 deletions Orange/projection/freeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class FreeViz(LinearProjector):
projection = FreeVizModel

def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1,
initial=None, maxiter=500, alpha=0.1, gravity=None,
atol=1e-5, preprocessors=None):
super().__init__(preprocessors=preprocessors)
self.weights = weights
Expand All @@ -33,6 +33,7 @@ def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
self.maxiter = maxiter
self.alpha = alpha
self.atol = atol
self.gravity = gravity
self.is_class_discrete = False
self.components_ = None

Expand All @@ -50,6 +51,7 @@ def get_components(self, X, Y):
X, Y, weights=self.weights, center=self.center, scale=self.scale,
dim=self.dim, p=self.p, initial=self.initial,
maxiter=self.maxiter, alpha=self.alpha, atol=self.atol,
gravity=self.gravity,
is_class_discrete=self.is_class_discrete)[1].T

@classmethod
Expand Down Expand Up @@ -104,7 +106,7 @@ def forces_regression(cls, distances, y, p=1):
return F

@classmethod
def forces_classification(cls, distances, y, p=1):
def forces_classification(cls, distances, y, p=1, gravity=None):
diffclass = scipy.spatial.distance.pdist(y.reshape(-1, 1), "hamming") != 0
# handle attractive force
if p == 1:
Expand All @@ -120,6 +122,8 @@ def forces_classification(cls, distances, y, p=1):
F[mask] = 1 / distances[mask]
else:
F[mask] = 1 / (distances[mask] ** p)
if gravity is not None:
F[mask] *= -np.sum(F[~mask]) / np.sum(F[mask]) / gravity
return F

@classmethod
Expand Down Expand Up @@ -180,7 +184,8 @@ def gradient(cls, X, embeddings, forces, embedding_dist=None, weights=None):
return G

@classmethod
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=False):
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None,
gravity=None, is_class_discrete=False):
"""
Return the gradient for the FreeViz [1]_ projection.
Expand Down Expand Up @@ -214,7 +219,7 @@ def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=
assert X.ndim == 2 and X.shape[0] == y.shape[0] == embedding.shape[0]
D = scipy.spatial.distance.pdist(embedding)
if is_class_discrete:
forces = cls.forces_classification(D, y, p=p)
forces = cls.forces_classification(D, y, p=p, gravity=gravity)
else:
forces = cls.forces_regression(D, y, p=p)
G = cls.gradient(X, embedding, forces, embedding_dist=D, weights=weights)
Expand All @@ -234,7 +239,8 @@ def _rotate(cls, A):

@classmethod
def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1, atol=1e-5, is_class_discrete=False):
initial=None, maxiter=500, alpha=0.1, atol=1e-5, gravity=None,
is_class_discrete=False):
"""
FreeViz
Expand Down Expand Up @@ -341,6 +347,7 @@ def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
step_i = 0
while step_i < maxiter:
G = cls.freeviz_gradient(X, y, embeddings, p=p, weights=weights,
gravity=gravity,
is_class_discrete=is_class_discrete)

# Scale the changes (the largest anchor move is alpha * radius)
Expand Down
41 changes: 39 additions & 2 deletions Orange/widgets/visualize/owfreeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from AnyQt.QtCore import Qt, QRectF, QLineF, QPoint
from AnyQt.QtGui import QPalette
from AnyQt.QtGui import QPalette, QFontMetrics
from AnyQt.QtWidgets import QSizePolicy

import pyqtgraph as pg
Expand Down Expand Up @@ -137,9 +137,13 @@ class OWFreeViz(OWAnchorProjectionWidget, ConcurrentWidgetMixin):

settings_version = 3
initialization = settings.Setting(InitType.Circular)
balance = False
gravity_index = 4
GRAPH_CLASS = OWFreeVizGraph
graph = settings.SettingProvider(OWFreeVizGraph)

GravityValues = [0.1, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5]

class Error(OWAnchorProjectionWidget.Error):
no_class_var = widget.Msg("Data must have a target variable.")
multiple_class_vars = widget.Msg(
Expand All @@ -159,6 +163,7 @@ class Warning(OWAnchorProjectionWidget.Warning):
def __init__(self):
OWAnchorProjectionWidget.__init__(self)
ConcurrentWidgetMixin.__init__(self)
self.__optimized = False

def _add_controls(self):
self.__add_controls_start_box()
Expand All @@ -177,6 +182,20 @@ def __add_controls_start_box(self):
callback=self.__init_combo_changed,
sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
)
box2 = gui.hBox(box)
gui.checkBox(
box2, self, "balance", "Gravity",
callback=self.__gravity_changed)
self.grav_slider = gui.hSlider(
box2, self, "gravity_index",
minValue=0, maxValue=len(self.GravityValues) - 1,
callback=self.__gravity_dragged, createLabel=False)
self.gravity_label = gui.widgetLabel(box2)
self.gravity_label.setFixedWidth(
max(QFontMetrics(self.font()).horizontalAdvance(str(x))
for x in self.GravityValues))
self.gravity_label.setAlignment(Qt.AlignRight)
self.__update_gravity_label()
self.run_button = gui.button(box, self, "Start", self._toggle_run)

@property
Expand All @@ -189,6 +208,21 @@ def effective_data(self):
return self.data.transform(Domain(self.effective_variables,
self.data.domain.class_vars))

def __gravity_dragged(self):
self.balance = True
self.__update_gravity_label()
self.__gravity_changed()

def __update_gravity_label(self):
self.gravity_label.setText(str(self.GravityValues[self.gravity_index]))

def __gravity_changed(self):
gravity = self.GravityValues[self.gravity_index]
self.projector.gravity = gravity if self.balance else None
self.__update_gravity_label()
if self.task is None and self.__optimized:
self._run()

def __radius_slider_changed(self):
self.graph.update_radius()

Expand Down Expand Up @@ -232,6 +266,7 @@ def on_done(self, result: Result):
self.projection = result.projection
self.graph.set_sample_size(None)
self.run_button.setText("Start")
self.__optimized = True
self.commit.deferred()

def on_exception(self, ex: Exception):
Expand All @@ -254,13 +289,15 @@ def init_projection(self):
if self.initialization == InitType.Circular \
else FreeViz.init_random(len(self.effective_variables), 2)
self.projector = FreeViz(scale=False, center=False,
initial=anchors, maxiter=10)
initial=anchors, maxiter=10,
gravity=self.gravity if self.balance else None)
data = self.projector.preprocess(self.effective_data)
self.projector.domain = data.domain
self.projector.components_ = anchors.T
self.projection = FreeVizModel(self.projector, self.projector.domain, 2)
self.projection.pre_domain = data.domain
self.projection.name = self.projector.name
self.__optimized = False

def check_data(self):
def error(err):
Expand Down

0 comments on commit 3de6f8b

Please sign in to comment.