Skip to content

Commit

Permalink
Clean up the distance function kwargs for DISE
Browse files Browse the repository at this point in the history
  • Loading branch information
FanwangM committed Oct 5, 2024
1 parent d080d9c commit 5c3037f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
22 changes: 4 additions & 18 deletions selector/methods/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#
# --
"""Module for Distance-Based Selection Methods."""
import inspect
import warnings

import bitarray
Expand Down Expand Up @@ -461,7 +462,7 @@ class DISE(SelectionBase):
"""

def __init__(
self, r0=None, ref_index=None, tol=0.05, n_iter=10, p=2.0, eps=0.0, fun_dist=None, **kwargs
self, r0=None, ref_index=None, tol=0.05, n_iter=10, p=2.0, eps=0.0, fun_dist=None
):
"""
Initialize class.
Expand Down Expand Up @@ -496,13 +497,6 @@ def __init__(
fun_dist: callable, optional
Function for calculating the distances between sample points. When `fun_dist` is `None`,
the Minkowski p-norm distance is used. Default is None.
kwargs: dict, optional
Additional keyword arguments to be passed to the distance function `fun_dist`.
Notes
-----
If `p` is also defined in `kwargs`, the value of `p` from the argument will be used. For
example, when `p=2` and `kwargs={"p": 3}`, the value of `p` will be 2.
"""
self.r0 = r0
Expand All @@ -519,14 +513,6 @@ def __init__(
# self.fun_dist = fun_dist
self.fun_dist = fun_dist

self.kwargs = kwargs
if "p" in self.kwargs.keys():
warnings.warn(
f"Value of p in kwargs is overwritten by: {self.p} as defined in the "
f"argument `p`."
)
self.kwargs["p"] = p

def algorithm(self, x, max_size):
"""Return selected samples based on directed sphere exclusion algorithm.
Expand All @@ -545,10 +531,10 @@ def algorithm(self, x, max_size):
"""
if self.fun_dist is None:
distances = spatial.distance.squareform(
spatial.distance.pdist(x, metric="minkowski", p=self.kwargs.get("p"))
spatial.distance.pdist(x, metric="minkowski", p=self.p)
)
else:
distances = spatial.distance.squareform(self.fun_dist(x, **self.kwargs))
distances = self.fun_dist(x)

# set up the ref_index as when is None
if self.ref_index is None:
Expand Down
17 changes: 17 additions & 0 deletions selector/methods/tests/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import numpy as np
import pytest
from numpy.testing import assert_equal, assert_raises
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import pairwise_distances


from selector.methods.distance import DISE, MaxMin, MaxSum, OptiSim
from selector.methods.tests.common import generate_synthetic_data

Expand Down Expand Up @@ -369,3 +371,18 @@ def test_directed_sphere_on_line_with_larger_radius():
expected = [1, 5, 9]
assert_equal(selected, expected)
assert_equal(collector.r, 1.0)


def test_directed_sphere_dist_func():
"""Test Direct Sphere Exclusion with a distance function."""
# (0,0) as the reference point
x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]])
collector = DISE(r0=0.5,
tol=0,
ref_index=0,
fun_dist=lambda x: squareform(pdist(x, metric="minkowski", p=0.1))
)
selected = collector.select(x, size=3)
expected = [0, 3, 6]
assert_equal(selected, expected)
assert_equal(collector.r, 2.0)

0 comments on commit 5c3037f

Please sign in to comment.