Skip to content

Commit

Permalink
scatterplot: interactive sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Jul 11, 2023
1 parent 9c7c014 commit 831f9ab
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 21 deletions.
58 changes: 45 additions & 13 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from itertools import chain
from xml.sax.saxutils import escape

import dask
import numpy as np
from scipy.stats import linregress
from sklearn.neighbors import NearestNeighbors
Expand All @@ -16,6 +17,7 @@

from Orange.data import Table, Domain, DiscreteVariable, Variable
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
from Orange.data.dask import DaskTable
from Orange.preprocess.score import ReliefF, RReliefF

from Orange.widgets import gui, report
Expand Down Expand Up @@ -353,7 +355,8 @@ def __init__(self):
self.cb_attr_y: ComboBoxSearch = None
self.vizrank: ScatterPlotVizRank = None
self.vizrank_button: QPushButton = None
self.sampling: QGroupBox = None
self.sql_sampling_box: QGroupBox = None
self.sampling_box: QGroupBox = None
self._xy_invalidated: bool = True

self.sql_data = None # Orange.data.sql.table.SqlTable
Expand All @@ -367,6 +370,8 @@ def __init__(self):
for w in [MatplotlibFormat, MatplotlibPDFFormat]:
self.graph_writers.append(w)

self.cached_x_data, self.cached_y_data = None, None

def _add_controls(self):
self._add_controls_axis()
self._add_controls_sampling()
Expand Down Expand Up @@ -411,10 +416,27 @@ def _add_controls_axis(self):
vizrank_box, self, "Find Informative Projections", self.set_attr)

def _add_controls_sampling(self):
self.sampling = gui.auto_commit(
self.controlArea, self, "auto_sample", "Sample", box="Sampling",
callback=self.switch_sampling, commit=lambda: self.add_data(1))
self.sampling.setVisible(False)
# Make gui boxes as there is no info about input data at this stage.
# This is temporary and probably need refactoring.

self.sql_sampling_box = gui.auto_commit(self.sampling_box, self,
"auto_sample",
"Sample", box='Sampling',
callback=self.switch_sampling,
commit=lambda: self.add_data(1))
self.sql_sampling_box.setVisible(False)

self.sampling_box = gui.vBox(self.controlArea, 'Sampling',
spacing=2 if gui.is_macstyle() else 8)

gui.spin(self.sampling_box, self, 'graph.sample_size',
minv=1000, maxv=10_000, step=500, label='Sample Size:',
callback=self.graph.resample_current_view_range)

gui.button(self.sampling_box, self, label='Resample',
callback=self.graph.resample_current_view_range)

self.sampling_box.setVisible(False)

@property
def effective_variables(self):
Expand Down Expand Up @@ -473,7 +495,8 @@ def findvar(name, iterable):
def check_data(self):
super().check_data()
self.__timer.stop()
self.sampling.setVisible(False)
self.sql_sampling_box.setVisible(False)
self.sampling_box.setVisible(False)
self.sql_data = None
if isinstance(self.data, SqlTable):
if self.data.approx_len() < 4000:
Expand All @@ -484,10 +507,13 @@ def check_data(self):
data_sample = self.data.sample_time(0.8, no_cache=True)
data_sample.download_data(2000, partial=True)
self.data = Table(data_sample)
self.sampling.setVisible(True)
self.sql_sampling_box.setVisible(True)
if self.auto_sample:
self.__timer.start()

elif isinstance(self.data, DaskTable):
self.sampling_box.setVisible(True)

if self.data is not None and (len(self.data) == 0 or
len(self.data.domain.variables) == 0):
self.data = None
Expand All @@ -497,18 +523,19 @@ def get_embedding(self):
if self.data is None:
return None

x_data = self.get_column(self.attr_x, filter_valid=False)
y_data = self.get_column(self.attr_y, filter_valid=False)
if x_data is None or y_data is None:
return None
if self.cached_x_data is None or self.cached_y_data is None:
self.cached_x_data = self.get_column(self.attr_x, filter_valid=False)
self.cached_y_data = self.get_column(self.attr_y, filter_valid=False)
if isinstance(self.data, DaskTable):
self.cached_x_data, self.cached_y_data = dask.compute(self.cached_x_data, self.cached_y_data)

self.Warning.missing_coords.clear()
self.Information.missing_coords.clear()
self.valid_data = np.isfinite(x_data) & np.isfinite(y_data)
self.valid_data = np.isfinite(self.cached_x_data) & np.isfinite(self.cached_y_data)
if self.valid_data is not None and not np.all(self.valid_data):
msg = self.Information if np.any(self.valid_data) else self.Warning
msg.missing_coords(self.attr_x.name, self.attr_y.name)
return np.vstack((x_data, y_data)).T
return np.vstack((self.cached_x_data, self.cached_y_data)).T

