Skip to content

Commit

Permalink
correct figure legend in notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
JiQi535 committed Jan 28, 2024
1 parent 96fc1bd commit c8f35cb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 68 deletions.
24 changes: 6 additions & 18 deletions maml/describers/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
if TYPE_CHECKING:
from pymatgen.core import Molecule, Structure

DEFAULT_MODEL = (
Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/"
)
DEFAULT_MODEL = Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/"


@describer_type("structure")
Expand Down Expand Up @@ -60,9 +58,7 @@ def transform_one(self, structure: Structure | Molecule):
graph = self.describer_model.graph_converter.convert(structure).as_list()
graph = tf_compute_distance_angle(graph)
three_basis = self.describer_model.basis_expansion(graph)
three_cutoff = polynomial(
graph[Index.BONDS], self.describer_model.threebody_cutoff
)
three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff)
g = self.describer_model.featurizer(graph)
g = self.describer_model.feature_adjust(g)
for i in range(self.describer_model.n_blocks):
Expand Down Expand Up @@ -106,17 +102,11 @@ def __init__(
else:
self.describer_model = M3GNet.from_dir(DEFAULT_MODEL)
self.model_path = str(DEFAULT_MODEL)
allowed_output_layers = ["embedding"] + [
f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)
]
allowed_output_layers = ["embedding"] + [f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)]
if output_layers is None:
output_layers = ["gc_1"]
elif not isinstance(output_layers, list) or set(output_layers).difference(
allowed_output_layers
):
raise ValueError(
f"Invalid output_layers, it must be a sublist of {allowed_output_layers}."
)
elif not isinstance(output_layers, list) or set(output_layers).difference(allowed_output_layers):
raise ValueError(f"Invalid output_layers, it must be a sublist of {allowed_output_layers}.")
self.output_layers = output_layers
self.return_type = return_type
super().__init__(**kwargs)
Expand All @@ -135,9 +125,7 @@ def transform_one(self, structure: Structure | Molecule):
graph = self.describer_model.graph_converter.convert(structure).as_list()
graph = tf_compute_distance_angle(graph)
three_basis = self.describer_model.basis_expansion(graph)
three_cutoff = polynomial(
graph[Index.BONDS], self.describer_model.threebody_cutoff
)
three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff)
g = self.describer_model.featurizer(graph)
atom_fea = {"embedding": g[Index.ATOMS]}
g = self.describer_model.feature_adjust(g)
Expand Down
40 changes: 30 additions & 10 deletions maml/sampling/stratified_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ def __init__(
self.allow_duplicate = allow_duplicate
allowed_selection_criterion = ["random", "smallest", "center"]
if selection_criteria not in allowed_selection_criterion:
raise ValueError(f"Invalid selection_criteria, it must be one of {allowed_selection_criterion}.")
raise ValueError(
f"Invalid selection_criteria, it must be one of {allowed_selection_criterion}."
)
if selection_criteria == "smallest" and not n_sites:
raise ValueError('n_sites must be provided when selection_criteria="smallest."')
raise ValueError(
'n_sites must be provided when selection_criteria="smallest."'
)
self.selection_criteria = selection_criteria
self.n_sites = n_sites

Expand Down Expand Up @@ -79,7 +83,10 @@ def transform(self, clustering_data: dict):
raise Exception(
"The data returned by clustering step should at least provide label and feature information."
)
if self.selection_criteria == "center" and "label_centers" not in clustering_data:
if (
self.selection_criteria == "center"
and "label_centers" not in clustering_data
):
warnings.warn(
"Centroid location is not provided, so random selection from each cluster will be performed, "
"which likely will still outperform manual sampling in terms of feature coverage. "
Expand All @@ -88,21 +95,32 @@ def transform(self, clustering_data: dict):
try:
assert len(self.n_sites) == len(clustering_data["PCAfeatures"])
except Exception:
raise ValueError("n_sites must have same length as features processed in clustering.")
raise ValueError(
"n_sites must have same length as features processed in clustering."
)

selected_indexes = []
for label in set(clustering_data["labels"]):
indexes_same_label = np.where(label == clustering_data["labels"])[0]
features_same_label = clustering_data["PCAfeatures"][indexes_same_label]
n_same_label = len(features_same_label)
if "label_centers" in clustering_data and self.selection_criteria == "center":
if (
"label_centers" in clustering_data
and self.selection_criteria == "center"
):
center_same_label = clustering_data["label_centers"][label]
distance_to_center = np.linalg.norm(features_same_label - center_same_label, axis=1).reshape(
len(indexes_same_label)
distance_to_center = np.linalg.norm(
features_same_label - center_same_label, axis=1
).reshape(len(indexes_same_label))
select_k_indexes = np.array(
[int(i) for i in np.linspace(0, n_same_label - 1, self.k)]
)
select_k_indexes = np.array([int(i) for i in np.linspace(0, n_same_label - 1, self.k)])
selected_indexes.extend(
indexes_same_label[np.argpartition(distance_to_center, select_k_indexes)[select_k_indexes]]
indexes_same_label[
np.argpartition(distance_to_center, select_k_indexes)[
select_k_indexes
]
]
)
elif self.selection_criteria == "smallest":
if self.k >= n_same_label:
Expand All @@ -118,7 +136,9 @@ def transform(self, clustering_data: dict):
]
)
else:
selected_indexes.extend(indexes_same_label[np.random.randint(n_same_label, size=self.k)])
selected_indexes.extend(
indexes_same_label[np.random.randint(n_same_label, size=self.k)]
)
n_duplicate = len(selected_indexes) - len(set(selected_indexes))
if not self.allow_duplicate and n_duplicate > 0:
selected_indexes = list(set(selected_indexes))
Expand Down
72 changes: 32 additions & 40 deletions notebooks/direct/Example2_Ti-H.ipynb

Large diffs are not rendered by default.

0 comments on commit c8f35cb

Please sign in to comment.