diff --git a/maml/describers/__init__.py b/maml/describers/__init__.py index af5082de..c21d6fca 100644 --- a/maml/describers/__init__.py +++ b/maml/describers/__init__.py @@ -5,16 +5,11 @@ from __future__ import annotations from ._composition import ElementProperty, ElementStats -from ._m3gnet import M3GNetStructure, M3GNetSite +from ._m3gnet import M3GNetSite, M3GNetStructure from ._matminer import wrap_matminer_describer from ._megnet import MEGNetSite, MEGNetStructure from ._rdf import RadialDistributionFunction -from ._site import ( - BispectrumCoefficients, - BPSymmetryFunctions, - SiteElementProperty, - SmoothOverlapAtomicPosition, -) +from ._site import BispectrumCoefficients, BPSymmetryFunctions, SiteElementProperty, SmoothOverlapAtomicPosition from ._structure import ( CoulombEigenSpectrum, CoulombMatrix, diff --git a/maml/describers/_m3gnet.py b/maml/describers/_m3gnet.py index 3f78bb11..e16bb74e 100644 --- a/maml/describers/_m3gnet.py +++ b/maml/describers/_m3gnet.py @@ -11,9 +11,7 @@ if TYPE_CHECKING: from pymatgen.core import Molecule, Structure -DEFAULT_MODEL = ( - Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/" -) +DEFAULT_MODEL = Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/" class M3GNetStructure(BaseDescriber): @@ -25,12 +23,14 @@ def __init__( **kwargs, ): """ + Args: model_path (str): m3gnet models path. If no path is provided, the models will be M3GNet formation energy model on figshare: https://figshare.com/articles/software/m3gnet_property_model_weights/20099465 Please refer to the M3GNet paper: https://doi.org/10.1038/s43588-022-00349-3. + **kwargs: Pass through to BaseDescriber. """ from m3gnet.models import M3GNet @@ -39,12 +39,13 @@ def __init__( self.model_path = model_path else: self.describer_model = M3GNet.from_dir(DEFAULT_MODEL) - self.model_path = DEFAULT_MODEL + self.model_path = str(DEFAULT_MODEL) super().__init__(**kwargs) def transform_one(self, structure: Structure | Molecule): """ Transform structure/molecule objects into structural features. + Args: structure (Structure/Molecule): target object structure or molecule Returns: M3GNet readout layer output as structural features. @@ -56,9 +57,7 @@ def transform_one(self, structure: Structure | Molecule): graph = self.describer_model.graph_converter.convert(structure).as_list() graph = tf_compute_distance_angle(graph) three_basis = self.describer_model.basis_expansion(graph) - three_cutoff = polynomial( - graph[Index.BONDS], self.describer_model.threebody_cutoff - ) + three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff) g = self.describer_model.featurizer(graph) g = self.describer_model.feature_adjust(g) for i in range(self.describer_model.n_blocks): @@ -79,6 +78,7 @@ def __init__( **kwargs, ): """ + Args: model_path (str): m3gnet models path. If no path is provided, the models will be M3GNet formation energy model on figshare: @@ -90,6 +90,7 @@ def __init__( "gc_1" layer are returned. return_type: The data type of the returned the atom features. By default, atom features in different output_layers are concatenated to one vector per atom, and a dataframe of vectors are returned. + **kwargs: Pass through to BaseDescriber. E.g., feature_batch="pandas_concat" is very useful (see test). """ from m3gnet.models import M3GNet @@ -98,18 +99,12 @@ def __init__( self.model_path = model_path else: self.describer_model = M3GNet.from_dir(DEFAULT_MODEL) - self.model_path = DEFAULT_MODEL - allowed_output_layers = ["embedding"] + [ - f"gc_{i + 1}" for i in range(self.describer_model.n_blocks) - ] + self.model_path = str(DEFAULT_MODEL) + allowed_output_layers = ["embedding"] + [f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)] if output_layers is None: output_layers = ["gc_1"] - elif not isinstance(output_layers, list) or set(output_layers).difference( - allowed_output_layers - ): - raise ValueError( - f"Invalid output_layers, it must be a sublist of {allowed_output_layers}." - ) + elif not isinstance(output_layers, list) or set(output_layers).difference(allowed_output_layers): + raise ValueError(f"Invalid output_layers, it must be a sublist of {allowed_output_layers}.") self.output_layers = output_layers self.return_type = return_type super().__init__(**kwargs) @@ -128,9 +123,7 @@ def transform_one(self, structure: Structure | Molecule): graph = self.describer_model.graph_converter.convert(structure).as_list() graph = tf_compute_distance_angle(graph) three_basis = self.describer_model.basis_expansion(graph) - three_cutoff = polynomial( - graph[Index.BONDS], self.describer_model.threebody_cutoff - ) + three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff) g = self.describer_model.featurizer(graph) atom_fea = {"embedding": g[Index.ATOMS]} g = self.describer_model.feature_adjust(g) diff --git a/maml/sampling/clustering.py b/maml/sampling/clustering.py index 867937df..d64ccaa3 100644 --- a/maml/sampling/clustering.py +++ b/maml/sampling/clustering.py @@ -12,7 +12,7 @@ class BirchClustering(BaseEstimator, TransformerMixin): - """ "Birch Clustering as one step of the DIRECT pipeline.""" + """Birch Clustering as one step of the DIRECT pipeline.""" def __init__(self, n=None, threshold_init=0.5, **kwargs): """ @@ -27,6 +27,7 @@ def __init__(self, n=None, threshold_init=0.5, **kwargs): Users may tune this value for desired performance of birch, while 0.5 is generally a good starting point, and some automatic tuning is done with our built-in codes to achieve n clusters if given. + **kwargs: Pass to BIRCH. """ self.n = n self.threshold_init = threshold_init @@ -56,28 +57,16 @@ def transform(self, PCAfeatures): PCA feature, centroid positions of each cluster in PCA feature s pace, and the array of input PCA features. """ - model = Birch( - n_clusters=self.n, threshold=self.threshold_init, **self.kwargs - ).fit(PCAfeatures) + model = Birch(n_clusters=self.n, threshold=self.threshold_init, **self.kwargs).fit(PCAfeatures) if self.n is not None: - while ( - len(model.subcluster_labels_) < self.n - ): # decrease threshold until desired n clusters is achieved - logger.info( - f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters." - ) - self.threshold_init = ( - self.threshold_init / self.n * len(model.subcluster_labels_) - ) - model = Birch( - n_clusters=self.n, threshold=self.threshold_init, **self.kwargs - ).fit(PCAfeatures) + while len(model.subcluster_labels_) < self.n: # decrease threshold until desired n clusters is achieved + logger.info(f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters.") + self.threshold_init = self.threshold_init / self.n * len(model.subcluster_labels_) + model = Birch(n_clusters=self.n, threshold=self.threshold_init, **self.kwargs).fit(PCAfeatures) labels = model.predict(PCAfeatures) self.model = model - logger.info( - f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters." - ) + logger.info(f"Birch threshold of {self.threshold_init} gives {len(model.subcluster_labels_)} clusters.") label_centers = dict(zip(model.subcluster_labels_, model.subcluster_centers_)) return { "labels": labels, diff --git a/maml/sampling/direct.py b/maml/sampling/direct.py index eea1b627..89338351 100644 --- a/maml/sampling/direct.py +++ b/maml/sampling/direct.py @@ -1,3 +1,4 @@ +"""DIRECT sampling.""" from __future__ import annotations from sklearn.pipeline import Pipeline @@ -40,21 +41,15 @@ def __init__( select_k_from_clusters: Straitified sampling of k structures from each cluster. """ - self.structure_encoder = ( - M3GNetStructure() if structure_encoder == "M3GNet" else structure_encoder - ) + self.structure_encoder = M3GNetStructure() if structure_encoder == "M3GNet" else structure_encoder self.scaler = StandardScaler() if scaler == "StandardScaler" else scaler self.pca = ( - PrincipalComponentAnalysis(weighting_PCs=weighting_PCs) - if pca == "PrincipalComponentAnalysis" - else pca + PrincipalComponentAnalysis(weighting_PCs=weighting_PCs) if pca == "PrincipalComponentAnalysis" else pca ) self.weighting_PCs = weighting_PCs self.clustering = BirchClustering() if clustering == "Birch" else clustering self.select_k_from_clusters = ( - SelectKFromClusters() - if select_k_from_clusters == "select_k_from_clusters" - else select_k_from_clusters + SelectKFromClusters() if select_k_from_clusters == "select_k_from_clusters" else select_k_from_clusters ) steps = [ (i.__class__.__name__, i) diff --git a/maml/sampling/stratified_sampling.py b/maml/sampling/stratified_sampling.py index 52f84c83..2e8b690e 100644 --- a/maml/sampling/stratified_sampling.py +++ b/maml/sampling/stratified_sampling.py @@ -1,4 +1,4 @@ -"""Implementation of stratefied sampling approaches.""" +"""Implementation of stratified sampling approaches.""" from __future__ import annotations import logging @@ -40,13 +40,9 @@ def __init__( self.allow_duplicate = allow_duplicate allowed_selection_criterion = ["random", "smallest", "center"] if selection_criteria not in allowed_selection_criterion: - raise ValueError( - f"Invalid selection_criteria, it must be one of {allowed_selection_criterion}." - ) - elif selection_criteria == "smallest" and not n_sites: - raise ValueError( - 'n_sites must be provided when selection_criteria="smallest."' - ) + raise ValueError(f"Invalid selection_criteria, it must be one of {allowed_selection_criterion}.") + if selection_criteria == "smallest" and not n_sites: + raise ValueError('n_sites must be provided when selection_criteria="smallest."') self.selection_criteria = selection_criteria self.n_sites = n_sites @@ -82,10 +78,7 @@ def transform(self, clustering_data: dict): raise Exception( "The data returned by clustering step should at least provide label and feature information." ) - if ( - self.selection_criteria == "center" - and "label_centers" not in clustering_data - ): + if self.selection_criteria == "center" and "label_centers" not in clustering_data: warnings.warn( "Centroid location is not provided, so random selection from each cluster will be performed, " "which likely will still outperform manual sampling in terms of feature coverage. " @@ -94,32 +87,21 @@ def transform(self, clustering_data: dict): try: assert len(self.n_sites) == len(clustering_data["PCAfeatures"]) except Exception: - raise ValueError( - "n_sites must have same length as features processed in clustering." - ) + raise ValueError("n_sites must have same length as features processed in clustering.") selected_indexes = [] for label in set(clustering_data["labels"]): indexes_same_label = np.where(label == clustering_data["labels"])[0] features_same_label = clustering_data["PCAfeatures"][indexes_same_label] n_same_label = len(features_same_label) - if ( - "label_centers" in clustering_data - and self.selection_criteria == "center" - ): + if "label_centers" in clustering_data and self.selection_criteria == "center": center_same_label = clustering_data["label_centers"][label] - distance_to_center = np.linalg.norm( - features_same_label - center_same_label, axis=1 - ).reshape(len(indexes_same_label)) - select_k_indexes = [ - int(i) for i in np.linspace(0, n_same_label - 1, self.k) - ] + distance_to_center = np.linalg.norm(features_same_label - center_same_label, axis=1).reshape( + len(indexes_same_label) + ) + select_k_indexes = [int(i) for i in np.linspace(0, n_same_label - 1, self.k)] selected_indexes.extend( - indexes_same_label[ - np.argpartition(distance_to_center, select_k_indexes)[ - select_k_indexes - ] - ] + indexes_same_label[np.argpartition(distance_to_center, select_k_indexes)[select_k_indexes]] ) elif self.selection_criteria == "smallest": if self.k >= n_same_label: @@ -135,9 +117,7 @@ def transform(self, clustering_data: dict): ] ) else: - selected_indexes.extend( - indexes_same_label[np.random.randint(n_same_label, size=self.k)] - ) + selected_indexes.extend(indexes_same_label[np.random.randint(n_same_label, size=self.k)]) n_duplicate = len(selected_indexes) - len(set(selected_indexes)) if not self.allow_duplicate and n_duplicate > 0: selected_indexes = list(set(selected_indexes)) diff --git a/tasks.py b/tasks.py index 9d002e0c..c700a315 100644 --- a/tasks.py +++ b/tasks.py @@ -53,7 +53,7 @@ def make_doc(ctx): contents = re.sub( r"\n## Official Documentation[^#]*", "{: .no_toc }\n\n## Table of contents\n{: .no_toc .text-delta }\n* TOC\n{:toc}\n\n", - contents + contents, ) contents = "---\nlayout: default\ntitle: Home\nnav_order: 1\n---\n\n" + contents diff --git a/tests/sampling/test_direct.py b/tests/sampling/test_direct.py index 2947b401..7dad3ad5 100644 --- a/tests/sampling/test_direct.py +++ b/tests/sampling/test_direct.py @@ -6,20 +6,14 @@ class TestDIRECTSampler: def setup(self): - self.direct_fixed_n = DIRECTSampler( - structure_encoder=None, clustering=BirchClustering(n=1) - ) + self.direct_fixed_n = DIRECTSampler(structure_encoder=None, clustering=BirchClustering(n=1)) self.direct_fixed_t = DIRECTSampler( structure_encoder=None, clustering=BirchClustering(n=None, threshold_init=0.5), ) def test_fit_transform(self, MPF_2021_2_8_first10_features_test): - result_fixed_n = self.direct_fixed_n.fit_transform( - MPF_2021_2_8_first10_features_test["M3GNet_features"] - ) - result_fixed_t = self.direct_fixed_t.fit_transform( - MPF_2021_2_8_first10_features_test["M3GNet_features"] - ) + result_fixed_n = self.direct_fixed_n.fit_transform(MPF_2021_2_8_first10_features_test["M3GNet_features"]) + result_fixed_t = self.direct_fixed_t.fit_transform(MPF_2021_2_8_first10_features_test["M3GNet_features"]) assert result_fixed_n["selected_indexes"] == [9] assert result_fixed_t["selected_indexes"] == [0, 6, 7, 8, 9] diff --git a/tests/sampling/test_stratified_sampling.py b/tests/sampling/test_stratified_sampling.py index 016625f1..8ac7e070 100644 --- a/tests/sampling/test_stratified_sampling.py +++ b/tests/sampling/test_stratified_sampling.py @@ -9,9 +9,7 @@ class TestSelectKFromClusters: def setup(self): self.selector_uni = SelectKFromClusters(k=2, allow_duplicate=False) self.selector_dup = SelectKFromClusters(k=2, allow_duplicate=True) - self.selector_rand = SelectKFromClusters( - k=2, allow_duplicate=False, selection_criteria="random" - ) + self.selector_rand = SelectKFromClusters(k=2, allow_duplicate=False, selection_criteria="random") self.selector_small = SelectKFromClusters( k=2, allow_duplicate=False, @@ -30,9 +28,7 @@ def test_exceptions(self, Birch_results): SelectKFromClusters(selection_criteria="whatever") with pytest.raises(ValueError, match="n_sites must be provided"): SelectKFromClusters(selection_criteria="smallest") - with pytest.raises( - ValueError, match="n_sites must have same length as features" - ): + with pytest.raises(ValueError, match="n_sites must have same length as features"): self.selector_small_wrong_n_sites.transform(Birch_results) def test_fit(self, Birch_results):