diff --git a/CHANGELOG.md b/CHANGELOG.md index d30bcdb..0e459cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Changed +- to set the elemental composition it is now possible to use dicts with not only int but also the element symbols (str) +- dict keys for elemental compositions will now always be checked for validity ## [0.5.0] - 2024-12-16 ### Changed diff --git a/README.md b/README.md index 204fd3d..ba95768 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,8 @@ def main(): config.generate.max_num_atoms = 15 config.generate.element_composition = "Ce:1-1" # alternatively as a dictionary: config.generate.element_composition = {39:(1,1)} + # or: config.generate.element_composition = {"Ce":(1,1)"} + # or as mixed-key dict, e.g. for Ce and O: {"Ce":(1,1), 7:(2,2)} config.generate.forbidden_elements = "21-30,39-48,57-80" # alternatively as a list: config.generate.forbidden_elements = [20,21,22,23] # 24,25,26... diff --git a/src/mindlessgen/prog/config.py b/src/mindlessgen/prog/config.py index d598b4a..cb1e8bc 100644 --- a/src/mindlessgen/prog/config.py +++ b/src/mindlessgen/prog/config.py @@ -12,6 +12,8 @@ import numpy as np import toml +from mindlessgen.molecules.molecule import PSE_SYMBOLS + from ..molecules import PSE_NUMBERS @@ -290,18 +292,19 @@ def element_composition(self): @element_composition.setter def element_composition( - self, composition: None | str | dict[int, tuple[int | None, int | None]] + self, composition: None | str | dict[int | str, tuple[int | None, int | None]] ) -> None: """ - If composition_str: str, it should be a string with the format: + If composition: str: Parses the element_composition string and stores the parsed data in the _element_composition dictionary. Format: "C:2-10, H:10-20, O:1-5, N:1-*" - If composition_str: dict, it should be a dictionary with integer keys and tuple values. Will be stored as is. + If composition: dict: + Should be a dictionary with integer/string keys and tuple values. Will be stored as is. Arguments: composition_str (str): String with the element composition - composition_str (dict): Dictionary with integer keys and tuple values + composition_str (dict): Dictionary with integer/str keys and tuple values Raises: TypeError: If composition_str is not a string or a dictionary AttributeError: If the element is not found in the periodic table @@ -312,25 +315,50 @@ def element_composition( if not composition: return + + # Will return if composition dict does not contain either int or str keys and tuple[int | None, int | None] values + # Will also return if dict is valid after setting property if isinstance(composition, dict): + tmp = {} + + # Check validity and also convert str keys into atomic numbers for key, value in composition.items(): if ( - not isinstance(key, int) + not (isinstance(key, int) or isinstance(key, str)) or not isinstance(value, tuple) or len(value) != 2 or not all(isinstance(val, int) or val is None for val in value) ): raise TypeError( - "Element composition dictionary should be a dictionary with integer keys and tuple values (int, int)." + "Element composition dictionary should be a dictionary with either integer or string keys and tuple values (int, int)." ) - self._element_composition = composition + + # Convert str keys + if isinstance(key, str): + element_number = PSE_NUMBERS.get(key.lower(), None) + if element_number is None: + raise KeyError( + f"Element {key} not found in the periodic table." + ) + tmp[element_number - 1] = composition[key] + # Check int keys + else: + if key + 1 in PSE_SYMBOLS: + tmp[key] = composition[key] + else: + raise KeyError( + f"Element with atomic number {key+1} (provided key: {key}) not found in the periodic table." + ) + self._element_composition = tmp return + if not isinstance(composition, str): raise TypeError( "Element composition should be a string (will be parsed) or " - + "a dictionary with integer keys and tuple values." + + "a dictionary with integer/string keys and tuple values." ) + # Parsing composition string element_dict: dict[int, tuple[int | None, int | None]] = {} elements = composition.split(",") # remove leading and trailing whitespaces @@ -537,9 +565,11 @@ def check_config(self, verbosity: int = 1) -> None: if ( np.sum( [ - self.element_composition.get(i, (0, 0))[0] - if self.element_composition.get(i, (0, 0))[0] is not None - else 0 + ( + self.element_composition.get(i, (0, 0))[0] + if self.element_composition.get(i, (0, 0))[0] is not None + else 0 + ) for i in self.element_composition ] ) diff --git a/test/test_generate/test_generate_molecule.py b/test/test_generate/test_generate_molecule.py index a020536..190843f 100644 --- a/test/test_generate/test_generate_molecule.py +++ b/test/test_generate/test_generate_molecule.py @@ -41,9 +41,9 @@ def test_generate_atom_list(min_atoms, max_atoms, default_generate_config): assert np.sum(atom_list) <= max_atoms -# Test the element composition property of the GenerateConfig class -def test_generate_config_element_composition(default_generate_config): - """Test the element composition property of the GenerateConfig class.""" +# Test the element composition property of the GenerateConfig class with a composition string +def test_generate_config_element_composition_string(default_generate_config): + """Test the element composition property of the GenerateConfig class with a composition string.""" default_generate_config.min_num_atoms = 10 default_generate_config.max_num_atoms = 15 default_generate_config.element_composition = "C:2-2, N:3-3, O:1-1" @@ -55,6 +55,60 @@ def test_generate_config_element_composition(default_generate_config): assert atom_list[7] == 1 +# Test the element composition property of the GenerateConfig class with an int key composition dict +def test_generate_config_element_composition_dict_int(default_generate_config): + """Test the element composition property of the GenerateConfig class with an int key composition dict.""" + + # Pure int keys + default_generate_config.min_num_atoms = 10 + default_generate_config.max_num_atoms = 15 + default_generate_config.element_composition = { + 5: (2, 2), + 6: (3, 3), + 7: (1, 1), + } # NOTE: mind 0-based indexing for atomic numbers + atom_list = generate_atom_list(default_generate_config, verbosity=1) + + # Check that the atom list contains the correct number of atoms for each element + assert atom_list[5] == 2 + assert atom_list[6] == 3 + assert atom_list[7] == 1 + + +# Test the element composition property of the GenerateConfig class with an int key composition dict +def test_generate_config_element_composition_dict_string(default_generate_config): + """Test the element composition property of the GenerateConfig class with a str key composition dict.""" + + default_generate_config.min_num_atoms = 10 + default_generate_config.max_num_atoms = 15 + default_generate_config.element_composition = { + "C": (2, 2), + "N": (3, 3), + "O": (1, 1), + } + atom_list = generate_atom_list(default_generate_config, verbosity=1) + + # Check that the atom list contains the correct number of atoms for each element + assert atom_list[5] == 2 + assert atom_list[6] == 3 + assert atom_list[7] == 1 + + +# Test the element composition property of the GenerateConfig class with an int key composition dict +def test_generate_config_element_composition_dict_mixed(default_generate_config): + """Test the element composition property of the GenerateConfig class with a str key composition dict.""" + + default_generate_config.min_num_atoms = 10 + default_generate_config.max_num_atoms = 15 + default_generate_config.element_composition = {5: (2, 2), "N": (3, 3), "O": (1, 1)} + atom_list = generate_atom_list(default_generate_config, verbosity=1) + + # Check that the atom list contains the correct number of atoms for each element + assert atom_list[5] == 2 + assert atom_list[6] == 3 + assert atom_list[7] == 1 + + # Test the forbidden_elements property of the GenerateConfig class def test_generate_config_forbidden_elements(default_generate_config): """Test the forbidden_elements property of the GenerateConfig class."""