diff --git a/selector/methods/distance.py b/selector/methods/distance.py index d3f9b6ee..99aeb640 100644 --- a/selector/methods/distance.py +++ b/selector/methods/distance.py @@ -22,6 +22,7 @@ # # -- """Module for Distance-Based Selection Methods.""" +import inspect import warnings import bitarray @@ -461,7 +462,7 @@ class DISE(SelectionBase): """ def __init__( - self, r0=None, ref_index=None, tol=0.05, n_iter=10, p=2.0, eps=0.0, fun_dist=None, **kwargs + self, r0=None, ref_index=None, tol=0.05, n_iter=10, p=2.0, eps=0.0, fun_dist=None ): """ Initialize class. @@ -496,13 +497,6 @@ def __init__( fun_dist: callable, optional Function for calculating the distances between sample points. When `fun_dist` is `None`, the Minkowski p-norm distance is used. Default is None. - kwargs: dict, optional - Additional keyword arguments to be passed to the distance function `fun_dist`. - - Notes - ----- - If `p` is also defined in `kwargs`, the value of `p` from the argument will be used. For - example, when `p=2` and `kwargs={"p": 3}`, the value of `p` will be 2. """ self.r0 = r0 @@ -519,14 +513,6 @@ def __init__( # self.fun_dist = fun_dist self.fun_dist = fun_dist - self.kwargs = kwargs - if "p" in self.kwargs.keys(): - warnings.warn( - f"Value of p in kwargs is overwritten by: {self.p} as defined in the " - f"argument `p`." - ) - self.kwargs["p"] = p - def algorithm(self, x, max_size): """Return selected samples based on directed sphere exclusion algorithm. @@ -545,10 +531,10 @@ def algorithm(self, x, max_size): """ if self.fun_dist is None: distances = spatial.distance.squareform( - spatial.distance.pdist(x, metric="minkowski", p=self.kwargs.get("p")) + spatial.distance.pdist(x, metric="minkowski", p=self.p) ) else: - distances = spatial.distance.squareform(self.fun_dist(x, **self.kwargs)) + distances = self.fun_dist(x) # set up the ref_index as when is None if self.ref_index is None: diff --git a/selector/methods/tests/test_distance.py b/selector/methods/tests/test_distance.py index a65620da..12ec65aa 100644 --- a/selector/methods/tests/test_distance.py +++ b/selector/methods/tests/test_distance.py @@ -26,8 +26,10 @@ import numpy as np import pytest from numpy.testing import assert_equal, assert_raises +from scipy.spatial.distance import pdist, squareform from sklearn.metrics import pairwise_distances + from selector.methods.distance import DISE, MaxMin, MaxSum, OptiSim from selector.methods.tests.common import generate_synthetic_data @@ -369,3 +371,18 @@ def test_directed_sphere_on_line_with_larger_radius(): expected = [1, 5, 9] assert_equal(selected, expected) assert_equal(collector.r, 1.0) + + +def test_directed_sphere_dist_func(): + """Test Direct Sphere Exclusion with a distance function.""" + # (0,0) as the reference point + x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]]) + collector = DISE(r0=0.5, + tol=0, + ref_index=0, + fun_dist=lambda x: squareform(pdist(x, metric="minkowski", p=0.1)) + ) + selected = collector.select(x, size=3) + expected = [0, 3, 6] + assert_equal(selected, expected) + assert_equal(collector.r, 2.0)