# Tooltip
def _point_tooltip(self, point_id, skip_attrs=()):
Expand Down Expand Up @@ -579,19 +606,23 @@ def handleNewSignals(self):
self.vizrank.setEnabled(False)
self._invalidated = self._invalidated or self._xy_invalidated
self._xy_invalidated = False
self.cached_x_data = None
self.cached_y_data = None
super().handleNewSignals()
if self._domain_invalidated:
self.graph.update_axes()
self._domain_invalidated = False
self.cb_reg_line.setEnabled(self.can_draw_regresssion_line())


@Inputs.features
def set_shown_attributes(self, attributes):
if attributes and len(attributes) >= 2:
self.attribute_selection_list = attributes[:2]
self._xy_invalidated = self._xy_invalidated \
or self.attr_x != attributes[0] \
or self.attr_y != attributes[1]

else:
self.attribute_selection_list = None

Expand All @@ -601,6 +632,7 @@ def set_attr(self, attr_x, attr_y):
self.attr_changed()

def set_attr_from_combo(self):
self.cached_x_data, self.cached_y_data = None, None
self.attr_changed()
self.xy_changed_manually.emit(self.attr_x, self.attr_y)

Expand Down
168 changes: 160 additions & 8 deletions Orange/widgets/visualize/owscatterplotgraph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import itertools
import warnings
import random
from typing import Callable
from xml.sax.saxutils import escape
from datetime import datetime, timezone
Expand All @@ -18,6 +19,7 @@
from pyqtgraph.graphicsItems.LegendItem import LegendItem as PgLegendItem
from pyqtgraph.graphicsItems.TextItem import TextItem

from Orange.data.dask import DaskTable
from Orange.preprocess.discretize import _time_binnings
from Orange.util import utc_from_timestamp
from Orange.widgets import gui
Expand Down Expand Up @@ -504,6 +506,7 @@ def get_size_data(self):
show_legend = Setting(True)
class_density = Setting(False)
jitter_size = Setting(0)
sample_size = Setting(1000)

resolution = 256

Expand Down Expand Up @@ -548,7 +551,6 @@ def __init__(self, scatter_widget, parent=None, view_box=ViewBox):

self.n_valid = 0
self.n_shown = 0
self.sample_size = None
self.sample_indices = None

self.palette = None
Expand All @@ -571,10 +573,156 @@ def __init__(self, scatter_widget, parent=None, view_box=ViewBox):
self.view_box.sigTransformChanged.connect(self.update_density)
self.view_box.sigRangeChangedManually.connect(self.update_labels)

self.timer = None
self.x1, self.x2 = None, None
self.y1, self.y2 = None, None

def sample_points_in_current_view():
if not isinstance(self.master.data, DaskTable):
return
x_data = self.master.cached_x_data
y_data = self.master.cached_y_data
min_x, max_x = np.min(x_data), np.max(x_data)
min_y, max_y = np.min(y_data), np.max(y_data)

if self.x1 is None or self.x2 is None or \
self.y1 is None or self.y2 is None:
self.x1, self.x2 = min_x, max_x
self.y1, self.y2 = min_y, max_y

# Current graph view range
new_x1, new_x2, new_y1, new_y2 = self._get_current_view_range()

# Data points in previous view range
mask_old = np.flatnonzero(
((x_data > self.x1) & (x_data < self.x2)) &
((y_data > self.y1) & (y_data < self.y2))
)

# Data points in new view range
mask_new = np.flatnonzero(
((x_data > new_x1) & (x_data < new_x2)) &
((y_data > new_y1) & (y_data < new_y2))
)

# Find the differance of data points between current and previous views
sample_indices = set(self.sample_indices)
mask_old = set(mask_old)
mask_new = set(mask_new)


is_zoom_in = all((new_x1 > self.x1,
new_x2 < self.x2,
new_y1 > self.y1,
new_y2 < self.y2))

# is_zoom_out = all((new_x1 < self.x1,
# new_x2 > self.x2,
# new_y1 < self.y1,
# new_y2 > self.y2))

# Update previous view ranges
self.x1, self.x2 = new_x1, new_x2
self.y1, self.y2 = new_y1, new_y2

if is_zoom_in:
# remove currently sampled data points that got zoomed out.
zoomed_out_samples = sample_indices & (mask_old - mask_new)
samples_to_keep = sample_indices - zoomed_out_samples

