From 013177b4d2f51173fbd53af6c72136601a08e1c2 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sat, 5 Oct 2024 16:56:10 -0400 Subject: [PATCH 01/12] Use numpy to count the unique number of labels for efficiency --- selector/methods/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/selector/methods/base.py b/selector/methods/base.py index d9263d7e..2b1c7b54 100644 --- a/selector/methods/base.py +++ b/selector/methods/base.py @@ -71,11 +71,9 @@ def select(self, x: np.ndarray, size: int, labels: np.ndarray = None) -> np.ndar ) # compute the number of samples (i.e. population or pop) in each cluster - unique_labels = np.unique(labels) + unique_labels, unique_label_counts = np.unique(labels, return_counts=True) num_clusters = len(unique_labels) - pop_clusters = { - unique_label: len(np.where(labels == unique_label)[0]) for unique_label in unique_labels - } + pop_clusters = dict(zip(unique_labels, unique_label_counts)) # compute number of samples to be selected from each cluster n = size // num_clusters From cc070c2a744a2825dd9a589b1e4a44f073e35824 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sat, 5 Oct 2024 20:56:08 -0400 Subject: [PATCH 02/12] Add support of selection different cluster subsets proportionately --- selector/methods/base.py | 103 +++++++++++++++++++++++++++------------ 1 file changed, 72 insertions(+), 31 deletions(-) diff --git a/selector/methods/base.py b/selector/methods/base.py index 2b1c7b54..650ec29e 100644 --- a/selector/methods/base.py +++ b/selector/methods/base.py @@ -34,7 +34,12 @@ class SelectionBase(ABC): """Base class for selecting subset of sample points.""" - def select(self, x: np.ndarray, size: int, labels: np.ndarray = None) -> np.ndarray: + def select(self, + x: np.ndarray, + size: int, + labels: np.ndarray = None, + proportional_selection: bool = False, + ) -> list: """Return indices representing subset of sample points. Parameters @@ -48,6 +53,10 @@ def select(self, x: np.ndarray, size: int, labels: np.ndarray = None) -> np.ndar Array of integers or strings representing the labels of the clusters that each sample belongs to. If `None`, the samples are treated as one cluster. If labels are provided, selection is made from each cluster. + proportional_selection: bool, optional + If True, the number of samples to be selected from each cluster is proportional. + Otherwise, the number of samples to be selected from each cluster is equal. + Default is False. Returns ------- @@ -70,49 +79,81 @@ def select(self, x: np.ndarray, size: int, labels: np.ndarray = None) -> np.ndar f"Number of labels {len(labels)} does not match number of samples {len(x)}." ) + selected_ids = [] + # compute the number of samples (i.e. population or pop) in each cluster unique_labels, unique_label_counts = np.unique(labels, return_counts=True) num_clusters = len(unique_labels) pop_clusters = dict(zip(unique_labels, unique_label_counts)) # compute number of samples to be selected from each cluster - n = size // num_clusters - - # update number of samples to select from each cluster based on the cluster population. - # this is needed when some clusters do not have enough samples in them (pop < n) and - # needs to be done iteratively until all remaining clusters have at least n samples - selected_ids = [] - while np.any([value <= n for value in pop_clusters.values() if value != 0]): - for unique_label in unique_labels: - if pop_clusters[unique_label] != 0: - # get index of sample labelled with unique_label - cluster_ids = np.where(labels == unique_label)[0] - if len(cluster_ids) <= n: - # all samples in the cluster are selected & population becomes zero - selected_ids.append(cluster_ids) - pop_clusters[unique_label] = 0 - # update number of samples to be selected from each cluster - totally_used_clusters = list(pop_clusters.values()).count(0) - n = (size - len(np.hstack(selected_ids))) // (num_clusters - totally_used_clusters) - - warnings.warn( - f"Number of molecules in one cluster is less than" - f" {size}/{num_clusters}.\nNumber of selected " - f"molecules might be less than desired.\nIn order to avoid this " - f"problem. Try to use less number of clusters" - ) - - for unique_label in unique_labels: + if proportional_selection: + # make sure that tht total number of samples selected is equal to size + size_each_cluster = size * unique_label_counts / len(labels) + # using np.round to get to the nearest integer + # not using int function directly to avoid truncation of decimal values + size_each_cluster = np.round(size_each_cluster).astype(int) + # the total number of samples selected from all clusters at this point + size_each_cluster_total = np.sum(size_each_cluster) + # Adjust if the total is less than the required number + if size_each_cluster_total < size: + while size_each_cluster_total < size: + largest_cluster_index = np.argmax(unique_label_counts - size_each_cluster) + size_each_cluster[largest_cluster_index] += 1 + size_each_cluster_total += 1 + # Adjust if the total is more than the required number + elif size_each_cluster_total > size: + while size_each_cluster_total > size: + smallest_cluster_index = np.argmin(unique_label_counts - size_each_cluster) + size_each_cluster[smallest_cluster_index] -= 1 + size_each_cluster_total -= 1 + # perfect case where the total is equal to the required number + else: + pass + else: + size_each_cluster = size // num_clusters + + # update number of samples to select from each cluster based on the cluster population. + # this is needed when some clusters do not have enough samples in them + # (pop < size_each_cluster) and needs to be done iteratively until all remaining clusters + # have at least size_each_cluster samples + while np.any( + [value <= size_each_cluster for value in pop_clusters.values() if value != 0] + ): + # while list(pop_clusters.values()).count(0) < num_clusters: + for unique_label in unique_labels: + if pop_clusters[unique_label] != 0: + # get index of sample labelled with unique_label + cluster_ids = np.where(labels == unique_label)[0] + if len(cluster_ids) <= size_each_cluster: + # all samples in the cluster are selected & population becomes zero + selected_ids.append(cluster_ids) + pop_clusters[unique_label] = 0 + # update number of samples to be selected from each cluster + totally_used_clusters = list(pop_clusters.values()).count(0) + size_each_cluster = (size - len(np.hstack(selected_ids))) // ( + num_clusters - totally_used_clusters) + + warnings.warn( + f"Number of molecules in one cluster is less than" + f" {size}/{num_clusters}.\nNumber of selected " + f"molecules might be less than desired.\nIn order to avoid this " + f"problem. Try to use less number of clusters." + ) + # save the number of samples to be selected from each cluster in an array + size_each_cluster = np.full(num_clusters, size_each_cluster) + + for unique_label, size_sub in zip(unique_labels, size_each_cluster): if pop_clusters[unique_label] != 0: - # sample n ids from cluster labeled unique_label + # sample size_each_cluster ids from cluster labeled unique_label cluster_ids = np.where(labels == unique_label)[0] - selected = self.select_from_cluster(x, n, cluster_ids) + selected = self.select_from_cluster(x, size_sub, cluster_ids) selected_ids.append(cluster_ids[selected]) return np.hstack(selected_ids).flatten().tolist() @abstractmethod def select_from_cluster( - self, x: np.ndarray, size: int, labels: np.ndarray = None + self, x: np.ndarray, size: int, labels: np.ndarray = None ) -> np.ndarray: """Return indices representing subset of sample points from one cluster. From 2952169855781006c5d3eac6890a610271caf9c2 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sat, 5 Oct 2024 22:28:54 -0400 Subject: [PATCH 03/12] Add tests for the proportional selection --- selector/methods/tests/common.py | 15 +++ selector/methods/tests/test_distance.py | 145 +++++++++++++++++++++--- 2 files changed, 147 insertions(+), 13 deletions(-) diff --git a/selector/methods/tests/common.py b/selector/methods/tests/common.py index 98a843e4..3b08cc9e 100644 --- a/selector/methods/tests/common.py +++ b/selector/methods/tests/common.py @@ -30,10 +30,25 @@ from sklearn.metrics import pairwise_distances __all__ = [ + "generate_synthetic_cluster_data", "generate_synthetic_data", ] +def generate_synthetic_cluster_data(): + # generate the first cluster with 3 points + cluster_one = np.array([[0, 0], [0, 1], [0, 2]]) + # generate the second cluster with 6 points + cluster_two = np.array([[3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5]]) + # generate the third cluster with 9 points + cluster_three = np.array([[6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6], [6, 7], [6, 8]]) + # concatenate the clusters + coords = np.vstack([cluster_one, cluster_two, cluster_three]) + # generate the labels + labels = np.hstack([[0 for _ in range(3)], [1 for _ in range(6)], [2 for _ in range(9)]]) + + return coords, labels, cluster_one, cluster_two, cluster_three + def generate_synthetic_data( n_samples: int = 100, n_features: int = 2, diff --git a/selector/methods/tests/test_distance.py b/selector/methods/tests/test_distance.py index 12ec65aa..e7c7bb8f 100644 --- a/selector/methods/tests/test_distance.py +++ b/selector/methods/tests/test_distance.py @@ -29,9 +29,11 @@ from scipy.spatial.distance import pdist, squareform from sklearn.metrics import pairwise_distances - from selector.methods.distance import DISE, MaxMin, MaxSum, OptiSim -from selector.methods.tests.common import generate_synthetic_data +from selector.methods.tests.common import ( + generate_synthetic_cluster_data, + generate_synthetic_data, +) def test_maxmin(): @@ -58,7 +60,12 @@ def test_maxmin(): # use MaxMin algorithm to select points from clustered data collector = MaxMin() - selected_ids = collector.select(arr_dist_cluster, size=12, labels=class_labels_cluster) + selected_ids = collector.select( + arr_dist_cluster, + size=12, + labels=class_labels_cluster, + proportional_selection=False, + ) # make sure all the selected indices are the same with expectation assert_equal(selected_ids, [41, 34, 94, 85, 51, 50, 66, 78, 21, 64, 29, 83]) @@ -125,9 +132,37 @@ def test_maxmin(): # selecting molecules collector = MaxMin(lambda x: pairwise_distances(x, metric="euclidean")) - selected_mocked = collector.select(mocked_cluster_coords, size=15, labels=labels_mocked) + selected_mocked = collector.select( + mocked_cluster_coords, + size=15, + labels=labels_mocked, + proportional_selection=False, + ) assert_equal(selected_mocked, [0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 15, 10, 13, 9, 18]) +def test_maxmin_proportional_selection(): + """Test MaxMin class with proportional selection.""" + # generate the first cluster with 3 points + coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() + # instantiate the MaxMin class + collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) + # select 6 points with proportional selection from each cluster + selected_ids = collector.select(coords, + size=6, + labels=labels, + proportional_selection=True, + ) + # make sure all the selected indices are the same with expectation + assert_equal(selected_ids,[0, 3, 8, 9, 17, 13]) + # check how many points are selected from each cluster + assert_equal(len(selected_ids), 6) + # check the number of points selected from cluster one + assert_equal((labels[selected_ids] == 0).sum(), 1) + # check the number of points selected from cluster two + assert_equal((labels[selected_ids] == 1).sum(), 2) + # check the number of points selected from cluster three + assert_equal((labels[selected_ids] == 2).sum(), 3) + def test_maxmin_invalid_input(): """Testing MaxMin class with invalid input.""" @@ -158,20 +193,32 @@ def test_maxsum_clustered_data(): # use MaxSum algorithm to select points from clustered data, instantiating with euclidean distance metric collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=None) - selected_ids = collector.select(coords_cluster, size=12, labels=class_labels_cluster) + selected_ids = collector.select( + coords_cluster, + size=12, + labels=class_labels_cluster, + proportional_selection=False, + ) # make sure all the selected indices are the same with expectation assert_equal(selected_ids, [41, 34, 85, 94, 51, 50, 78, 66, 21, 64, 0, 83]) # use MaxSum algorithm to select points from clustered data without instantiating with euclidean distance metric collector = MaxSum(ref_index=None) - selected_ids = collector.select(coords_cluster_dist, size=12, labels=class_labels_cluster) + selected_ids = collector.select( + coords_cluster_dist, + size=12, + labels=class_labels_cluster, + proportional_selection=False, + ) # make sure all the selected indices are the same with expectation assert_equal(selected_ids, [41, 34, 85, 94, 51, 50, 78, 66, 21, 64, 0, 83]) # check that ValueError is raised when number of points requested is greater than number of points in array with pytest.raises(ValueError): - selected_ids = collector.select_from_cluster( - coords_cluster, size=101, labels=class_labels_cluster + _ = collector.select_from_cluster( + coords_cluster, + size=101, + labels=class_labels_cluster, ) @@ -239,6 +286,30 @@ def test_maxsum_invalid_input(): _ = collector.select(x_dist, size=2) +def test_maxsum_proportional_selection(): + """Test MaxSum class with proportional selection.""" + # generate the first cluster with 3 points + coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() + # instantiate the MaxSum class + collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) + # select 6 points with proportional selection from each cluster + selected_ids = collector.select(coords, + size=6, + labels=labels, + proportional_selection=True, + ) + # make sure all the selected indices are the same with expectation + assert_equal(selected_ids,[0, 3, 8, 9, 17, 10]) + # check how many points are selected from each cluster + assert_equal(len(selected_ids), 6) + # check the number of points selected from cluster one + assert_equal((labels[selected_ids] == 0).sum(), 1) + # check the number of points selected from cluster two + assert_equal((labels[selected_ids] == 1).sum(), 2) + # check the number of points selected from cluster three + assert_equal((labels[selected_ids] == 2).sum(), 3) + + def test_optisim(): """Testing OptiSim class.""" # generate random data points belonging to one cluster - coordinates and pairwise distance matrix @@ -286,6 +357,29 @@ def test_optisim(): _ = collector.select(coords, size=12) +def test_optisim_proportional_selection(): + """Test OptiSim class with proportional selection.""" + # generate the first cluster with 3 points + coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() + # instantiate the Optisim class + collector = OptiSim(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) + # select 6 points with proportional selection from each cluster + selected_ids = collector.select(coords, + size=6, + labels=labels, + proportional_selection=True, + ) + # make sure all the selected indices are the same with expectation + assert_equal(selected_ids,[0, 3, 8, 9, 17, 13]) + # check how many points are selected from each cluster + assert_equal(len(selected_ids), 6) + # check the number of points selected from cluster one + assert_equal((labels[selected_ids] == 0).sum(), 1) + # check the number of points selected from cluster two + assert_equal((labels[selected_ids] == 1).sum(), 2) + # check the number of points selected from cluster three + assert_equal((labels[selected_ids] == 2).sum(), 3) + def test_directed_sphere_size_error(): """Test DirectedSphereExclusion error when too many points requested.""" x = np.array([[1, 9]] * 100) @@ -377,12 +471,37 @@ def test_directed_sphere_dist_func(): """Test Direct Sphere Exclusion with a distance function.""" # (0,0) as the reference point x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]]) - collector = DISE(r0=0.5, - tol=0, - ref_index=0, - fun_dist=lambda x: squareform(pdist(x, metric="minkowski", p=0.1)) - ) + collector = DISE( + r0=0.5, + tol=0, + ref_index=0, + fun_dist=lambda x: squareform(pdist(x, metric="minkowski", p=0.1)), + ) selected = collector.select(x, size=3) expected = [0, 3, 6] assert_equal(selected, expected) assert_equal(collector.r, 2.0) + + +def test_directed_sphere_proportional_selection(): + """Test DISE class with proportional selection.""" + # generate the first cluster with 3 points + coords, labels, cluster_one, cluster_two, cluster_three = generate_synthetic_cluster_data() + # instantiate the DISE class + collector = DISE(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) + # select 6 points with proportional selection from each cluster + selected_ids = collector.select(coords, + size=6, + labels=labels, + proportional_selection=True, + ) + # make sure all the selected indices are the same with expectation + assert_equal(selected_ids,[0, 3, 7, 9, 12, 15]) + # check how many points are selected from each cluster + assert_equal(len(selected_ids), 6) + # check the number of points selected from cluster one + assert_equal((labels[selected_ids] == 0).sum(), 1) + # check the number of points selected from cluster two + assert_equal((labels[selected_ids] == 1).sum(), 2) + # check the number of points selected from cluster three + assert_equal((labels[selected_ids] == 2).sum(), 3) From b2380fab1b0bbabc26646b71f29af8d6689f0b30 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 01:08:06 -0400 Subject: [PATCH 04/12] Fix the problem of when zero number of elements selected from the minority class --- selector/methods/base.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/selector/methods/base.py b/selector/methods/base.py index 650ec29e..bfba7730 100644 --- a/selector/methods/base.py +++ b/selector/methods/base.py @@ -38,7 +38,7 @@ def select(self, x: np.ndarray, size: int, labels: np.ndarray = None, - proportional_selection: bool = False, + proportional_selection: bool = True, ) -> list: """Return indices representing subset of sample points. @@ -56,7 +56,7 @@ def select(self, proportional_selection: bool, optional If True, the number of samples to be selected from each cluster is proportional. Otherwise, the number of samples to be selected from each cluster is equal. - Default is False. + Default is True. Returns ------- @@ -92,20 +92,33 @@ def select(self, # using np.round to get to the nearest integer # not using int function directly to avoid truncation of decimal values size_each_cluster = np.round(size_each_cluster).astype(int) + # make sure each cluster has at least one sample + size_each_cluster[size_each_cluster < 1] = 1 + # the total number of samples selected from all clusters at this point size_each_cluster_total = np.sum(size_each_cluster) # Adjust if the total is less than the required number if size_each_cluster_total < size: while size_each_cluster_total < size: + # select the largest cluster with maximum number of data points not selected + # and add one sample to it largest_cluster_index = np.argmax(unique_label_counts - size_each_cluster) size_each_cluster[largest_cluster_index] += 1 size_each_cluster_total += 1 # Adjust if the total is more than the required number elif size_each_cluster_total > size: while size_each_cluster_total > size: - smallest_cluster_index = np.argmin(unique_label_counts - size_each_cluster) - size_each_cluster[smallest_cluster_index] -= 1 + largest_cluster_index = np.argmax(unique_label_counts - size_each_cluster) + size_each_cluster[largest_cluster_index] -= 1 size_each_cluster_total -= 1 + + # # when the total number of samples selected is more than the required number + # # we need to remove samples from the largest clusters + # while size_each_cluster_total > size: + # largest_cluster_index = np.argmax(size_each_cluster) + # size_each_cluster[largest_cluster_index] -= 1 + # size_each_cluster_total -= 1 + # perfect case where the total is equal to the required number else: pass @@ -119,7 +132,6 @@ def select(self, while np.any( [value <= size_each_cluster for value in pop_clusters.values() if value != 0] ): - # while list(pop_clusters.values()).count(0) < num_clusters: for unique_label in unique_labels: if pop_clusters[unique_label] != 0: # get index of sample labelled with unique_label From 6a08c5609a4be84f0d4f17f010dce87c3fba715f Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 01:10:39 -0400 Subject: [PATCH 05/12] Add tests for imbalance case of multiple classes --- selector/methods/tests/common.py | 23 +++++++++ selector/methods/tests/test_distance.py | 62 +++++++++++++++++++++++ selector/methods/tests/test_similarity.py | 24 +-------- 3 files changed, 87 insertions(+), 22 deletions(-) diff --git a/selector/methods/tests/common.py b/selector/methods/tests/common.py index 3b08cc9e..1ed5ad0d 100644 --- a/selector/methods/tests/common.py +++ b/selector/methods/tests/common.py @@ -23,6 +23,7 @@ # -- """Common functions for test module.""" +from importlib import resources from typing import Any, Tuple, Union import numpy as np @@ -32,6 +33,7 @@ __all__ = [ "generate_synthetic_cluster_data", "generate_synthetic_data", + "get_data_file_path", ] @@ -118,3 +120,24 @@ def generate_synthetic_data( ) return syn_data, class_labels, dist return syn_data, class_labels + + +def get_data_file_path(file_name): + """Get the absolute path of the data file inside the package. + + Parameters + ---------- + file_name : str + The name of the data file to load. + + Returns + ------- + str + The absolute path of the data file inside the package + + """ + data_file_path = resources.files("selector.methods.tests").joinpath( + f"data/{file_name}" + ) + + return data_file_path diff --git a/selector/methods/tests/test_distance.py b/selector/methods/tests/test_distance.py index e7c7bb8f..2cfbbf31 100644 --- a/selector/methods/tests/test_distance.py +++ b/selector/methods/tests/test_distance.py @@ -33,6 +33,7 @@ from selector.methods.tests.common import ( generate_synthetic_cluster_data, generate_synthetic_data, + get_data_file_path, ) @@ -164,6 +165,67 @@ def test_maxmin_proportional_selection(): assert_equal((labels[selected_ids] == 2).sum(), 3) +def test_maxmin_proportional_selection_imbalance_1(): + """Test MaxMin class with proportional selection with imbalance case 1.""" + # load three-cluster data from file + # 2 from class 0, 10 from class 1, 40 from class 2 + coords_file_path = get_data_file_path("coords_imbalance_case1.txt") + coords = np.genfromtxt(coords_file_path, delimiter=",", skip_header=0) + labels_file_path = get_data_file_path("labels_imbalance_case1.txt") + labels = np.genfromtxt(labels_file_path, delimiter=",", skip_header=0) + + # instantiate the MaxMin class + collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) + # select 12 points with proportional selection from each cluster + selected_ids = collector.select(coords, + size=9, + labels=labels, + proportional_selection=True, + ) + + # make sure all the selected indices are the same with expectation + assert_equal(selected_ids,[0, 2, 6, 12, 15, 38, 16, 41, 36]) + # check how many points are selected from each cluster + assert_equal(len(selected_ids), 9) + # check the number of points selected from cluster one + assert_equal((labels[selected_ids] == 0).sum(), 1) + # check the number of points selected from cluster two + assert_equal((labels[selected_ids] == 1).sum(), 2) + # check the number of points selected from cluster three + assert_equal((labels[selected_ids] == 2).sum(), 6) + + +def test_maxmin_proportional_selection_imbalance_2(): + """Test MaxMin class with proportional selection with imbalance case 2.""" + # load three-cluster data from file + # 3 from class 0, 11 from class 1, 40 from class 2 + coords_file_path = get_data_file_path("coords_imbalance_case2.txt") + coords = np.genfromtxt(coords_file_path, delimiter=",", skip_header=0) + labels_file_path = get_data_file_path("labels_imbalance_case2.txt") + labels = np.genfromtxt(labels_file_path, delimiter=",", skip_header=0) + + # instantiate the MaxMin class + collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) + # select 12 points with proportional selection from each cluster + selected_ids = collector.select(coords, + size=14, + labels=labels, + proportional_selection=True, + ) + + # # make sure all the selected indices are the same with expectation + assert_equal(selected_ids,[0, 3, 9, 6, 14, 36, 53, 17, 44, 23, 28, 50, 52, 49]) + print(f"selected_ids: {selected_ids}") + # check how many points are selected from each cluster + assert_equal(len(selected_ids), 14) + # check the number of points selected from cluster one + assert_equal((labels[selected_ids] == 0).sum(), 1) + # check the number of points selected from cluster two + assert_equal((labels[selected_ids] == 1).sum(), 3) + # check the number of points selected from cluster three + assert_equal((labels[selected_ids] == 2).sum(), 10) + + def test_maxmin_invalid_input(): """Testing MaxMin class with invalid input.""" # case when the distance matrix is not square diff --git a/selector/methods/tests/test_similarity.py b/selector/methods/tests/test_similarity.py index 09d7d654..b246b619 100644 --- a/selector/methods/tests/test_similarity.py +++ b/selector/methods/tests/test_similarity.py @@ -31,13 +31,14 @@ import pytest from numpy.testing import assert_almost_equal, assert_equal, assert_raises -from selector.methods.similarity import NSimilarity, SimilarityIndex from selector.measures.similarity import ( modified_tanimoto, pairwise_similarity_bit, scaled_similarity_matrix, tanimoto, ) +from selector.methods.similarity import NSimilarity, SimilarityIndex +from selector.methods.tests.common import get_data_file_path def test_pairwise_similarity_bit_raises(): @@ -1513,27 +1514,6 @@ def test_get_new_index_esim(c_threshold, w_factor, n_ary): # --------------------------------------------------------------------------------------------- # # Get reference data for testing the selection of the diverse subset # --------------------------------------------------------------------------------------------- # -def get_data_file_path(file_name): - """Get the absolute path of the data file inside the package. - - Parameters - ---------- - file_name : str - The name of the data file to load. - - Returns - ------- - str - The absolute path of the data file inside the package - - """ - data_file_path = importlib.resources.files("selector.methods.tests").joinpath( - f"data/{file_name}" - ) - - return data_file_path - - # get reference selected data for esim method def _get_selections_esim_ref_dict(): """Returns a dictionary with the reference values for the selection of samples. From 3e398750bc9796e43adba31d8d036381d33ace94 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 01:12:20 -0400 Subject: [PATCH 06/12] Add testing data for imbalance cases --- .../tests/data/coords_imbalance_case1.txt | 52 ++++++++++++++++++ .../tests/data/coords_imbalance_case2.txt | 54 +++++++++++++++++++ .../tests/data/labels_imbalance_case1.txt | 52 ++++++++++++++++++ .../tests/data/labels_imbalance_case2.txt | 54 +++++++++++++++++++ 4 files changed, 212 insertions(+) create mode 100644 selector/methods/tests/data/coords_imbalance_case1.txt create mode 100644 selector/methods/tests/data/coords_imbalance_case2.txt create mode 100644 selector/methods/tests/data/labels_imbalance_case1.txt create mode 100644 selector/methods/tests/data/labels_imbalance_case2.txt diff --git a/selector/methods/tests/data/coords_imbalance_case1.txt b/selector/methods/tests/data/coords_imbalance_case1.txt new file mode 100644 index 00000000..8a29ea2e --- /dev/null +++ b/selector/methods/tests/data/coords_imbalance_case1.txt @@ -0,0 +1,52 @@ +-2.988371860898040300e+00,8.828627151534506723e+00 +-2.522694847790684314e+00,7.956575199242423402e+00 +2.721107620929060111e+00,1.946655808491515094e+00 +3.856625543891864183e+00,1.651108167735056309e+00 +4.447517871446978965e+00,2.274717026274344356e+00 +4.247770683095943411e+00,5.096547358086134238e-01 +5.161820401844998685e+00,2.270154357173918225e+00 +3.448575339025452990e+00,2.629723292574561722e+00 +4.110118632461063015e+00,2.486437117054088208e+00 +4.605167066522858121e+00,8.044916463211999602e-01 +3.959854114649610679e+00,2.205423381101735636e+00 +4.935999113292677265e+00,2.234224956120621108e+00 +-7.194896435791616085e+00,-6.121140372782679862e+00 +-6.521839830802987237e+00,-6.319325066907712340e+00 +-6.665533447021066316e+00,-8.125848371987935082e+00 +-4.564968624477761416e+00,-8.747374785867695124e+00 +-4.735683101825944874e+00,-6.246190570957935506e+00 +-7.144284024389226495e+00,-4.159940426686327797e+00 +-6.364591923942610308e+00,-6.366323642363737711e+00 +-7.769141620776792934e+00,-7.695919878241385348e+00 +-6.821418472705270020e+00,-8.023079891106569050e+00 +-7.541413655919658510e+00,-6.027676258479722549e+00 +-6.706446265300088250e+00,-6.494792213547110116e+00 +-6.406389566577725070e+00,-6.952938505932819702e+00 +-7.609993822868406532e+00,-6.663651003693972008e+00 +-5.796575947975993515e+00,-5.826307541241043886e+00 +-7.351559056940703663e+00,-5.791158996308579887e+00 +-7.364990738980373486e+00,-6.798235453889623692e+00 +-6.956728900565374296e+00,-6.538957618459303234e+00 +-6.253959843386263984e+00,-7.737267149692229395e+00 +-6.057567031156779969e+00,-4.983316610621999487e+00 +-7.594930900411238639e+00,-6.200511844341271228e+00 +-7.125015307154140665e+00,-7.633845757633435980e+00 +-7.672147929583970516e+00,-6.994846034742845831e+00 +-7.103089976477121148e+00,-6.166109099183854525e+00 +-6.602936391821250695e+00,-6.052926344239923040e+00 +-8.904769777808876796e+00,-6.693655278506518869e+00 +-8.257296559108361578e+00,-7.817934633191069516e+00 +-6.364579504845222502e+00,-3.027378102621225864e+00 +-6.834055351247456223e+00,-7.531709940881763821e+00 +-7.652452405688841885e+00,-7.116928200015955497e+00 +-7.726420909219674726e+00,-8.394956817961810813e+00 +-6.866625299273363403e+00,-5.426575516118630205e+00 +-6.374639912170812828e+00,-6.014354399105824811e+00 +-7.326142143218291380e+00,-6.023710798952474299e+00 +-6.308736680458102875e+00,-5.744543953095347710e+00 +-8.079923598207045643e+00,-7.214610829116894664e+00 +-6.193367000776756726e+00,-8.492825464465598273e+00 +-5.925625427658067323e+00,-6.228718341970148842e+00 +-7.950519689212382168e+00,-6.397637178032761440e+00 +-7.763484627352402967e+00,-6.726384487330419049e+00 +-6.815347172055806979e+00,-7.957854371205252519e+00 diff --git a/selector/methods/tests/data/coords_imbalance_case2.txt b/selector/methods/tests/data/coords_imbalance_case2.txt new file mode 100644 index 00000000..9f5600b2 --- /dev/null +++ b/selector/methods/tests/data/coords_imbalance_case2.txt @@ -0,0 +1,54 @@ +-2.545023662162701594e+00,1.057892978401232931e+01 +-3.348415146275388832e+00,8.705073752347109561e+00 +-3.186119623358708797e+00,9.625962417039191976e+00 +6.526064737438631802e+00,2.147747496772570930e+00 +5.265546183993107476e+00,1.116012127524449449e+00 +3.793085118159696290e+00,4.583224592548673648e-01 +4.605167066522858121e+00,8.044916463211999602e-01 +3.665197166000779827e+00,2.760254287683184149e+00 +4.890371686573978138e+00,2.319617893437707856e+00 +3.089215405161968686e+00,2.041732658746759466e+00 +4.416416050902250312e+00,2.687170178032824097e+00 +3.568986338166989292e+00,2.455642099183917182e+00 +4.447517871446978965e+00,2.274717026274344356e+00 +5.161820401844998685e+00,2.270154357173918225e+00 +-6.598635323416237597e+00,-7.502809113096540194e+00 +-6.364591923942610308e+00,-6.366323642363737711e+00 +-7.351559056940703663e+00,-5.791158996308579887e+00 +-4.757470994138636833e+00,-5.847644332724799554e+00 +-7.132195342544430439e+00,-8.127892775240795231e+00 +-6.766109845900022179e+00,-6.217978918754900164e+00 +-6.680567495577800052e+00,-7.480326470434741637e+00 +-7.354572502312226590e+00,-7.533438825849658294e+00 +-4.735683101825944874e+00,-6.246190570957935506e+00 +-8.140511145486314604e+00,-5.962247646221170427e+00 +-6.374639912170812828e+00,-6.014354399105824811e+00 +-4.746593816495003892e+00,-8.832197392798448732e+00 +-7.652452405688841885e+00,-7.116928200015955497e+00 +-6.435807763005041870e+00,-6.105475539846610289e+00 +-6.308736680458102875e+00,-5.744543953095347710e+00 +-6.900528785115418451e+00,-6.762782209967165059e+00 +-6.834055351247456223e+00,-7.531709940881763821e+00 +-8.079923598207045643e+00,-7.214610829116894664e+00 +-7.672147929583970516e+00,-6.994846034742845831e+00 +-5.293610375005918023e+00,-8.117925092102796114e+00 +-7.087749441508545800e+00,-7.373110527934779945e+00 +-5.247215887219635277e+00,-8.310250971236579076e+00 +-6.364579504845222502e+00,-3.027378102621225864e+00 +-7.364990738980373486e+00,-6.798235453889623692e+00 +-7.861135842199221457e+00,-6.418006119012676258e+00 +-6.132333586028008376e+00,-6.269739327842481558e+00 +-5.925625427658067323e+00,-6.228718341970148842e+00 +-5.612716041964647573e+00,-7.587779058894727591e+00 +-5.980027315718019487e+00,-6.572810072399337677e+00 +-7.031412286186853322e+00,-6.291792386791370539e+00 +-4.564968624477761416e+00,-8.747374785867695124e+00 +-7.319671677848253566e+00,-6.749369015989855392e+00 +-5.438353902085154346e+00,-8.315971744455385561e+00 +-6.809825106161251362e+00,-7.265423190137706655e+00 +-8.904769777808876796e+00,-6.693655278506518869e+00 +-6.193367000776756726e+00,-8.492825464465598273e+00 +-8.398997157105283051e+00,-7.364343666142198153e+00 +-6.522611705186222686e+00,-7.573019188536600943e+00 +-5.716463438996310487e+00,-6.869876532256359525e+00 +-1.012089453122034222e+01,-7.904497234610236234e+00 diff --git a/selector/methods/tests/data/labels_imbalance_case1.txt b/selector/methods/tests/data/labels_imbalance_case1.txt new file mode 100644 index 00000000..96ed3569 --- /dev/null +++ b/selector/methods/tests/data/labels_imbalance_case1.txt @@ -0,0 +1,52 @@ +0.000000000000000000e+00 +0.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 diff --git a/selector/methods/tests/data/labels_imbalance_case2.txt b/selector/methods/tests/data/labels_imbalance_case2.txt new file mode 100644 index 00000000..b3bb027f --- /dev/null +++ b/selector/methods/tests/data/labels_imbalance_case2.txt @@ -0,0 +1,54 @@ +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +1.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 +2.000000000000000000e+00 From 9ad3a0d7f0cc9d0565f253efe9c10023a9a34107 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 08:16:07 -0400 Subject: [PATCH 07/12] Reformat with black --- selector/methods/base.py | 20 +++--- selector/methods/distance.py | 4 +- selector/methods/partition.py | 1 - selector/methods/tests/common.py | 9 +-- selector/methods/tests/test_distance.py | 80 +++++++++++++----------- selector/methods/tests/test_partition.py | 1 - 6 files changed, 61 insertions(+), 54 deletions(-) diff --git a/selector/methods/base.py b/selector/methods/base.py index bfba7730..317d4013 100644 --- a/selector/methods/base.py +++ b/selector/methods/base.py @@ -34,12 +34,13 @@ class SelectionBase(ABC): """Base class for selecting subset of sample points.""" - def select(self, - x: np.ndarray, - size: int, - labels: np.ndarray = None, - proportional_selection: bool = True, - ) -> list: + def select( + self, + x: np.ndarray, + size: int, + labels: np.ndarray = None, + proportional_selection: bool = True, + ) -> list: """Return indices representing subset of sample points. Parameters @@ -130,7 +131,7 @@ def select(self, # (pop < size_each_cluster) and needs to be done iteratively until all remaining clusters # have at least size_each_cluster samples while np.any( - [value <= size_each_cluster for value in pop_clusters.values() if value != 0] + [value <= size_each_cluster for value in pop_clusters.values() if value != 0] ): for unique_label in unique_labels: if pop_clusters[unique_label] != 0: @@ -143,7 +144,8 @@ def select(self, # update number of samples to be selected from each cluster totally_used_clusters = list(pop_clusters.values()).count(0) size_each_cluster = (size - len(np.hstack(selected_ids))) // ( - num_clusters - totally_used_clusters) + num_clusters - totally_used_clusters + ) warnings.warn( f"Number of molecules in one cluster is less than" @@ -165,7 +167,7 @@ def select(self, @abstractmethod def select_from_cluster( - self, x: np.ndarray, size: int, labels: np.ndarray = None + self, x: np.ndarray, size: int, labels: np.ndarray = None ) -> np.ndarray: """Return indices representing subset of sample points from one cluster. diff --git a/selector/methods/distance.py b/selector/methods/distance.py index 26c4c052..40339446 100644 --- a/selector/methods/distance.py +++ b/selector/methods/distance.py @@ -459,9 +459,7 @@ class DISE(SelectionBase): """ - def __init__( - self, r0=None, ref_index=None, tol=0.05, n_iter=10, p=2.0, eps=0.0, fun_dist=None - ): + def __init__(self, r0=None, ref_index=None, tol=0.05, n_iter=10, p=2.0, eps=0.0, fun_dist=None): """ Initialize class. diff --git a/selector/methods/partition.py b/selector/methods/partition.py index ee8a5367..39dc6cae 100644 --- a/selector/methods/partition.py +++ b/selector/methods/partition.py @@ -666,4 +666,3 @@ def select_from_cluster(self, arr, num_selected, cluster_ids=None): ) count += 1 return selected - diff --git a/selector/methods/tests/common.py b/selector/methods/tests/common.py index 1ed5ad0d..9c152d30 100644 --- a/selector/methods/tests/common.py +++ b/selector/methods/tests/common.py @@ -43,7 +43,9 @@ def generate_synthetic_cluster_data(): # generate the second cluster with 6 points cluster_two = np.array([[3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5]]) # generate the third cluster with 9 points - cluster_three = np.array([[6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6], [6, 7], [6, 8]]) + cluster_three = np.array( + [[6, 0], [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [6, 6], [6, 7], [6, 8]] + ) # concatenate the clusters coords = np.vstack([cluster_one, cluster_two, cluster_three]) # generate the labels @@ -51,6 +53,7 @@ def generate_synthetic_cluster_data(): return coords, labels, cluster_one, cluster_two, cluster_three + def generate_synthetic_data( n_samples: int = 100, n_features: int = 2, @@ -136,8 +139,6 @@ def get_data_file_path(file_name): The absolute path of the data file inside the package """ - data_file_path = resources.files("selector.methods.tests").joinpath( - f"data/{file_name}" - ) + data_file_path = resources.files("selector.methods.tests").joinpath(f"data/{file_name}") return data_file_path diff --git a/selector/methods/tests/test_distance.py b/selector/methods/tests/test_distance.py index 2cfbbf31..1e22fb0d 100644 --- a/selector/methods/tests/test_distance.py +++ b/selector/methods/tests/test_distance.py @@ -141,6 +141,7 @@ def test_maxmin(): ) assert_equal(selected_mocked, [0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 15, 10, 13, 9, 18]) + def test_maxmin_proportional_selection(): """Test MaxMin class with proportional selection.""" # generate the first cluster with 3 points @@ -148,13 +149,14 @@ def test_maxmin_proportional_selection(): # instantiate the MaxMin class collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) # select 6 points with proportional selection from each cluster - selected_ids = collector.select(coords, - size=6, - labels=labels, - proportional_selection=True, - ) + selected_ids = collector.select( + coords, + size=6, + labels=labels, + proportional_selection=True, + ) # make sure all the selected indices are the same with expectation - assert_equal(selected_ids,[0, 3, 8, 9, 17, 13]) + assert_equal(selected_ids, [0, 3, 8, 9, 17, 13]) # check how many points are selected from each cluster assert_equal(len(selected_ids), 6) # check the number of points selected from cluster one @@ -177,14 +179,15 @@ def test_maxmin_proportional_selection_imbalance_1(): # instantiate the MaxMin class collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) # select 12 points with proportional selection from each cluster - selected_ids = collector.select(coords, - size=9, - labels=labels, - proportional_selection=True, - ) + selected_ids = collector.select( + coords, + size=9, + labels=labels, + proportional_selection=True, + ) # make sure all the selected indices are the same with expectation - assert_equal(selected_ids,[0, 2, 6, 12, 15, 38, 16, 41, 36]) + assert_equal(selected_ids, [0, 2, 6, 12, 15, 38, 16, 41, 36]) # check how many points are selected from each cluster assert_equal(len(selected_ids), 9) # check the number of points selected from cluster one @@ -207,14 +210,15 @@ def test_maxmin_proportional_selection_imbalance_2(): # instantiate the MaxMin class collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) # select 12 points with proportional selection from each cluster - selected_ids = collector.select(coords, - size=14, - labels=labels, - proportional_selection=True, - ) + selected_ids = collector.select( + coords, + size=14, + labels=labels, + proportional_selection=True, + ) # # make sure all the selected indices are the same with expectation - assert_equal(selected_ids,[0, 3, 9, 6, 14, 36, 53, 17, 44, 23, 28, 50, 52, 49]) + assert_equal(selected_ids, [0, 3, 9, 6, 14, 36, 53, 17, 44, 23, 28, 50, 52, 49]) print(f"selected_ids: {selected_ids}") # check how many points are selected from each cluster assert_equal(len(selected_ids), 14) @@ -355,13 +359,14 @@ def test_maxsum_proportional_selection(): # instantiate the MaxSum class collector = MaxSum(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) # select 6 points with proportional selection from each cluster - selected_ids = collector.select(coords, - size=6, - labels=labels, - proportional_selection=True, - ) + selected_ids = collector.select( + coords, + size=6, + labels=labels, + proportional_selection=True, + ) # make sure all the selected indices are the same with expectation - assert_equal(selected_ids,[0, 3, 8, 9, 17, 10]) + assert_equal(selected_ids, [0, 3, 8, 9, 17, 10]) # check how many points are selected from each cluster assert_equal(len(selected_ids), 6) # check the number of points selected from cluster one @@ -426,13 +431,14 @@ def test_optisim_proportional_selection(): # instantiate the Optisim class collector = OptiSim(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) # select 6 points with proportional selection from each cluster - selected_ids = collector.select(coords, - size=6, - labels=labels, - proportional_selection=True, - ) + selected_ids = collector.select( + coords, + size=6, + labels=labels, + proportional_selection=True, + ) # make sure all the selected indices are the same with expectation - assert_equal(selected_ids,[0, 3, 8, 9, 17, 13]) + assert_equal(selected_ids, [0, 3, 8, 9, 17, 13]) # check how many points are selected from each cluster assert_equal(len(selected_ids), 6) # check the number of points selected from cluster one @@ -442,6 +448,7 @@ def test_optisim_proportional_selection(): # check the number of points selected from cluster three assert_equal((labels[selected_ids] == 2).sum(), 3) + def test_directed_sphere_size_error(): """Test DirectedSphereExclusion error when too many points requested.""" x = np.array([[1, 9]] * 100) @@ -552,13 +559,14 @@ def test_directed_sphere_proportional_selection(): # instantiate the DISE class collector = DISE(fun_dist=lambda x: pairwise_distances(x, metric="euclidean"), ref_index=0) # select 6 points with proportional selection from each cluster - selected_ids = collector.select(coords, - size=6, - labels=labels, - proportional_selection=True, - ) + selected_ids = collector.select( + coords, + size=6, + labels=labels, + proportional_selection=True, + ) # make sure all the selected indices are the same with expectation - assert_equal(selected_ids,[0, 3, 7, 9, 12, 15]) + assert_equal(selected_ids, [0, 3, 7, 9, 12, 15]) # check how many points are selected from each cluster assert_equal(len(selected_ids), 6) # check the number of points selected from cluster one diff --git a/selector/methods/tests/test_partition.py b/selector/methods/tests/test_partition.py index ced6cc8d..3e613e2d 100644 --- a/selector/methods/tests/test_partition.py +++ b/selector/methods/tests/test_partition.py @@ -183,4 +183,3 @@ def test_medoid(): selector = Medoid() selected_ids = selector.select(features, size=2) assert_equal(selected_ids, [0, 3]) - From fd4d39f5857cd484621730570452a91b843e09a1 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 08:39:42 -0400 Subject: [PATCH 08/12] Add data points to smallest cluster when not enough data points --- selector/methods/base.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/selector/methods/base.py b/selector/methods/base.py index 317d4013..32971915 100644 --- a/selector/methods/base.py +++ b/selector/methods/base.py @@ -98,28 +98,25 @@ def select( # the total number of samples selected from all clusters at this point size_each_cluster_total = np.sum(size_each_cluster) - # Adjust if the total is less than the required number + # when the total of data points in each class is less than the required number + # add one sample to the smallest cluster iteratively until the total is equal to the + # required number if size_each_cluster_total < size: while size_each_cluster_total < size: - # select the largest cluster with maximum number of data points not selected - # and add one sample to it - largest_cluster_index = np.argmax(unique_label_counts - size_each_cluster) - size_each_cluster[largest_cluster_index] += 1 + # the number of remaining data points in each cluster + size_each_cluster_remaining = unique_label_counts - size_each_cluster_total + # skip the clusters with no data points left + size_each_cluster_remaining[size_each_cluster_remaining == 0] = np.inf + smallest_cluster_index = np.argmin(size_each_cluster_remaining) + size_each_cluster[smallest_cluster_index] += 1 size_each_cluster_total += 1 - # Adjust if the total is more than the required number + # when the total of data points in each class is more than the required number + # we need to remove samples from the largest clusters elif size_each_cluster_total > size: while size_each_cluster_total > size: - largest_cluster_index = np.argmax(unique_label_counts - size_each_cluster) + largest_cluster_index = np.argmax(size_each_cluster) size_each_cluster[largest_cluster_index] -= 1 size_each_cluster_total -= 1 - - # # when the total number of samples selected is more than the required number - # # we need to remove samples from the largest clusters - # while size_each_cluster_total > size: - # largest_cluster_index = np.argmax(size_each_cluster) - # size_each_cluster[largest_cluster_index] -= 1 - # size_each_cluster_total -= 1 - # perfect case where the total is equal to the required number else: pass From c0d68577ce6ed748c83d347294b4f72adb980224 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 08:47:58 -0400 Subject: [PATCH 09/12] Add test for checking the number of labels match the number of total data points --- selector/methods/tests/test_distance.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/selector/methods/tests/test_distance.py b/selector/methods/tests/test_distance.py index 1e22fb0d..9f6d92d8 100644 --- a/selector/methods/tests/test_distance.py +++ b/selector/methods/tests/test_distance.py @@ -107,11 +107,18 @@ def test_maxmin(): # test failing case when ref_index contains a complex number with pytest.raises(ValueError): collector_float = MaxMin(ref_index=[1 + 5j, 2, 5]) - selected_ids_float = collector_float.select(arr_dist, size=12) + _ = collector_float.select(arr_dist, size=12) # test failing case when ref_index contains a negative number with pytest.raises(ValueError): collector_float = MaxMin(ref_index=[-1, 2, 5]) - selected_ids_float = collector_float.select(arr_dist, size=12) + _ = collector_float.select(arr_dist, size=12) + + # test failing case when the number of labels is not equal to the number of samples + with pytest.raises(ValueError): + collector_float = MaxMin(ref_index=85) + _ = collector_float.select( + arr_dist, size=12, labels=class_labels_cluster[:90], proportional_selection=False + ) # use MaxMin algorithm, this time instantiating with a distance metric collector = MaxMin(fun_dist=lambda x: pairwise_distances(x, metric="euclidean")) From 948bc27bc4956e3fd10866f387be999d057d6d69 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 09:23:36 -0400 Subject: [PATCH 10/12] Ignore NotImplementedError in coverage report --- .coveragerc | 4 ++++ selector/methods/base.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index 44911367..769d7432 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,3 +7,7 @@ omit = [report] show_missing = True +exclude_also = + pragma: no cover + raise NotImplementedError + if __name__ == .__main__.: diff --git a/selector/methods/base.py b/selector/methods/base.py index 32971915..2a7cd86e 100644 --- a/selector/methods/base.py +++ b/selector/methods/base.py @@ -165,7 +165,7 @@ def select( @abstractmethod def select_from_cluster( self, x: np.ndarray, size: int, labels: np.ndarray = None - ) -> np.ndarray: + ) -> np.ndarray: # pragma: no cover """Return indices representing subset of sample points from one cluster. Parameters From a7de226f6081a624388aa52037d2a715ab0f0aa7 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 15:58:37 -0400 Subject: [PATCH 11/12] Add typing hints for returns --- selector/methods/distance.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/selector/methods/distance.py b/selector/methods/distance.py index 40339446..c35acdc7 100644 --- a/selector/methods/distance.py +++ b/selector/methods/distance.py @@ -26,6 +26,7 @@ import bitarray import numpy as np from scipy import spatial +from typing import List, Iterable, Union from selector.methods.base import SelectionBase from selector.methods.utils import optimize_radius @@ -86,7 +87,7 @@ def __init__(self, fun_dist=None, ref_index=None): self.fun_dist = fun_dist self.ref_index = ref_index - def select_from_cluster(self, x, size, labels=None): + def select_from_cluster(self, x, size, labels=None) -> Union[List, Iterable]: """Return selected samples from a cluster based on MaxMin algorithm. Parameters @@ -102,7 +103,7 @@ def select_from_cluster(self, x, size, labels=None): Returns ------- - selected : list + selected : Union[List, Iterable] List of indices of selected samples. """ # calculate pairwise distance between points @@ -134,6 +135,8 @@ def select_from_cluster(self, x, size, labels=None): new_id = np.argmax(min_distances) selected.append(new_id) + selected = [int(i) for i in selected] + return selected @@ -184,7 +187,7 @@ def __init__(self, fun_dist=None, ref_index=None): self.fun_dist = fun_dist self.ref_index = ref_index - def select_from_cluster(self, x, size, labels=None): + def select_from_cluster(self, x, size, labels=None) -> Union[List, Iterable]: """Return selected samples from a cluster based on MaxSum algorithm. Parameters @@ -200,7 +203,7 @@ def select_from_cluster(self, x, size, labels=None): Returns ------- - selected : list + selected : Union[List, Iterable] List of indices of selected samples. """ @@ -237,6 +240,8 @@ def select_from_cluster(self, x, size, labels=None): # already-selected points new_id = np.argmax(sum_distances) selected.append(new_id) + + selected = [int(i) for i in selected] return selected @@ -261,6 +266,7 @@ class 0 and `ref_index=[3, 6]` class 1 respectively. References ---------- [1] J. Chem. Inf. Comput. Sci. 1997, 37, 6, 1181–1188. https://doi.org/10.1021/ci970282v + """ def __init__( @@ -330,7 +336,7 @@ def __init__( self.random_seed = random_seed self.fun_dist = fun_dist - def algorithm(self, x, max_size) -> list: + def algorithm(self, x, max_size) -> Union[List, Iterable]: """Return selected sample indices based on OptiSim algorithm. Parameters @@ -342,7 +348,7 @@ def algorithm(self, x, max_size) -> list: Returns ------- - selected : list + selected : Union[List, Iterable] List of indices of selected sample indices. """ @@ -402,7 +408,7 @@ def algorithm(self, x, max_size) -> list: return selected - def select_from_cluster(self, x, size, labels=None): + def select_from_cluster(self, x, size, labels=None) -> Union[List, Iterable]: """Return selected samples from a cluster based on OptiSim algorithm. Parameters @@ -416,7 +422,7 @@ def select_from_cluster(self, x, size, labels=None): Returns ------- - selected : list + selected : Union[List, Iterable] List of indices of selected samples. """ @@ -509,7 +515,7 @@ def __init__(self, r0=None, ref_index=None, tol=0.05, n_iter=10, p=2.0, eps=0.0, # self.fun_dist = fun_dist self.fun_dist = fun_dist - def algorithm(self, x, max_size): + def algorithm(self, x, max_size) -> Union[List, Iterable]: """Return selected samples based on directed sphere exclusion algorithm. Parameters @@ -521,7 +527,7 @@ def algorithm(self, x, max_size): Returns ------- - selected: list + selected: Union[List, Iterable] List of indices of selected samples. """ @@ -591,7 +597,7 @@ def algorithm(self, x, max_size): return selected - def select_from_cluster(self, x, size, labels=None): + def select_from_cluster(self, x, size, labels=None) -> Union[List, Iterable]: """Return selected samples from a cluster based on directed sphere exclusion algorithm Parameters @@ -605,7 +611,7 @@ def select_from_cluster(self, x, size, labels=None): Returns ------- - selected: list + selected: Union[List, Iterable] List of indices of selected samples. """ @@ -623,7 +629,7 @@ def select_from_cluster(self, x, size, labels=None): return optimize_radius(self, x, size, labels) -def get_initial_selection(x=None, x_dist=None, ref_index=None, fun_dist=None): +def get_initial_selection(x=None, x_dist=None, ref_index=None, fun_dist=None) -> List: """Set up the reference index for selecting. Parameters @@ -648,7 +654,7 @@ def get_initial_selection(x=None, x_dist=None, ref_index=None, fun_dist=None): Returns ------- - initial_selections: list + initial_selections: List List of indices of the initial selected data points. """ From fafc517b2e8db333fc3e72ed9c4d607904577546 Mon Sep 17 00:00:00 2001 From: Fanwang Meng Date: Sun, 6 Oct 2024 17:07:47 -0400 Subject: [PATCH 12/12] Add typing hints --- selector/methods/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/selector/methods/base.py b/selector/methods/base.py index 2a7cd86e..e49c2663 100644 --- a/selector/methods/base.py +++ b/selector/methods/base.py @@ -25,6 +25,7 @@ import warnings from abc import ABC, abstractmethod +from typing import List, Iterable, Union import numpy as np @@ -40,7 +41,7 @@ def select( size: int, labels: np.ndarray = None, proportional_selection: bool = True, - ) -> list: + ) -> Union[List, Iterable]: """Return indices representing subset of sample points. Parameters