diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5fcb31c..13cae96 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,7 +8,41 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## Unreleased
-[Compare with latest](https://github.com/WMD-group/ElementEmbeddings/compare/v0.1...HEAD)
+[Compare with latest](https://github.com/WMD-group/ElementEmbeddings/compare/v0.1.1...HEAD)
+
+### Added
+
+- Added dev & doc reqs and updated version no. ([e9278c5](https://github.com/WMD-group/ElementEmbeddings/commit/e9278c579a031643576f137196aad34e0f5ea98f) by Anthony Onwuli).
+- Added test for dimension plotter in 3D ([ef37daf](https://github.com/WMD-group/ElementEmbeddings/commit/ef37daff6aa824c3d9917fa1ba26fc37b95a9951) by Anthony Onwuli).
+- Added test for plotting euclidean heatmap ([e9cc5ed](https://github.com/WMD-group/ElementEmbeddings/commit/e9cc5ed5508e624420b1330973425572ff5b1628) by Anthony Onwuli).
+- Added deprecation warning to old plotting function ([542e2f2](https://github.com/WMD-group/ElementEmbeddings/commit/542e2f2e6bd96b0f0e1624192cb9a9a98fb3dfcc) by Anthony Onwuli).
+- Added functions to test PCA and UMAP for Embedding ([b7ccc8f](https://github.com/WMD-group/ElementEmbeddings/commit/b7ccc8f41384e5e6095090aa016088279b5a0439) by Anthony Onwuli).
+- Added test function for tSNE for Embedding class ([aaa1472](https://github.com/WMD-group/ElementEmbeddings/commit/aaa147279ba609984482813df2ce9530420da2be) by Anthony Onwuli).
+- Added more tests for compute metrics function ([c4055bc](https://github.com/WMD-group/ElementEmbeddings/commit/c4055bcdad6e5bd7832a8568767ced72cd9cdfd9) by Anthony Onwuli).
+- Added test for computing spearman's rank ([5f715aa](https://github.com/WMD-group/ElementEmbeddings/commit/5f715aaa3ba339b5e01012cb0a40c44652481b55) by Anthony Onwuli).
+- Added more dataframe test functions ([edeeb87](https://github.com/WMD-group/ElementEmbeddings/commit/edeeb8714ae80b194159738b562606819ffc3ccb) by Anthony Onwuli).
+- Added test for removing multiple elements ([b69d806](https://github.com/WMD-group/ElementEmbeddings/commit/b69d80699cad211166ff1b112886d19d387890b5) by Anthony Onwuli).
+- Added test function for removing elements ([664cbbf](https://github.com/WMD-group/ElementEmbeddings/commit/664cbbf1846757b7d018c199745b6227465c0268) by Anthony Onwuli).
+- Added setUpClass to test_core.py ([7cb8ab6](https://github.com/WMD-group/ElementEmbeddings/commit/7cb8ab6d3b731d04831cdfe83a90b926ab1e2a1b) by Anthony Onwuli).
+- Added tests for `as_dataframe` method ([7256c23](https://github.com/WMD-group/ElementEmbeddings/commit/7256c23d8d2840b77983424ee9247a90f1caaded) by Anthony Onwuli).
+- Added more tests to the Embedding loading function ([7f9f87b](https://github.com/WMD-group/ElementEmbeddings/commit/7f9f87b987a77f1d4b73cf9fed289a5d9a028417) by Anthony Onwuli).
+- Added test functions for loading csv and json ([87ba958](https://github.com/WMD-group/ElementEmbeddings/commit/87ba9581c506bc16aa377961da28c0cbe60e80de) by Anthony Onwuli).
+- Added test embedding json file ([188aea4](https://github.com/WMD-group/ElementEmbeddings/commit/188aea48e21b3c3d5a1b9624a9885b94f14b2fcc) by Anthony Onwuli).
+- Added test embedding csv file ([173bcee](https://github.com/WMD-group/ElementEmbeddings/commit/173bcee057173ec1a48cdc7bb3141406236119ce) by Anthony Onwuli).
+
+### Fixed
+
+- Fixed spelling error for extras_require setup.py ([8e28e9a](https://github.com/WMD-group/ElementEmbeddings/commit/8e28e9a09550bfcaf21ec4d95989cd031d717596) by Anthony Onwuli).
+
+### Removed
+
+- Removed outdated installation instructions ([c69817f](https://github.com/WMD-group/ElementEmbeddings/commit/c69817fef331e203fb3861e603c7c0176097e51f) by Anthony Onwuli).
+- Removed an else block from `load_data` ([5532de6](https://github.com/WMD-group/ElementEmbeddings/commit/5532de6d050580382f0fa9688be96f0e9cd231ec) by Anthony Onwuli).
+
+
+## [v0.1.1](https://github.com/WMD-group/ElementEmbeddings/releases/tag/v0.1.1) - 2023-07-05
+
+[Compare with v0.1](https://github.com/WMD-group/ElementEmbeddings/compare/v0.1...v0.1.1)
### Added
@@ -20,7 +54,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- removed pandas, pytest and pytest-subtests in reqs ([cd1bf77](https://github.com/WMD-group/ElementEmbeddings/commit/cd1bf776220250377bb7cd48cca6b08e9a968f1d) by Anthony Onwuli).
-
## [v0.1](https://github.com/WMD-group/ElementEmbeddings/releases/tag/v0.1) - 2023-06-30
[Compare with first commit](https://github.com/WMD-group/ElementEmbeddings/compare/262c7e99a438a3527fb73866093ae8cb1ee85ee6...v0.1)
diff --git a/CITATION.cff b/CITATION.cff
index 6c4b898..d12ab06 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -36,5 +36,5 @@ keywords:
- element representation
license: MIT
commit: d3f30602abf825ba3dcd5f247694a174a358ef49
-version: '0.1'
-date-released: '2023-06-30'
+version: '0.2.0'
+date-released: '2023-07-07'
diff --git a/README.md b/README.md
index 8cd0fac..e59034c 100644
--- a/README.md
+++ b/README.md
@@ -12,7 +12,6 @@ ElementEmbeddings
[![documentation](https://img.shields.io/badge/docs-mkdocs%20material-blue.svg?style=flat)](https://wmd-group.github.io/ElementEmbeddings/)
![python version](https://img.shields.io/pypi/pyversions/elementembeddings)
-
The **Element Embeddings** package provides high-level tools for analysing elemental
embeddings data. This primarily involves visualising the correlation between
embedding schemes using different statistical measures.
@@ -37,14 +36,18 @@ importing the class.
Installation
--------
+
The latest stable release can be installed via pip using:
+
```bash
pip install ElementEmbeddings
```
-The latest version can be installed using:
+
+For installing the development or documentation dependencies via pip:
```bash
-pip install git+git://github.com/WMD-group/ElementEmbeddings.git
+pip install "ElementEmbeddings[dev]"
+pip install "ElementEmbeddings[docs]"
```
For development, you can clone the repository and install the package in editable mode.
@@ -80,4 +83,3 @@ We can access some of the properties of the `Embedding` class. For example, we c
The magpie representation has embeddings of dimension 22
['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk']
```
-
diff --git a/setup.py b/setup.py
index 6497cd0..6a71566 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
module_dir = os.path.dirname(os.path.abspath(__file__))
-VERSION = "0.1.1"
+VERSION = "0.2.0"
DESCRIPTION = "Element Embeddings"
with open(os.path.join(module_dir, "README.md"), encoding="utf-8") as f:
LONG_DESCRIPTION = f.read()
@@ -37,6 +37,26 @@
"umap-learn==0.5.3",
"adjustText==0.8",
],
+ extras_require={
+ "dev": [
+ "pre-commit==2.20.0",
+ "black==23.3.0",
+ "isort==5.12.0",
+ "pytest==7.2.1",
+ "pytest-subtests==0.10.0",
+ "nbqa==1.5.3",
+ "pyupgrade==3.3.1",
+ "flake8==6.0.0",
+ "autopep8==2.0.1",
+ "pytest-cov==4.1.0",
+ ],
+ "docs": [
+ "mkdocs==1.4.3",
+ "mkdocs-material==9.1.17",
+ "mkdocstrings ==0.21.2",
+ "mkdocstrings-python == 1.0.0",
+ ],
+ },
classifiers=[
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
diff --git a/src/elementembeddings/core.py b/src/elementembeddings/core.py
index 5d8c132..21a94a0 100644
--- a/src/elementembeddings/core.py
+++ b/src/elementembeddings/core.py
@@ -151,11 +151,6 @@ def load_data(embedding_name: Optional[str] = None):
_json = path.join(data_directory, _cbfv_files[embedding_name])
with open(_json) as f:
embedding_data = json.load(f)
-
- # Load a json file from a file specified in the input
- else:
- with open(embedding_name) as f:
- embedding_data = json.load(f)
else:
raise (
ValueError(
@@ -853,6 +848,10 @@ def plot_PCA_2D(
ax (matplotlib.axes.Axes): An Axes object with the PCA plot
"""
+ warnings.warn(
+ "This method is deprecated and will be removed in a future release. ",
+ DeprecationWarning,
+ )
embeddings_array = np.array(list(self.embeddings.values()))
element_array = np.array(self.element_list)
@@ -913,6 +912,10 @@ def plot_tSNE(
"""
+ warnings.warn(
+ "This method is deprecated and will be removed in a future release. ",
+ DeprecationWarning,
+ )
embeddings_array = np.array(list(self.embeddings.values()))
element_array = np.array(self.element_list)
diff --git a/src/elementembeddings/tests/files/test_embedding.csv b/src/elementembeddings/tests/files/test_embedding.csv
new file mode 100644
index 0000000..74a2a4c
--- /dev/null
+++ b/src/elementembeddings/tests/files/test_embedding.csv
@@ -0,0 +1,4 @@
+element,0,1,2,3,4,5,6,7,8,9
+H,1,0,0,0,0,0,0,0,0,0
+He,0,1,0,0,0,0,0,0,0,0
+Li,0,0,1,0,0,0,0,0,0,0
\ No newline at end of file
diff --git a/src/elementembeddings/tests/files/test_embedding.json b/src/elementembeddings/tests/files/test_embedding.json
new file mode 100644
index 0000000..abcd5a9
--- /dev/null
+++ b/src/elementembeddings/tests/files/test_embedding.json
@@ -0,0 +1 @@
+{"H":[1,0,0,0,0,0,0,0,0,0], "He":[0,1,0,0,0,0,0,0,0,0],"Li":[0,0,1,0,0,0,0,0,0,0]}
\ No newline at end of file
diff --git a/src/elementembeddings/tests/test_composition.py b/src/elementembeddings/tests/test_composition.py
index 3641c67..6ef635d 100644
--- a/src/elementembeddings/tests/test_composition.py
+++ b/src/elementembeddings/tests/test_composition.py
@@ -11,35 +11,53 @@
class TestComposition(unittest.TestCase):
"""Test the composition module."""
+ def setUp(self):
+ """Set up the test class."""
+ self.formulas = [
+ "Sr3Sc2(GeO4)3",
+ "Fe2O3",
+ "Li7La3ZrO12",
+ "CsPbI3",
+ ]
+
def test_formula_parser(self):
"""Test the formula_parser function."""
- LLZO_parsed = composition.formula_parser("Li7La3ZrO12")
+ LLZO_parsed = composition.formula_parser(self.formulas[2])
assert isinstance(LLZO_parsed, dict)
assert "Zr" in LLZO_parsed
assert LLZO_parsed["Li"] == 7
+ def test_formula_parser_with_parentheses(self):
+ """Test the formula_parser function with parentheses."""
+ SrScGeO4_parsed = composition.formula_parser(self.formulas[0])
+ assert isinstance(SrScGeO4_parsed, dict)
+ assert "Sr" in SrScGeO4_parsed
+ assert SrScGeO4_parsed["Ge"] == 3
+
+ def test_formula_parser_with_invalid_formula(self):
+ """Test the formula_parser function with an invalid formula."""
+ with self.assertRaises(ValueError):
+ composition.formula_parser("Sr3Sc2(GeO4)3)")
+
def test__get_fractional_composition(self):
"""Test the _get_fractional_composition function."""
- CsPbI3_frac = composition._get_fractional_composition("CsPbI3")
+ CsPbI3_frac = composition._get_fractional_composition(self.formulas[3])
assert isinstance(CsPbI3_frac, dict)
assert "Pb" in CsPbI3_frac
assert CsPbI3_frac["I"] == 0.6
- def test_Composition_class(self):
- """Test the Composition class."""
- Fe2O3_magpie = composition.CompositionalEmbedding(
- formula="Fe2O3", embedding="magpie"
- )
- assert isinstance(Fe2O3_magpie.embedding, core.Embedding)
- assert Fe2O3_magpie.formula == "Fe2O3"
- assert Fe2O3_magpie.embedding_name == "magpie"
- assert isinstance(Fe2O3_magpie.composition, dict)
- assert {"Fe": 2, "O": 3} == Fe2O3_magpie.composition
- assert Fe2O3_magpie._natoms == 5
- assert Fe2O3_magpie.fractional_composition == {"Fe": 0.4, "O": 0.6}
- assert isinstance(Fe2O3_magpie._mean_feature_vector(), np.ndarray)
- # Test that the feature vector function works
- stats = [
+
+class TestCompositionalEmbedding(unittest.TestCase):
+ """Test the CompositionalEmbedding class."""
+
+ def setUp(self):
+ """Set up the test formulas."""
+ self.formulas = ["Sr3Sc2(GeO4)3", "Fe2O3", "Li7La3ZrO12", "CsPbI3", "CsPbI-3"]
+ self.valid_magpie_compositions = [
+ composition.CompositionalEmbedding(formula=formula, embedding="magpie")
+ for formula in self.formulas[:3]
+ ]
+ self.stats = [
"mean",
"variance",
"minpool",
@@ -49,16 +67,50 @@ def test_Composition_class(self):
"geometric_mean",
"harmonic_mean",
]
- assert isinstance(Fe2O3_magpie.feature_vector(stats=stats), np.ndarray)
+
+ def test_CompositionalEmbedding_attributes(self):
+ """Test the Composition class."""
+ Fe2O3_magpie = self.valid_magpie_compositions[1]
+ assert isinstance(Fe2O3_magpie.embedding, core.Embedding)
+ assert Fe2O3_magpie.formula == "Fe2O3"
+ assert Fe2O3_magpie.embedding_name == "magpie"
+ assert isinstance(Fe2O3_magpie.composition, dict)
+ assert {"Fe": 2, "O": 3} == Fe2O3_magpie.composition
+ assert Fe2O3_magpie.num_atoms == 5
+ assert Fe2O3_magpie.fractional_composition == {"Fe": 0.4, "O": 0.6}
+ assert Fe2O3_magpie.embedding.dim == 22
+
+ def test_CompositionalEmbedding_negative_formula(self):
+ """Test the Composition class with a negative formula."""
+ with self.assertRaises(ValueError):
+ composition.CompositionalEmbedding(
+ formula=self.formulas[4], embedding="magpie"
+ )
+
+ def test__mean_feature_vector(self):
+ """Test the _mean_feature_vector function."""
+ assert isinstance(
+ self.valid_magpie_compositions[1]._mean_feature_vector(), np.ndarray
+ )
+ # Test that the feature vector function works
+
+ def test_feature_vector(self):
+ """Test the feature_vector function."""
+ assert isinstance(
+ self.valid_magpie_compositions[0].feature_vector(stats=self.stats),
+ np.ndarray,
+ )
assert len(
- Fe2O3_magpie.feature_vector(stats=stats)
- ) == Fe2O3_magpie.embedding.dim * len(stats)
+ self.valid_magpie_compositions[0].feature_vector(stats=self.stats)
+ ) == self.valid_magpie_compositions[0].embedding.dim * len(self.stats)
# Test that the feature vector function works with a single stat
- assert isinstance(Fe2O3_magpie.feature_vector(stats="mean"), np.ndarray)
+ assert isinstance(
+ self.valid_magpie_compositions[0].feature_vector(stats="mean"), np.ndarray
+ )
def test_composition_featuriser(self):
"""Test the composition featuriser function."""
- formulas = ["Fe2O3", "Li7La3ZrO12", "CsPbI3"]
+ formulas = self.formulas[:3]
formula_df = pd.DataFrame(formulas, columns=["formula"])
assert isinstance(composition.composition_featuriser(formula_df), pd.DataFrame)
assert composition.composition_featuriser(formula_df).shape == (3, 2)
diff --git a/src/elementembeddings/tests/test_core.py b/src/elementembeddings/tests/test_core.py
index de96d3d..845b3b0 100644
--- a/src/elementembeddings/tests/test_core.py
+++ b/src/elementembeddings/tests/test_core.py
@@ -1,4 +1,5 @@
"""Test the core module of AtomicEmbeddings."""
+import os
import unittest
import matplotlib.pyplot as plt
@@ -7,26 +8,49 @@
from elementembeddings.core import Embedding
+test_files_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "files")
+TEST_EMBEDDING_CSV = os.path.join(test_files_dir, "test_embedding.csv")
+TEST_EMBEDDING_JSON = os.path.join(test_files_dir, "test_embedding.json")
+
class EmbeddingTest(unittest.TestCase):
"""Test the Embedding class."""
# High Level functions
+ @classmethod
+ def setUpClass(cls):
+ """Set up the test."""
+ cls.test_skipatom = Embedding.load_data("skipatom")
+ cls.test_megnet16 = Embedding.load_data("megnet16")
+ cls.test_matscholar = Embedding.load_data("matscholar")
+ cls.test_mod_petti = Embedding.load_data("mod_petti")
+ cls.test_magpie = Embedding.load_data("magpie")
+
+ def test_Embedding_attributes(self):
+ """Test attributes of the loaded embeddings."""
+ assert self.test_skipatom.dim == 200
+ assert self.test_skipatom.embedding_name == "skipatom"
+ assert self.test_megnet16.dim == 16
+ assert self.test_megnet16.embedding_name == "megnet16"
+ assert self.test_matscholar.dim == 200
+ assert self.test_matscholar.embedding_name == "matscholar"
+ assert self.test_mod_petti.dim == 1
+ assert self.test_mod_petti.embedding_name == "mod_petti"
+ assert isinstance(self.test_skipatom.citation(), list)
+ assert isinstance(self.test_megnet16.citation(), list)
+ assert isinstance(self.test_matscholar.citation(), list)
+ assert isinstance(self.test_mod_petti.citation(), list)
- def test_Embedding_loading(self):
- """Test that the Embedding class can load the data."""
- skipatom = Embedding.load_data("skipatom")
- megnet16 = Embedding.load_data("megnet16")
- assert skipatom.dim == 200
- assert skipatom.embedding_name == "skipatom"
- assert megnet16.dim == 16
- assert megnet16.embedding_name == "megnet16"
- assert isinstance(skipatom.citation(), list)
- assert isinstance(megnet16.citation(), list)
+ def test_Embedding_file_input(self):
+ """Test that the Embedding class can load custom data."""
+ embedding_csv = Embedding.from_csv(TEST_EMBEDDING_CSV)
+ embedding_json = Embedding.from_json(TEST_EMBEDDING_JSON)
+ assert embedding_csv.dim == 10
+ assert embedding_json.dim == 10
def test_Embeddings_class_magpie(self):
"""Test that the Embedding class can load the magpie data."""
- magpie = Embedding.load_data("magpie")
+ magpie = self.test_magpie
# Check if the embeddings attribute is a dict
assert isinstance(magpie.embeddings, dict)
# Check if the embedding vector is a numpy array
@@ -253,26 +277,73 @@ def test_Embeddings_class_magpie(self):
# TO-DO
# Create tests for checking dataframes and plotting functions
+
+ def test_as_dataframe(self):
+ """Test the as_dataframe method."""
+ magpie = self.test_magpie
assert isinstance(magpie.as_dataframe(), pd.DataFrame)
- assert isinstance(magpie.to(fmt="json"), str)
- assert isinstance(magpie.to(fmt="csv"), str)
+ assert "H" in magpie.as_dataframe().index.tolist()
+ assert isinstance(magpie.as_dataframe(columns="elements"), pd.DataFrame)
+ assert "H" in magpie.as_dataframe(columns="elements").columns.tolist()
+ self.assertRaises(ValueError, magpie.as_dataframe, columns="test")
+
+ def test_to(self):
+ """Test the to method."""
+ assert isinstance(self.test_magpie.to(fmt="json"), str)
+ self.test_magpie.to(fmt="json", filename="test.json")
+ assert os.path.isfile("test.json")
+ os.remove("test.json")
+
+ assert isinstance(self.test_magpie.to(fmt="csv"), str)
+ self.test_magpie.to(fmt="csv", filename="test.csv")
+ assert os.path.isfile("test.csv")
+ os.remove("test.csv")
+
+ def test_compute_metric_functions(self):
+ """Test the compute metric functions."""
assert isinstance(
- magpie.compute_correlation_metric("H", "O", metric="pearson"),
+ self.test_magpie.compute_correlation_metric("H", "O", metric="pearson"),
float,
)
assert isinstance(
- magpie.compute_distance_metric(
+ self.test_magpie.compute_distance_metric(
"H",
"O",
),
float,
)
- assert isinstance(magpie.distance_df(), pd.DataFrame)
- assert magpie.distance_df().shape == (
- len(list(magpie.create_pairs())) * 2 - len(magpie.embeddings),
+ assert isinstance(
+ self.test_magpie.compute_distance_metric("H", "O", "energy"),
+ float,
+ )
+ assert isinstance(
+ self.test_magpie.compute_distance_metric("H", "O", "cosine_distance"),
+ float,
+ )
+ assert isinstance(
+ self.test_magpie.compute_correlation_metric("H", "O", metric="spearman"),
+ float,
+ )
+
+ self.assertRaises(
+ ValueError, self.test_skipatom.compute_distance_metric, "He", "O"
+ )
+ self.assertRaises(
+ ValueError, self.test_skipatom.compute_distance_metric, "O", "He"
+ )
+ self.assertRaises(
+ ValueError, self.test_skipatom.compute_distance_metric, "Li", "O", "euclid"
+ )
+
+ def test_distance_dataframe_functions(self):
+ """Test the distance dataframe functions."""
+ assert isinstance(self.test_magpie.distance_df(), pd.DataFrame)
+ assert self.test_magpie.distance_df().shape == (
+ len(list(self.test_magpie.create_pairs())) * 2
+ - len(self.test_magpie.embeddings),
7,
)
- assert magpie.distance_df().columns.tolist() == [
+ assert self.test_magpie.distance_df().columns.tolist() == [
"ele_1",
"ele_2",
"mend_1",
@@ -281,8 +352,49 @@ def test_Embeddings_class_magpie(self):
"Z_2",
"euclidean",
]
- assert isinstance(magpie.distance_pivot_table(), pd.DataFrame)
- assert isinstance(magpie.plot_distance_correlation(), plt.Axes)
+ assert isinstance(self.test_magpie.distance_pivot_table(), pd.DataFrame)
+ assert isinstance(
+ self.test_magpie.distance_pivot_table(sortby="atomic_number"), pd.DataFrame
+ )
+ assert isinstance(self.test_magpie.plot_distance_correlation(), plt.Axes)
assert isinstance(
- magpie.plot_distance_correlation(metric="euclidean"), plt.Axes
+ self.test_magpie.plot_distance_correlation(metric="euclidean"), plt.Axes
+ )
+ assert isinstance(self.test_magpie.stats_correlation_df(), pd.DataFrame)
+
+ def test_remove_elements(self):
+ """Test the remove_elements function."""
+ assert isinstance(self.test_skipatom.remove_elements("H"), Embedding)
+ assert isinstance(self.test_skipatom.remove_elements(["H", "Li"]), Embedding)
+ self.assertIsNone(self.test_skipatom.remove_elements("H", inplace=True))
+ self.assertFalse(self.test_skipatom._is_el_in_embedding("H"))
+ self.assertIsNone(
+ self.test_skipatom.remove_elements(["Li", "Ti", "Bi"], inplace=True)
+ )
+ assert "Li" not in self.test_skipatom.element_list
+ assert "Ti" not in self.test_skipatom.element_list
+ assert "Bi" not in self.test_skipatom.element_list
+
+ def test_PCA(self):
+ """Test the PCA function."""
+ assert isinstance(self.test_matscholar.calculate_PC(), np.ndarray)
+ assert self.test_matscholar.calculate_PC().shape == (
+ len(self.test_matscholar.element_list),
+ 2,
+ )
+
+ def test_tSNE(self):
+ """Test the tSNE function."""
+ assert isinstance(self.test_matscholar.calculate_tSNE(), np.ndarray)
+ assert self.test_matscholar.calculate_tSNE().shape == (
+ len(self.test_matscholar.element_list),
+ 2,
+ )
+
+ def test_UMAP(self):
+ """Test the UMAP function."""
+ assert isinstance(self.test_matscholar.calculate_UMAP(), np.ndarray)
+ assert self.test_matscholar.calculate_UMAP().shape == (
+ len(self.test_matscholar.element_list),
+ 2,
)
diff --git a/src/elementembeddings/tests/test_plotter.py b/src/elementembeddings/tests/test_plotter.py
index d34c799..59f9c81 100644
--- a/src/elementembeddings/tests/test_plotter.py
+++ b/src/elementembeddings/tests/test_plotter.py
@@ -8,30 +8,91 @@
from elementembeddings.plotter import dimension_plotter, heatmap_plotter
-class heatmapTest(unittest.TestCase):
+class HeatmapTest(unittest.TestCase):
"""Test the heatmap_plotter function."""
+ @classmethod
+ def setUpClass(cls):
+ """Set up the test class."""
+ cls.test_skipatom = Embedding.load_data("skipatom")
+
def test_heatmap_plotter(self):
"""Test that the heatmap_plotter function works."""
- # Load the data
- skipatom = Embedding.load_data("skipatom")
# Get the embeddings
skipatom_cos_plot = heatmap_plotter(
- skipatom,
+ self.test_skipatom,
metric="cosine_similarity",
)
assert isinstance(skipatom_cos_plot, plt.Axes)
+ skipatom_euc_plot = heatmap_plotter(
+ self.test_skipatom, metric="euclidean", show_axislabels=False
+ )
+ assert isinstance(skipatom_euc_plot, plt.Axes)
-class dimensionTest(unittest.TestCase):
+class DimensionTest(unittest.TestCase):
"""Test the dimension_plotter function."""
- def test_dimension_plotter(self):
+ @classmethod
+ def setUpClass(cls):
+ """Set up the test class."""
+ cls.test_skipatom = Embedding.load_data("skipatom")
+
+ def test_dimension_2d_plotter(self):
"""Test that the dimension_plotter function works."""
- # Load the data
- skipatom = Embedding.load_data("skipatom")
- # Get the embeddings
skipatom_pca_plot = dimension_plotter(
- skipatom, n_components=2, reducer="pca", adjusttext=True
+ self.test_skipatom, n_components=2, reducer="pca", adjusttext=False
)
assert isinstance(skipatom_pca_plot, plt.Axes)
+ skipatom_tsne_plot = dimension_plotter(
+ self.test_skipatom, n_components=2, reducer="tsne", adjusttext=False
+ )
+ assert isinstance(skipatom_tsne_plot, plt.Axes)
+ skipatom_umap_plot = dimension_plotter(
+ self.test_skipatom, n_components=2, reducer="umap", adjusttext=True
+ )
+ assert isinstance(skipatom_umap_plot, plt.Axes)
+
+ self.assertRaises(
+ ValueError,
+ dimension_plotter,
+ self.test_skipatom,
+ n_components=2,
+ reducer="badreducer",
+ )
+
+ def test_dimension_2d_plotter_preloaded_reduction(self):
+ """Test that the dimension_plotter function works with a preloaded reduction."""
+ self.test_skipatom.calculate_PC()
+ self.test_skipatom.calculate_tSNE()
+ self.test_skipatom.calculate_UMAP()
+
+ skipatom_pca_plot = dimension_plotter(
+ self.test_skipatom, n_components=2, reducer="pca", adjusttext=False
+ )
+ assert isinstance(skipatom_pca_plot, plt.Axes)
+ skipatom_tsne_plot = dimension_plotter(
+ self.test_skipatom, n_components=2, reducer="tsne", adjusttext=False
+ )
+ assert isinstance(skipatom_tsne_plot, plt.Axes)
+ skipatom_umap_plot = dimension_plotter(
+ self.test_skipatom, n_components=2, reducer="umap", adjusttext=False
+ )
+ assert isinstance(skipatom_umap_plot, plt.Axes)
+
+ def test_dimension_3d_plotter(self):
+ """Test that the dimension_plotter function works in 3D."""
+ skipatom_3d_pca_plot = dimension_plotter(
+ self.test_skipatom, n_components=3, reducer="pca", adjusttext=False
+ )
+ assert isinstance(skipatom_3d_pca_plot, plt.Axes)
+
+ def test_dimension_Nd_plotter(self):
+ """Test that the dimension_plotter function will fail in d>3."""
+ self.assertRaises(
+ ValueError,
+ dimension_plotter,
+ self.test_skipatom,
+ n_components=4,
+ reducer="pca",
+ )