Skip to content

Commit

Permalink
black, ruff, mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
JiQi535 committed Jan 28, 2024
1 parent 46d7749 commit fe71d77
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 104 deletions.
9 changes: 2 additions & 7 deletions maml/describers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@
from __future__ import annotations

from ._composition import ElementProperty, ElementStats
from ._m3gnet import M3GNetStructure, M3GNetSite
from ._m3gnet import M3GNetSite, M3GNetStructure
from ._matminer import wrap_matminer_describer
from ._megnet import MEGNetSite, MEGNetStructure
from ._rdf import RadialDistributionFunction
from ._site import (
BispectrumCoefficients,
BPSymmetryFunctions,
SiteElementProperty,
SmoothOverlapAtomicPosition,
)
from ._site import BispectrumCoefficients, BPSymmetryFunctions, SiteElementProperty, SmoothOverlapAtomicPosition
from ._structure import (
CoulombEigenSpectrum,
CoulombMatrix,
Expand Down
33 changes: 13 additions & 20 deletions maml/describers/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
if TYPE_CHECKING:
from pymatgen.core import Molecule, Structure

DEFAULT_MODEL = (
Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/"
)
DEFAULT_MODEL = Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/"


class M3GNetStructure(BaseDescriber):
Expand All @@ -25,12 +23,14 @@ def __init__(
**kwargs,
):
"""
Args:
model_path (str): m3gnet models path. If no path is provided,
the models will be M3GNet formation energy model on figshare:
https://figshare.com/articles/software/m3gnet_property_model_weights/20099465
Please refer to the M3GNet paper:
https://doi.org/10.1038/s43588-022-00349-3.
**kwargs: Pass through to BaseDescriber.
"""
from m3gnet.models import M3GNet

Expand All @@ -39,12 +39,13 @@ def __init__(
self.model_path = model_path
else:
self.describer_model = M3GNet.from_dir(DEFAULT_MODEL)
self.model_path = DEFAULT_MODEL
self.model_path = str(DEFAULT_MODEL)
super().__init__(**kwargs)

def transform_one(self, structure: Structure | Molecule):
"""
Transform structure/molecule objects into structural features.
Args:
structure (Structure/Molecule): target object structure or molecule
Returns: M3GNet readout layer output as structural features.
Expand All @@ -56,9 +57,7 @@ def transform_one(self, structure: Structure | Molecule):
graph = self.describer_model.graph_converter.convert(structure).as_list()
graph = tf_compute_distance_angle(graph)
three_basis = self.describer_model.basis_expansion(graph)
three_cutoff = polynomial(
graph[Index.BONDS], self.describer_model.threebody_cutoff
)
three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff)
g = self.describer_model.featurizer(graph)
g = self.describer_model.feature_adjust(g)
for i in range(self.describer_model.n_blocks):
Expand All @@ -79,6 +78,7 @@ def __init__(
**kwargs,
):
"""
Args:
model_path (str): m3gnet models path. If no path is provided,
the models will be M3GNet formation energy model on figshare:
Expand All @@ -90,6 +90,7 @@ def __init__(
"gc_1" layer are returned.
return_type: The data type of the returned the atom features. By default, atom features in different
output_layers are concatenated to one vector per atom, and a dataframe of vectors are returned.
**kwargs: Pass through to BaseDescriber. E.g., feature_batch="pandas_concat" is very useful (see test).
"""
from m3gnet.models import M3GNet

Expand All @@ -98,18 +99,12 @@ def __init__(
self.model_path = model_path
else:
self.describer_model = M3GNet.from_dir(DEFAULT_MODEL)
self.model_path = DEFAULT_MODEL
allowed_output_layers = ["embedding"] + [
f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)
]
self.model_path = str(DEFAULT_MODEL)
allowed_output_layers = ["embedding"] + [f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)]
if output_layers is None:
output_layers = ["gc_1"]
elif not isinstance(output_layers, list) or set(output_layers).difference(
allowed_output_layers
):
raise ValueError(
f"Invalid output_layers, it must be a sublist of {allowed_output_layers}."
)
elif not isinstance(output_layers, list) or set(output_layers).difference(allowed_output_layers):
raise ValueError(f"Invalid output_layers, it must be a sublist of {allowed_output_layers}.")
self.output_layers = output_layers
self.return_type = return_type
super().__init__(**kwargs)
Expand All @@ -128,9 +123,7 @@ def transform_one(self, structure: Structure | Molecule):
graph = self.describer_model.graph_converter.convert(structure).as_list()
graph = tf_compute_distance_angle(graph)
three_basis = self.describer_model.basis_expansion(graph)
three_cutoff = polynomial(
graph[Index.BONDS], self.describer_model.threebody_cutoff
)
three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff)
g = self.describer_model.featurizer(graph)
atom_fea = {"embedding": g[Index.ATOMS]}
g = self.describer_model.feature_adjust(g)
Expand Down
27 changes: 8 additions & 19 deletions maml/sampling/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class BirchClustering(BaseEstimator, TransformerMixin):
""" "Birch Clustering as one step of the DIRECT pipeline."""
"""Birch Clustering as one step of the DIRECT pipeline."""

def __init__(self, n=None, threshold_init=0.5, **kwargs):
"""
Expand All @@ -27,6 +27,7 @@ def __init__(self, n=None, threshold_init=0.5, **kwargs):
Users may tune this value for desired performance of birch, while 0.5
is generally a good starting point, and some automatic tuning is done
with our built-in codes to achieve n clusters if given.
**kwargs: Pass to BIRCH.
"""
self.n = n
self.threshold_init = threshold_init
Expand Down Expand Up @@ -56,28 +57,16 @@ def transform(self, PCAfeatures):
PCA feature, centroid positions of each cluster in PCA feature s
pace, and the array of input PCA features.
"""
model = Birch(
n_clusters=self.n, threshold=self.threshold_init, **self.kwargs
).fit(PCAfeatures)
model = Birch(n_clusters=self.n, threshold=self.threshold_init, **self.kwargs).fit(PCAfeatures)
if self.n is not None:
while (
len(model.subcluster_labels_) < self.n
): # decrease threshold until desired n clusters is achieved
logger.info(
f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters."
)
self.threshold_init = (
self.threshold_init / self.n * len(model.subcluster_labels_)
)
model = Birch(
n_clusters=self.n, threshold=self.threshold_init, **self.kwargs
).fit(PCAfeatures)
while len(model.subcluster_labels_) < self.n: # decrease threshold until desired n clusters is achieved
logger.info(f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters.")
self.threshold_init = self.threshold_init / self.n * len(model.subcluster_labels_)
model = Birch(n_clusters=self.n, threshold=self.threshold_init, **self.kwargs).fit(PCAfeatures)

labels = model.predict(PCAfeatures)
self.model = model
logger.info(
f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters."
)
logger.info(f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters.")
label_centers = dict(zip(model.subcluster_labels_, model.subcluster_centers_))
return {
"labels": labels,
Expand Down
13 changes: 4 additions & 9 deletions maml/sampling/direct.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""DIRECT sampling."""
from __future__ import annotations

from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -40,21 +41,15 @@ def __init__(
select_k_from_clusters: Straitified sampling of k structures from
each cluster.
"""
self.structure_encoder = (
M3GNetStructure() if structure_encoder == "M3GNet" else structure_encoder
)
self.structure_encoder = M3GNetStructure() if structure_encoder == "M3GNet" else structure_encoder
self.scaler = StandardScaler() if scaler == "StandardScaler" else scaler
self.pca = (
PrincipalComponentAnalysis(weighting_PCs=weighting_PCs)
if pca == "PrincipalComponentAnalysis"
else pca
PrincipalComponentAnalysis(weighting_PCs=weighting_PCs) if pca == "PrincipalComponentAnalysis" else pca
)
self.weighting_PCs = weighting_PCs
self.clustering = BirchClustering() if clustering == "Birch" else clustering
self.select_k_from_clusters = (
SelectKFromClusters()
if select_k_from_clusters == "select_k_from_clusters"
else select_k_from_clusters
SelectKFromClusters() if select_k_from_clusters == "select_k_from_clusters" else select_k_from_clusters
)
steps = [
(i.__class__.__name__, i)
Expand Down
46 changes: 13 additions & 33 deletions maml/sampling/stratified_sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Implementation of stratefied sampling approaches."""
"""Implementation of stratified sampling approaches."""
from __future__ import annotations

import logging
Expand Down Expand Up @@ -40,13 +40,9 @@ def __init__(
self.allow_duplicate = allow_duplicate
allowed_selection_criterion = ["random", "smallest", "center"]
if selection_criteria not in allowed_selection_criterion:
raise ValueError(
f"Invalid selection_criteria, it must be one of {allowed_selection_criterion}."
)
elif selection_criteria == "smallest" and not n_sites:
raise ValueError(
'n_sites must be provided when selection_criteria="smallest."'
)
raise ValueError(f"Invalid selection_criteria, it must be one of {allowed_selection_criterion}.")
if selection_criteria == "smallest" and not n_sites:
raise ValueError('n_sites must be provided when selection_criteria="smallest."')
self.selection_criteria = selection_criteria
self.n_sites = n_sites

Expand Down Expand Up @@ -82,10 +78,7 @@ def transform(self, clustering_data: dict):
raise Exception(
"The data returned by clustering step should at least provide label and feature information."
)
if (
self.selection_criteria == "center"
and "label_centers" not in clustering_data
):
if self.selection_criteria == "center" and "label_centers" not in clustering_data:
warnings.warn(
"Centroid location is not provided, so random selection from each cluster will be performed, "
"which likely will still outperform manual sampling in terms of feature coverage. "
Expand All @@ -94,32 +87,21 @@ def transform(self, clustering_data: dict):
try:
assert len(self.n_sites) == len(clustering_data["PCAfeatures"])
except Exception:
raise ValueError(
"n_sites must have same length as features processed in clustering."
)
raise ValueError("n_sites must have same length as features processed in clustering.")

selected_indexes = []
for label in set(clustering_data["labels"]):
indexes_same_label = np.where(label == clustering_data["labels"])[0]
features_same_label = clustering_data["PCAfeatures"][indexes_same_label]
n_same_label = len(features_same_label)
if (
"label_centers" in clustering_data
and self.selection_criteria == "center"
):
if "label_centers" in clustering_data and self.selection_criteria == "center":
center_same_label = clustering_data["label_centers"][label]
distance_to_center = np.linalg.norm(
features_same_label - center_same_label, axis=1
).reshape(len(indexes_same_label))
select_k_indexes = [
int(i) for i in np.linspace(0, n_same_label - 1, self.k)
]
distance_to_center = np.linalg.norm(features_same_label - center_same_label, axis=1).reshape(
len(indexes_same_label)
)
select_k_indexes = [int(i) for i in np.linspace(0, n_same_label - 1, self.k)]
selected_indexes.extend(
indexes_same_label[
np.argpartition(distance_to_center, select_k_indexes)[
select_k_indexes
]
]
indexes_same_label[np.argpartition(distance_to_center, select_k_indexes)[select_k_indexes]]
)
elif self.selection_criteria == "smallest":
if self.k >= n_same_label:
Expand All @@ -135,9 +117,7 @@ def transform(self, clustering_data: dict):
]
)
else:
selected_indexes.extend(
indexes_same_label[np.random.randint(n_same_label, size=self.k)]
)
selected_indexes.extend(indexes_same_label[np.random.randint(n_same_label, size=self.k)])
n_duplicate = len(selected_indexes) - len(set(selected_indexes))
if not self.allow_duplicate and n_duplicate > 0:
selected_indexes = list(set(selected_indexes))
Expand Down
2 changes: 1 addition & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def make_doc(ctx):
contents = re.sub(
r"\n## Official Documentation[^#]*",
"{: .no_toc }\n\n## Table of contents\n{: .no_toc .text-delta }\n* TOC\n{:toc}\n\n",
contents
contents,
)
contents = "---\nlayout: default\ntitle: Home\nnav_order: 1\n---\n\n" + contents

Expand Down
12 changes: 3 additions & 9 deletions tests/sampling/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,14 @@

class TestDIRECTSampler:
def setup(self):
self.direct_fixed_n = DIRECTSampler(
structure_encoder=None, clustering=BirchClustering(n=1)
)
self.direct_fixed_n = DIRECTSampler(structure_encoder=None, clustering=BirchClustering(n=1))
self.direct_fixed_t = DIRECTSampler(
structure_encoder=None,
clustering=BirchClustering(n=None, threshold_init=0.5),
)

def test_fit_transform(self, MPF_2021_2_8_first10_features_test):
result_fixed_n = self.direct_fixed_n.fit_transform(
MPF_2021_2_8_first10_features_test["M3GNet_features"]
)
result_fixed_t = self.direct_fixed_t.fit_transform(
MPF_2021_2_8_first10_features_test["M3GNet_features"]
)
result_fixed_n = self.direct_fixed_n.fit_transform(MPF_2021_2_8_first10_features_test["M3GNet_features"])
result_fixed_t = self.direct_fixed_t.fit_transform(MPF_2021_2_8_first10_features_test["M3GNet_features"])
assert result_fixed_n["selected_indexes"] == [9]
assert result_fixed_t["selected_indexes"] == [0, 6, 7, 8, 9]
8 changes: 2 additions & 6 deletions tests/sampling/test_stratified_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ class TestSelectKFromClusters:
def setup(self):
self.selector_uni = SelectKFromClusters(k=2, allow_duplicate=False)
self.selector_dup = SelectKFromClusters(k=2, allow_duplicate=True)
self.selector_rand = SelectKFromClusters(
k=2, allow_duplicate=False, selection_criteria="random"
)
self.selector_rand = SelectKFromClusters(k=2, allow_duplicate=False, selection_criteria="random")
self.selector_small = SelectKFromClusters(
k=2,
allow_duplicate=False,
Expand All @@ -30,9 +28,7 @@ def test_exceptions(self, Birch_results):
SelectKFromClusters(selection_criteria="whatever")
with pytest.raises(ValueError, match="n_sites must be provided"):
SelectKFromClusters(selection_criteria="smallest")
with pytest.raises(
ValueError, match="n_sites must have same length as features"
):
with pytest.raises(ValueError, match="n_sites must have same length as features"):
self.selector_small_wrong_n_sites.transform(Birch_results)

def test_fit(self, Birch_results):
Expand Down

0 comments on commit fe71d77

Please sign in to comment.