Skip to content

Commit

Permalink
Initial commit for SpeciesEmbedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Anthony Onwuli committed Aug 5, 2024
1 parent a6bb950 commit 4429a00
Showing 1 changed file with 115 additions and 1 deletion.
116 changes: 115 additions & 1 deletion src/elementembeddings/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sklearn.metrics import DistanceMetric
from tqdm import tqdm

from .core import Embedding
from .core import Embedding, SpeciesEmbedding
from .utils.math import cosine_distance

tqdm.pandas()
Expand Down Expand Up @@ -417,3 +417,117 @@ def composition_featuriser(
raise ValueError(
msg,
)


class SpeciesCompositionalEmbedding:
"""Class to handle species compositional embeddings.
Args:
----
formula_dict (dict): A dictionary of the form {species: amount}
embedding (Union[str, SpeciesEmbedding]): Either a string name of the embedding
or an SpeciesEmbedding instance
x (int, optional): The non-stoichiometric amount.
"""

def __init__(
self, formula_dict: dict, embedding: Union[str, SpeciesEmbedding], x=1
) -> None:
"""Initialise a SpeciesCompositionalEmbedding instance."""
self.embedding = embedding

Check warning on line 437 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L437

Added line #L437 was not covered by tests

# If a string has been passed for embedding, create an Embedding instance
if isinstance(embedding, str):
self.embedding = SpeciesEmbedding.load_data(embedding)

Check warning on line 441 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L440-L441

Added lines #L440 - L441 were not covered by tests

self.embedding_name: str = self.embedding.embedding_name

Check warning on line 443 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L443

Added line #L443 was not covered by tests

# Set an attribute for the comp dict
self.composition = formula_dict

Check warning on line 446 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L446

Added line #L446 was not covered by tests

# Set an attribute for the number of atoms
self._natoms = 0
for v in self.composition.values():
if v < 0:
msg = "Formula cannot contain negative amounts of elements"
raise ValueError(msg)
self._natoms += abs(v)

Check warning on line 454 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L449-L454

Added lines #L449 - L454 were not covered by tests

# Set an attribute for the species list
self.species_list = list(self.composition.keys())

Check warning on line 457 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L457

Added line #L457 was not covered by tests
# Set an attribute for the species matrix
self.species_matrix = np.zeros(

Check warning on line 459 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L459

Added line #L459 was not covered by tests
shape=(len(self.composition), len(self.embedding.embeddings["Zn2+"])),
)
for i, k in enumerate(self.composition.keys()):
self.species_matrix[i] = self.embedding.embeddings[k]
self.species_matrix = np.nan_to_num(self.species_matrix)

Check warning on line 464 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L462-L464

Added lines #L462 - L464 were not covered by tests

# Set an attribute for the stoichiometric vector
self.stoich_vector = np.array(list(self.composition.values()))

Check warning on line 467 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L467

Added line #L467 was not covered by tests

# Set an attribute for the normalised stoichiometric vector
self.norm_stoich_vector = self.stoich_vector / np.sum(self.stoich_vector)

Check warning on line 470 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L470

Added line #L470 was not covered by tests

@property
def num_atoms(self) -> float:
"""Total number of atoms in Composition."""
return self._natoms

Check warning on line 475 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L475

Added line #L475 was not covered by tests

def as_dict(self) -> dict:
# TO-DO: Need to create a dict representation for the embedding class
"""Return the SpeciesCompositionalEmbedding class as a dict."""
return {

Check warning on line 480 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L480

Added line #L480 was not covered by tests
"composition": self.composition,
}

@property
def fractional_composition(self):
"""Fractional composition of the Composition."""
return {k: v / self._natoms for k, v in self.composition.items()}

Check warning on line 487 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L487

Added line #L487 was not covered by tests

def _mean_feature_vector(self) -> np.ndarray:
"""Compute a weighted mean feature vector based of the embedding.
The dimension of the feature vector is the same as the embedding.
"""
return np.dot(self.norm_stoich_vector, self.species_matrix)

Check warning on line 495 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L495

Added line #L495 was not covered by tests

def _variance_feature_vector(self) -> np.ndarray:
"""Compute a weighted variance feature vector."""
diff_matrix = self.species_matrix - self._mean_feature_vector()

Check warning on line 499 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L499

Added line #L499 was not covered by tests

diff_matrix = diff_matrix**2
return np.dot(self.norm_stoich_vector, diff_matrix)

Check warning on line 502 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L501-L502

Added lines #L501 - L502 were not covered by tests

def _minpool_feature_vector(self) -> np.ndarray:
return np.min(self.species_matrix, axis=0)

Check warning on line 505 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L505

Added line #L505 was not covered by tests

def _maxpool_feature_vector(self) -> np.ndarray:
return np.max(self.species_matrix, axis=0)

Check warning on line 508 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L508

Added line #L508 was not covered by tests

def _range_feature_vector(self) -> np.ndarray:
return np.ptp(self.species_matrix, axis=0)

Check warning on line 511 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L511

Added line #L511 was not covered by tests

def _sum_feature_vector(self) -> np.ndarray:
return np.dot(self.stoich_vector, self.species_matrix)

Check warning on line 514 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L514

Added line #L514 was not covered by tests

def _geometric_mean_feature_vector(self) -> np.ndarray:
return np.exp(np.dot(self.norm_stoich_vector, np.log(self.species_matrix)))

Check warning on line 517 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L517

Added line #L517 was not covered by tests

def _harmonic_mean_feature_vector(self) -> np.ndarray:
return np.reciprocal(

Check warning on line 520 in src/elementembeddings/composition.py

View check run for this annotation

Codecov / codecov/patch

src/elementembeddings/composition.py#L520

Added line #L520 was not covered by tests
np.dot(self.norm_stoich_vector, np.reciprocal(self.species_matrix)),
)

_stats_functions_dict = {
"mean": "_mean_feature_vector",
"variance": "_variance_feature_vector",
"minpool": "_minpool_feature_vector",
"maxpool": "_maxpool_feature_vector",
"range": "_range_feature_vector",
"sum": "_sum_feature_vector",
"geometric_mean": "_geometric_mean_feature_vector",
"harmonic_mean": "_harmonic_mean_feature_vector",
}

0 comments on commit 4429a00

Please sign in to comment.