# data points that we should sample now must not already be sampled
new_sample_candidates = mask_new - samples_to_keep
k = self.sample_size - len(samples_to_keep)

# sample size iz larger than number of available data
if len(mask_new) < self.sample_size:
return

if len(new_sample_candidates) > k:
new_sample_candidates = set(
random.sample(new_sample_candidates, k=k))

# join currently sampled data points with proportion of new one in the
# zoomed in area
self.sample_indices = np.array(
list(samples_to_keep | new_sample_candidates))

else:
diff = mask_new - mask_old

# Sampled and non-sampled data points in the overlapped area.
intersect_data_points = mask_old & mask_new

# Data points that are already sampled and are in the overlapped area.
intersect_sampled_points = intersect_data_points & sample_indices

# Non-sampled datapoints in the overlapped area.
intersect_non_sampled_points = intersect_data_points - sample_indices

# sample size iz larger than number of available data
if len(mask_new) < self.sample_size:
return

if not len(diff):

# do nothing
if len(intersect_sampled_points) == self.sample_size:
return

# add remaining data samples
self.sample_indices = np.array(list(intersect_sampled_points | set(random.sample(intersect_non_sampled_points, k=self.sample_size - len(intersect_sampled_points)))))

else:
# size ratio between arrays
ratio = len(intersect_data_points) / len(diff)

num_of_samples_overlap = round(ratio / (ratio + 1) * self.sample_size)
num_of_samples_diff = self.sample_size - num_of_samples_overlap

m = num_of_samples_overlap - len(intersect_sampled_points)
if m > 0:
overlap_samples = intersect_sampled_points | set(random.sample(intersect_non_sampled_points, k=m))
new_samples = set(random.sample(diff, k=num_of_samples_diff))
else:
to_remove = set(random.sample(intersect_sampled_points, k=abs(m)))
overlap_samples = intersect_sampled_points - to_remove
new_samples = set(random.sample(diff, k=num_of_samples_diff))

self.sample_indices = np.array(list(overlap_samples | new_samples))

self.clear()
self.update_coordinates()
self.update_point_props()

self._proxy_sigRangeChanged = pg.SignalProxy(
# self.plot_widget.plotItem.sigRangeChangedManually, slot=test, delay=0.5
self.view_box.sigRangeChangedManually, slot=sample_points_in_current_view, delay=0.4
)
self.timer = None
self.parameter_setter = ScatterBaseParameterSetter(self)

def resample_current_view_range(self):
x1, x2, y1, y2 = self._get_current_view_range()
x_data = self.master.cached_x_data
y_data = self.master.cached_y_data
mask = np.flatnonzero(
((x_data > x1) & (x_data < x2)) &
((y_data > y1) & (y_data < y2))
)

self.sample_indices = None
self._create_sample(mask)
self.clear()
self.update_coordinates()
self.update_point_props()

def _get_current_view_range(self):
x_axis, y_axis = self.view_box.state['viewRange']

x1, x2 = x_axis
y1, y2 = y_axis

return x1, x2, y1, y2

def _create_legend(self, anchor):
legend = LegendItem()
legend.setParentItem(self.plot_widget.getViewBox())
Expand Down Expand Up @@ -778,7 +926,10 @@ def get_coordinates(self):
self.n_valid = self.n_shown = 0
return None, None
self.n_valid = len(x)
self._create_sample()

if self.sample_indices is None:
self._create_sample()

x = self._filter_visible(x)
y = self._filter_visible(y)
# Jittering after sampling is OK if widgets do not change the sample
Expand All @@ -790,17 +941,19 @@ def get_coordinates(self):
x, y = self.jitter_coordinates(x, y)
return x, y

def _create_sample(self):
def _create_sample(self, mask=None):
"""
Create a random sample if the data is larger than the set sample size
"""
self.n_shown = min(self.n_valid, self.sample_size or self.n_valid)
if self.sample_size is not None \
and self.sample_indices is None \
and self.n_valid != self.n_shown:
random = np.random.RandomState(seed=0)

random = np.random.RandomState()
self.sample_indices = random.choice(
self.n_valid, self.n_shown, replace=False)
mask if mask is not None else self.n_valid, self.n_shown, replace=False)

# TODO: Is this really needed?
np.sort(self.sample_indices)

Expand Down Expand Up @@ -838,8 +991,7 @@ def update_coordinates(self):
x, y = self.get_coordinates()
if x is None or len(x) == 0:
return

self._reset_view(x, y)
# self._reset_view(x, y)
if self.scatterplot_item is None:
if self.sample_indices is None:
indices = np.arange(self.n_valid)
Expand Down
Loading

0 comments on commit 831f9ab

Please sign in to comment.