Skip to content

Commit

Permalink
selection_criteria for stratified sampling, i.g., smallest, random, a…
Browse files Browse the repository at this point in the history
…nd center
  • Loading branch information
JiQi535 committed Jan 28, 2024
1 parent 0ee886f commit 00c8eae
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 5 deletions.
61 changes: 56 additions & 5 deletions maml/sampling/stratified_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,40 @@
class SelectKFromClusters(BaseEstimator, TransformerMixin):
"""Wrapper around selection of k data from each cluster."""

def __init__(self, k: int = 1, allow_duplicate=False):
def __init__(
self,
k: int = 1,
allow_duplicate=False,
selection_criteria="center",
n_sites=None,
):
"""
Args:
k: Select k structures from each cluster.
allow_duplicate: Whether structures are allowed to be
selected over once.
allow_duplicate: Whether structures are allowed to be selected over once.
selection_criteria: The criteria to do stratified sampling from each cluster. Supported criterion
include "random", "smallest", and "center" (default). By default, structures are ranked with respect to
their distances to the centroid of the cluster they locate, then up to k structures with fixed ranking
intervals are selected from each cluster, and when k=1, the structure with the smallest Euclidean
distance to the centroid of each cluster is sampled. For "random", k structures are randomly sampled
with replacement. For "smallest", it is ensured to select the k structures with the least number of
atoms in each cluster.
n_sites: The number of sites in all the structures to sample from. Only needed when
selection_criteria="smallest".
"""
self.k = k
self.allow_duplicate = allow_duplicate
allowed_selection_criterion = ["random", "smallest", "center"]
if selection_criteria not in allowed_selection_criterion:
raise ValueError(
f"Invalid selection_criteria, it must be one of {allowed_selection_criterion}."
)
elif selection_criteria == "smallest" and not n_sites:
raise ValueError(
'n_sites must be provided when selection_criteria="smallest."'
)
self.selection_criteria = selection_criteria
self.n_sites = n_sites

def fit(self, X, y=None):
"""
Expand Down Expand Up @@ -57,18 +82,31 @@ def transform(self, clustering_data: dict):
raise Exception(
"The data returned by clustering step should at least provide label and feature information."
)
if "label_centers" not in clustering_data:
if (
self.selection_criteria == "center"
and "label_centers" not in clustering_data
):
warnings.warn(
"Centroid location is not provided, so random selection from each cluster will be performed, "
"which likely will still outperform manual sampling in terms of feature coverage. "
)
if self.selection_criteria == "smallest":
try:
assert len(self.n_sites) == len(clustering_data["PCAfeatures"])
except Exception:
raise ValueError(
"n_sites must have same length as features processed in clustering."
)

selected_indexes = []
for label in set(clustering_data["labels"]):
indexes_same_label = np.where(label == clustering_data["labels"])[0]
features_same_label = clustering_data["PCAfeatures"][indexes_same_label]
n_same_label = len(features_same_label)
if "label_centers" in clustering_data:
if (
"label_centers" in clustering_data
and self.selection_criteria == "center"
):
center_same_label = clustering_data["label_centers"][label]
distance_to_center = np.linalg.norm(
features_same_label - center_same_label, axis=1
Expand All @@ -83,6 +121,19 @@ def transform(self, clustering_data: dict):
]
]
)
elif self.selection_criteria == "smallest":
if self.k >= n_same_label:
selected_indexes.extend(indexes_same_label)
else:
select_k_indexes = np.arange(self.k)
selected_indexes.extend(
indexes_same_label[
np.argpartition(
np.array(self.n_sites)[indexes_same_label],
select_k_indexes,
)[select_k_indexes]
]
)
else:
selected_indexes.extend(
indexes_same_label[np.random.randint(n_same_label, size=self.k)]
Expand Down
33 changes: 33 additions & 0 deletions tests/sampling/test_stratified_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,38 @@

from maml.sampling.stratified_sampling import SelectKFromClusters

import pytest


class TestSelectKFromClusters:
def setup(self):
self.selector_uni = SelectKFromClusters(k=2, allow_duplicate=False)
self.selector_dup = SelectKFromClusters(k=2, allow_duplicate=True)
self.selector_rand = SelectKFromClusters(
k=2, allow_duplicate=False, selection_criteria="random"
)
self.selector_small = SelectKFromClusters(
k=2,
allow_duplicate=False,
selection_criteria="smallest",
n_sites=range(10),
)
self.selector_small_wrong_n_sites = SelectKFromClusters(
k=2,
allow_duplicate=False,
selection_criteria="smallest",
n_sites=range(11),
)

def test_exceptions(self, Birch_results):
with pytest.raises(ValueError, match="Invalid selection_criteria"):
SelectKFromClusters(selection_criteria="whatever")
with pytest.raises(ValueError, match="n_sites must be provided"):
SelectKFromClusters(selection_criteria="smallest")
with pytest.raises(
ValueError, match="n_sites must have same length as features"
):
self.selector_small_wrong_n_sites.transform(Birch_results)

def test_fit(self, Birch_results):
assert self.selector_uni == self.selector_uni.fit(Birch_results)
Expand All @@ -24,3 +51,9 @@ def test_transform(self, Birch_results):
7,
7,
]
assert len(self.selector_rand.transform(Birch_results)["selected_indexes"]) == 3
assert self.selector_small.transform(Birch_results)["selected_indexes"] == [
0,
1,
7,
]

0 comments on commit 00c8eae

Please sign in to comment.