diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fcb31c..13cae96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,41 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## Unreleased -[Compare with latest](https://github.com/WMD-group/ElementEmbeddings/compare/v0.1...HEAD) +[Compare with latest](https://github.com/WMD-group/ElementEmbeddings/compare/v0.1.1...HEAD) + +### Added + +- Added dev & doc reqs and updated version no. ([e9278c5](https://github.com/WMD-group/ElementEmbeddings/commit/e9278c579a031643576f137196aad34e0f5ea98f) by Anthony Onwuli). +- Added test for dimension plotter in 3D ([ef37daf](https://github.com/WMD-group/ElementEmbeddings/commit/ef37daff6aa824c3d9917fa1ba26fc37b95a9951) by Anthony Onwuli). +- Added test for plotting euclidean heatmap ([e9cc5ed](https://github.com/WMD-group/ElementEmbeddings/commit/e9cc5ed5508e624420b1330973425572ff5b1628) by Anthony Onwuli). +- Added deprecation warning to old plotting function ([542e2f2](https://github.com/WMD-group/ElementEmbeddings/commit/542e2f2e6bd96b0f0e1624192cb9a9a98fb3dfcc) by Anthony Onwuli). +- Added functions to test PCA and UMAP for Embedding ([b7ccc8f](https://github.com/WMD-group/ElementEmbeddings/commit/b7ccc8f41384e5e6095090aa016088279b5a0439) by Anthony Onwuli). +- Added test function for tSNE for Embedding class ([aaa1472](https://github.com/WMD-group/ElementEmbeddings/commit/aaa147279ba609984482813df2ce9530420da2be) by Anthony Onwuli). +- Added more tests for compute metrics function ([c4055bc](https://github.com/WMD-group/ElementEmbeddings/commit/c4055bcdad6e5bd7832a8568767ced72cd9cdfd9) by Anthony Onwuli). +- Added test for computing spearman's rank ([5f715aa](https://github.com/WMD-group/ElementEmbeddings/commit/5f715aaa3ba339b5e01012cb0a40c44652481b55) by Anthony Onwuli). +- Added more dataframe test functions ([edeeb87](https://github.com/WMD-group/ElementEmbeddings/commit/edeeb8714ae80b194159738b562606819ffc3ccb) by Anthony Onwuli). +- Added test for removing multiple elements ([b69d806](https://github.com/WMD-group/ElementEmbeddings/commit/b69d80699cad211166ff1b112886d19d387890b5) by Anthony Onwuli). +- Added test function for removing elements ([664cbbf](https://github.com/WMD-group/ElementEmbeddings/commit/664cbbf1846757b7d018c199745b6227465c0268) by Anthony Onwuli). +- Added setUpClass to test_core.py ([7cb8ab6](https://github.com/WMD-group/ElementEmbeddings/commit/7cb8ab6d3b731d04831cdfe83a90b926ab1e2a1b) by Anthony Onwuli). +- Added tests for `as_dataframe` method ([7256c23](https://github.com/WMD-group/ElementEmbeddings/commit/7256c23d8d2840b77983424ee9247a90f1caaded) by Anthony Onwuli). +- Added more tests to the Embedding loading function ([7f9f87b](https://github.com/WMD-group/ElementEmbeddings/commit/7f9f87b987a77f1d4b73cf9fed289a5d9a028417) by Anthony Onwuli). +- Added test functions for loading csv and json ([87ba958](https://github.com/WMD-group/ElementEmbeddings/commit/87ba9581c506bc16aa377961da28c0cbe60e80de) by Anthony Onwuli). +- Added test embedding json file ([188aea4](https://github.com/WMD-group/ElementEmbeddings/commit/188aea48e21b3c3d5a1b9624a9885b94f14b2fcc) by Anthony Onwuli). +- Added test embedding csv file ([173bcee](https://github.com/WMD-group/ElementEmbeddings/commit/173bcee057173ec1a48cdc7bb3141406236119ce) by Anthony Onwuli). + +### Fixed + +- Fixed spelling error for extras_require setup.py ([8e28e9a](https://github.com/WMD-group/ElementEmbeddings/commit/8e28e9a09550bfcaf21ec4d95989cd031d717596) by Anthony Onwuli). + +### Removed + +- Removed outdated installation instructions ([c69817f](https://github.com/WMD-group/ElementEmbeddings/commit/c69817fef331e203fb3861e603c7c0176097e51f) by Anthony Onwuli). +- Removed an else block from `load_data` ([5532de6](https://github.com/WMD-group/ElementEmbeddings/commit/5532de6d050580382f0fa9688be96f0e9cd231ec) by Anthony Onwuli). + + +## [v0.1.1](https://github.com/WMD-group/ElementEmbeddings/releases/tag/v0.1.1) - 2023-07-05 + +[Compare with v0.1](https://github.com/WMD-group/ElementEmbeddings/compare/v0.1...v0.1.1) ### Added @@ -20,7 +54,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - removed pandas, pytest and pytest-subtests in reqs ([cd1bf77](https://github.com/WMD-group/ElementEmbeddings/commit/cd1bf776220250377bb7cd48cca6b08e9a968f1d) by Anthony Onwuli). - ## [v0.1](https://github.com/WMD-group/ElementEmbeddings/releases/tag/v0.1) - 2023-06-30 [Compare with first commit](https://github.com/WMD-group/ElementEmbeddings/compare/262c7e99a438a3527fb73866093ae8cb1ee85ee6...v0.1) diff --git a/CITATION.cff b/CITATION.cff index 6c4b898..d12ab06 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -36,5 +36,5 @@ keywords: - element representation license: MIT commit: d3f30602abf825ba3dcd5f247694a174a358ef49 -version: '0.1' -date-released: '2023-06-30' +version: '0.2.0' +date-released: '2023-07-07' diff --git a/README.md b/README.md index 8cd0fac..e59034c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ ElementEmbeddings [![documentation](https://img.shields.io/badge/docs-mkdocs%20material-blue.svg?style=flat)](https://wmd-group.github.io/ElementEmbeddings/) ![python version](https://img.shields.io/pypi/pyversions/elementembeddings) - The **Element Embeddings** package provides high-level tools for analysing elemental embeddings data. This primarily involves visualising the correlation between embedding schemes using different statistical measures. @@ -37,14 +36,18 @@ importing the class. Installation -------- + The latest stable release can be installed via pip using: + ```bash pip install ElementEmbeddings ``` -The latest version can be installed using: + +For installing the development or documentation dependencies via pip: ```bash -pip install git+git://github.com/WMD-group/ElementEmbeddings.git +pip install "ElementEmbeddings[dev]" +pip install "ElementEmbeddings[docs]" ``` For development, you can clone the repository and install the package in editable mode. @@ -80,4 +83,3 @@ We can access some of the properties of the `Embedding` class. For example, we c The magpie representation has embeddings of dimension 22 ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk'] ``` - diff --git a/setup.py b/setup.py index 6497cd0..6a71566 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ module_dir = os.path.dirname(os.path.abspath(__file__)) -VERSION = "0.1.1" +VERSION = "0.2.0" DESCRIPTION = "Element Embeddings" with open(os.path.join(module_dir, "README.md"), encoding="utf-8") as f: LONG_DESCRIPTION = f.read() @@ -37,6 +37,26 @@ "umap-learn==0.5.3", "adjustText==0.8", ], + extras_require={ + "dev": [ + "pre-commit==2.20.0", + "black==23.3.0", + "isort==5.12.0", + "pytest==7.2.1", + "pytest-subtests==0.10.0", + "nbqa==1.5.3", + "pyupgrade==3.3.1", + "flake8==6.0.0", + "autopep8==2.0.1", + "pytest-cov==4.1.0", + ], + "docs": [ + "mkdocs==1.4.3", + "mkdocs-material==9.1.17", + "mkdocstrings ==0.21.2", + "mkdocstrings-python == 1.0.0", + ], + }, classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 3.8", diff --git a/src/elementembeddings/core.py b/src/elementembeddings/core.py index 5d8c132..21a94a0 100644 --- a/src/elementembeddings/core.py +++ b/src/elementembeddings/core.py @@ -151,11 +151,6 @@ def load_data(embedding_name: Optional[str] = None): _json = path.join(data_directory, _cbfv_files[embedding_name]) with open(_json) as f: embedding_data = json.load(f) - - # Load a json file from a file specified in the input - else: - with open(embedding_name) as f: - embedding_data = json.load(f) else: raise ( ValueError( @@ -853,6 +848,10 @@ def plot_PCA_2D( ax (matplotlib.axes.Axes): An Axes object with the PCA plot """ + warnings.warn( + "This method is deprecated and will be removed in a future release. ", + DeprecationWarning, + ) embeddings_array = np.array(list(self.embeddings.values())) element_array = np.array(self.element_list) @@ -913,6 +912,10 @@ def plot_tSNE( """ + warnings.warn( + "This method is deprecated and will be removed in a future release. ", + DeprecationWarning, + ) embeddings_array = np.array(list(self.embeddings.values())) element_array = np.array(self.element_list) diff --git a/src/elementembeddings/tests/files/test_embedding.csv b/src/elementembeddings/tests/files/test_embedding.csv new file mode 100644 index 0000000..74a2a4c --- /dev/null +++ b/src/elementembeddings/tests/files/test_embedding.csv @@ -0,0 +1,4 @@ +element,0,1,2,3,4,5,6,7,8,9 +H,1,0,0,0,0,0,0,0,0,0 +He,0,1,0,0,0,0,0,0,0,0 +Li,0,0,1,0,0,0,0,0,0,0 \ No newline at end of file diff --git a/src/elementembeddings/tests/files/test_embedding.json b/src/elementembeddings/tests/files/test_embedding.json new file mode 100644 index 0000000..abcd5a9 --- /dev/null +++ b/src/elementembeddings/tests/files/test_embedding.json @@ -0,0 +1 @@ +{"H":[1,0,0,0,0,0,0,0,0,0], "He":[0,1,0,0,0,0,0,0,0,0],"Li":[0,0,1,0,0,0,0,0,0,0]} \ No newline at end of file diff --git a/src/elementembeddings/tests/test_composition.py b/src/elementembeddings/tests/test_composition.py index 3641c67..6ef635d 100644 --- a/src/elementembeddings/tests/test_composition.py +++ b/src/elementembeddings/tests/test_composition.py @@ -11,35 +11,53 @@ class TestComposition(unittest.TestCase): """Test the composition module.""" + def setUp(self): + """Set up the test class.""" + self.formulas = [ + "Sr3Sc2(GeO4)3", + "Fe2O3", + "Li7La3ZrO12", + "CsPbI3", + ] + def test_formula_parser(self): """Test the formula_parser function.""" - LLZO_parsed = composition.formula_parser("Li7La3ZrO12") + LLZO_parsed = composition.formula_parser(self.formulas[2]) assert isinstance(LLZO_parsed, dict) assert "Zr" in LLZO_parsed assert LLZO_parsed["Li"] == 7 + def test_formula_parser_with_parentheses(self): + """Test the formula_parser function with parentheses.""" + SrScGeO4_parsed = composition.formula_parser(self.formulas[0]) + assert isinstance(SrScGeO4_parsed, dict) + assert "Sr" in SrScGeO4_parsed + assert SrScGeO4_parsed["Ge"] == 3 + + def test_formula_parser_with_invalid_formula(self): + """Test the formula_parser function with an invalid formula.""" + with self.assertRaises(ValueError): + composition.formula_parser("Sr3Sc2(GeO4)3)") + def test__get_fractional_composition(self): """Test the _get_fractional_composition function.""" - CsPbI3_frac = composition._get_fractional_composition("CsPbI3") + CsPbI3_frac = composition._get_fractional_composition(self.formulas[3]) assert isinstance(CsPbI3_frac, dict) assert "Pb" in CsPbI3_frac assert CsPbI3_frac["I"] == 0.6 - def test_Composition_class(self): - """Test the Composition class.""" - Fe2O3_magpie = composition.CompositionalEmbedding( - formula="Fe2O3", embedding="magpie" - ) - assert isinstance(Fe2O3_magpie.embedding, core.Embedding) - assert Fe2O3_magpie.formula == "Fe2O3" - assert Fe2O3_magpie.embedding_name == "magpie" - assert isinstance(Fe2O3_magpie.composition, dict) - assert {"Fe": 2, "O": 3} == Fe2O3_magpie.composition - assert Fe2O3_magpie._natoms == 5 - assert Fe2O3_magpie.fractional_composition == {"Fe": 0.4, "O": 0.6} - assert isinstance(Fe2O3_magpie._mean_feature_vector(), np.ndarray) - # Test that the feature vector function works - stats = [ + +class TestCompositionalEmbedding(unittest.TestCase): + """Test the CompositionalEmbedding class.""" + + def setUp(self): + """Set up the test formulas.""" + self.formulas = ["Sr3Sc2(GeO4)3", "Fe2O3", "Li7La3ZrO12", "CsPbI3", "CsPbI-3"] + self.valid_magpie_compositions = [ + composition.CompositionalEmbedding(formula=formula, embedding="magpie") + for formula in self.formulas[:3] + ] + self.stats = [ "mean", "variance", "minpool", @@ -49,16 +67,50 @@ def test_Composition_class(self): "geometric_mean", "harmonic_mean", ] - assert isinstance(Fe2O3_magpie.feature_vector(stats=stats), np.ndarray) + + def test_CompositionalEmbedding_attributes(self): + """Test the Composition class.""" + Fe2O3_magpie = self.valid_magpie_compositions[1] + assert isinstance(Fe2O3_magpie.embedding, core.Embedding) + assert Fe2O3_magpie.formula == "Fe2O3" + assert Fe2O3_magpie.embedding_name == "magpie" + assert isinstance(Fe2O3_magpie.composition, dict) + assert {"Fe": 2, "O": 3} == Fe2O3_magpie.composition + assert Fe2O3_magpie.num_atoms == 5 + assert Fe2O3_magpie.fractional_composition == {"Fe": 0.4, "O": 0.6} + assert Fe2O3_magpie.embedding.dim == 22 + + def test_CompositionalEmbedding_negative_formula(self): + """Test the Composition class with a negative formula.""" + with self.assertRaises(ValueError): + composition.CompositionalEmbedding( + formula=self.formulas[4], embedding="magpie" + ) + + def test__mean_feature_vector(self): + """Test the _mean_feature_vector function.""" + assert isinstance( + self.valid_magpie_compositions[1]._mean_feature_vector(), np.ndarray + ) + # Test that the feature vector function works + + def test_feature_vector(self): + """Test the feature_vector function.""" + assert isinstance( + self.valid_magpie_compositions[0].feature_vector(stats=self.stats), + np.ndarray, + ) assert len( - Fe2O3_magpie.feature_vector(stats=stats) - ) == Fe2O3_magpie.embedding.dim * len(stats) + self.valid_magpie_compositions[0].feature_vector(stats=self.stats) + ) == self.valid_magpie_compositions[0].embedding.dim * len(self.stats) # Test that the feature vector function works with a single stat - assert isinstance(Fe2O3_magpie.feature_vector(stats="mean"), np.ndarray) + assert isinstance( + self.valid_magpie_compositions[0].feature_vector(stats="mean"), np.ndarray + ) def test_composition_featuriser(self): """Test the composition featuriser function.""" - formulas = ["Fe2O3", "Li7La3ZrO12", "CsPbI3"] + formulas = self.formulas[:3] formula_df = pd.DataFrame(formulas, columns=["formula"]) assert isinstance(composition.composition_featuriser(formula_df), pd.DataFrame) assert composition.composition_featuriser(formula_df).shape == (3, 2) diff --git a/src/elementembeddings/tests/test_core.py b/src/elementembeddings/tests/test_core.py index de96d3d..845b3b0 100644 --- a/src/elementembeddings/tests/test_core.py +++ b/src/elementembeddings/tests/test_core.py @@ -1,4 +1,5 @@ """Test the core module of AtomicEmbeddings.""" +import os import unittest import matplotlib.pyplot as plt @@ -7,26 +8,49 @@ from elementembeddings.core import Embedding +test_files_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "files") +TEST_EMBEDDING_CSV = os.path.join(test_files_dir, "test_embedding.csv") +TEST_EMBEDDING_JSON = os.path.join(test_files_dir, "test_embedding.json") + class EmbeddingTest(unittest.TestCase): """Test the Embedding class.""" # High Level functions + @classmethod + def setUpClass(cls): + """Set up the test.""" + cls.test_skipatom = Embedding.load_data("skipatom") + cls.test_megnet16 = Embedding.load_data("megnet16") + cls.test_matscholar = Embedding.load_data("matscholar") + cls.test_mod_petti = Embedding.load_data("mod_petti") + cls.test_magpie = Embedding.load_data("magpie") + + def test_Embedding_attributes(self): + """Test attributes of the loaded embeddings.""" + assert self.test_skipatom.dim == 200 + assert self.test_skipatom.embedding_name == "skipatom" + assert self.test_megnet16.dim == 16 + assert self.test_megnet16.embedding_name == "megnet16" + assert self.test_matscholar.dim == 200 + assert self.test_matscholar.embedding_name == "matscholar" + assert self.test_mod_petti.dim == 1 + assert self.test_mod_petti.embedding_name == "mod_petti" + assert isinstance(self.test_skipatom.citation(), list) + assert isinstance(self.test_megnet16.citation(), list) + assert isinstance(self.test_matscholar.citation(), list) + assert isinstance(self.test_mod_petti.citation(), list) - def test_Embedding_loading(self): - """Test that the Embedding class can load the data.""" - skipatom = Embedding.load_data("skipatom") - megnet16 = Embedding.load_data("megnet16") - assert skipatom.dim == 200 - assert skipatom.embedding_name == "skipatom" - assert megnet16.dim == 16 - assert megnet16.embedding_name == "megnet16" - assert isinstance(skipatom.citation(), list) - assert isinstance(megnet16.citation(), list) + def test_Embedding_file_input(self): + """Test that the Embedding class can load custom data.""" + embedding_csv = Embedding.from_csv(TEST_EMBEDDING_CSV) + embedding_json = Embedding.from_json(TEST_EMBEDDING_JSON) + assert embedding_csv.dim == 10 + assert embedding_json.dim == 10 def test_Embeddings_class_magpie(self): """Test that the Embedding class can load the magpie data.""" - magpie = Embedding.load_data("magpie") + magpie = self.test_magpie # Check if the embeddings attribute is a dict assert isinstance(magpie.embeddings, dict) # Check if the embedding vector is a numpy array @@ -253,26 +277,73 @@ def test_Embeddings_class_magpie(self): # TO-DO # Create tests for checking dataframes and plotting functions + + def test_as_dataframe(self): + """Test the as_dataframe method.""" + magpie = self.test_magpie assert isinstance(magpie.as_dataframe(), pd.DataFrame) - assert isinstance(magpie.to(fmt="json"), str) - assert isinstance(magpie.to(fmt="csv"), str) + assert "H" in magpie.as_dataframe().index.tolist() + assert isinstance(magpie.as_dataframe(columns="elements"), pd.DataFrame) + assert "H" in magpie.as_dataframe(columns="elements").columns.tolist() + self.assertRaises(ValueError, magpie.as_dataframe, columns="test") + + def test_to(self): + """Test the to method.""" + assert isinstance(self.test_magpie.to(fmt="json"), str) + self.test_magpie.to(fmt="json", filename="test.json") + assert os.path.isfile("test.json") + os.remove("test.json") + + assert isinstance(self.test_magpie.to(fmt="csv"), str) + self.test_magpie.to(fmt="csv", filename="test.csv") + assert os.path.isfile("test.csv") + os.remove("test.csv") + + def test_compute_metric_functions(self): + """Test the compute metric functions.""" assert isinstance( - magpie.compute_correlation_metric("H", "O", metric="pearson"), + self.test_magpie.compute_correlation_metric("H", "O", metric="pearson"), float, ) assert isinstance( - magpie.compute_distance_metric( + self.test_magpie.compute_distance_metric( "H", "O", ), float, ) - assert isinstance(magpie.distance_df(), pd.DataFrame) - assert magpie.distance_df().shape == ( - len(list(magpie.create_pairs())) * 2 - len(magpie.embeddings), + assert isinstance( + self.test_magpie.compute_distance_metric("H", "O", "energy"), + float, + ) + assert isinstance( + self.test_magpie.compute_distance_metric("H", "O", "cosine_distance"), + float, + ) + assert isinstance( + self.test_magpie.compute_correlation_metric("H", "O", metric="spearman"), + float, + ) + + self.assertRaises( + ValueError, self.test_skipatom.compute_distance_metric, "He", "O" + ) + self.assertRaises( + ValueError, self.test_skipatom.compute_distance_metric, "O", "He" + ) + self.assertRaises( + ValueError, self.test_skipatom.compute_distance_metric, "Li", "O", "euclid" + ) + + def test_distance_dataframe_functions(self): + """Test the distance dataframe functions.""" + assert isinstance(self.test_magpie.distance_df(), pd.DataFrame) + assert self.test_magpie.distance_df().shape == ( + len(list(self.test_magpie.create_pairs())) * 2 + - len(self.test_magpie.embeddings), 7, ) - assert magpie.distance_df().columns.tolist() == [ + assert self.test_magpie.distance_df().columns.tolist() == [ "ele_1", "ele_2", "mend_1", @@ -281,8 +352,49 @@ def test_Embeddings_class_magpie(self): "Z_2", "euclidean", ] - assert isinstance(magpie.distance_pivot_table(), pd.DataFrame) - assert isinstance(magpie.plot_distance_correlation(), plt.Axes) + assert isinstance(self.test_magpie.distance_pivot_table(), pd.DataFrame) + assert isinstance( + self.test_magpie.distance_pivot_table(sortby="atomic_number"), pd.DataFrame + ) + assert isinstance(self.test_magpie.plot_distance_correlation(), plt.Axes) assert isinstance( - magpie.plot_distance_correlation(metric="euclidean"), plt.Axes + self.test_magpie.plot_distance_correlation(metric="euclidean"), plt.Axes + ) + assert isinstance(self.test_magpie.stats_correlation_df(), pd.DataFrame) + + def test_remove_elements(self): + """Test the remove_elements function.""" + assert isinstance(self.test_skipatom.remove_elements("H"), Embedding) + assert isinstance(self.test_skipatom.remove_elements(["H", "Li"]), Embedding) + self.assertIsNone(self.test_skipatom.remove_elements("H", inplace=True)) + self.assertFalse(self.test_skipatom._is_el_in_embedding("H")) + self.assertIsNone( + self.test_skipatom.remove_elements(["Li", "Ti", "Bi"], inplace=True) + ) + assert "Li" not in self.test_skipatom.element_list + assert "Ti" not in self.test_skipatom.element_list + assert "Bi" not in self.test_skipatom.element_list + + def test_PCA(self): + """Test the PCA function.""" + assert isinstance(self.test_matscholar.calculate_PC(), np.ndarray) + assert self.test_matscholar.calculate_PC().shape == ( + len(self.test_matscholar.element_list), + 2, + ) + + def test_tSNE(self): + """Test the tSNE function.""" + assert isinstance(self.test_matscholar.calculate_tSNE(), np.ndarray) + assert self.test_matscholar.calculate_tSNE().shape == ( + len(self.test_matscholar.element_list), + 2, + ) + + def test_UMAP(self): + """Test the UMAP function.""" + assert isinstance(self.test_matscholar.calculate_UMAP(), np.ndarray) + assert self.test_matscholar.calculate_UMAP().shape == ( + len(self.test_matscholar.element_list), + 2, ) diff --git a/src/elementembeddings/tests/test_plotter.py b/src/elementembeddings/tests/test_plotter.py index d34c799..59f9c81 100644 --- a/src/elementembeddings/tests/test_plotter.py +++ b/src/elementembeddings/tests/test_plotter.py @@ -8,30 +8,91 @@ from elementembeddings.plotter import dimension_plotter, heatmap_plotter -class heatmapTest(unittest.TestCase): +class HeatmapTest(unittest.TestCase): """Test the heatmap_plotter function.""" + @classmethod + def setUpClass(cls): + """Set up the test class.""" + cls.test_skipatom = Embedding.load_data("skipatom") + def test_heatmap_plotter(self): """Test that the heatmap_plotter function works.""" - # Load the data - skipatom = Embedding.load_data("skipatom") # Get the embeddings skipatom_cos_plot = heatmap_plotter( - skipatom, + self.test_skipatom, metric="cosine_similarity", ) assert isinstance(skipatom_cos_plot, plt.Axes) + skipatom_euc_plot = heatmap_plotter( + self.test_skipatom, metric="euclidean", show_axislabels=False + ) + assert isinstance(skipatom_euc_plot, plt.Axes) -class dimensionTest(unittest.TestCase): +class DimensionTest(unittest.TestCase): """Test the dimension_plotter function.""" - def test_dimension_plotter(self): + @classmethod + def setUpClass(cls): + """Set up the test class.""" + cls.test_skipatom = Embedding.load_data("skipatom") + + def test_dimension_2d_plotter(self): """Test that the dimension_plotter function works.""" - # Load the data - skipatom = Embedding.load_data("skipatom") - # Get the embeddings skipatom_pca_plot = dimension_plotter( - skipatom, n_components=2, reducer="pca", adjusttext=True + self.test_skipatom, n_components=2, reducer="pca", adjusttext=False ) assert isinstance(skipatom_pca_plot, plt.Axes) + skipatom_tsne_plot = dimension_plotter( + self.test_skipatom, n_components=2, reducer="tsne", adjusttext=False + ) + assert isinstance(skipatom_tsne_plot, plt.Axes) + skipatom_umap_plot = dimension_plotter( + self.test_skipatom, n_components=2, reducer="umap", adjusttext=True + ) + assert isinstance(skipatom_umap_plot, plt.Axes) + + self.assertRaises( + ValueError, + dimension_plotter, + self.test_skipatom, + n_components=2, + reducer="badreducer", + ) + + def test_dimension_2d_plotter_preloaded_reduction(self): + """Test that the dimension_plotter function works with a preloaded reduction.""" + self.test_skipatom.calculate_PC() + self.test_skipatom.calculate_tSNE() + self.test_skipatom.calculate_UMAP() + + skipatom_pca_plot = dimension_plotter( + self.test_skipatom, n_components=2, reducer="pca", adjusttext=False + ) + assert isinstance(skipatom_pca_plot, plt.Axes) + skipatom_tsne_plot = dimension_plotter( + self.test_skipatom, n_components=2, reducer="tsne", adjusttext=False + ) + assert isinstance(skipatom_tsne_plot, plt.Axes) + skipatom_umap_plot = dimension_plotter( + self.test_skipatom, n_components=2, reducer="umap", adjusttext=False + ) + assert isinstance(skipatom_umap_plot, plt.Axes) + + def test_dimension_3d_plotter(self): + """Test that the dimension_plotter function works in 3D.""" + skipatom_3d_pca_plot = dimension_plotter( + self.test_skipatom, n_components=3, reducer="pca", adjusttext=False + ) + assert isinstance(skipatom_3d_pca_plot, plt.Axes) + + def test_dimension_Nd_plotter(self): + """Test that the dimension_plotter function will fail in d>3.""" + self.assertRaises( + ValueError, + dimension_plotter, + self.test_skipatom, + n_components=4, + reducer="pca", + )