diff --git a/forcefield_utilities/__init__.py b/forcefield_utilities/__init__.py index a0b39f0..185c521 100644 --- a/forcefield_utilities/__init__.py +++ b/forcefield_utilities/__init__.py @@ -1,3 +1,3 @@ -from forcefield_utilities.xml_loader import FoyerFFs +from forcefield_utilities.xml_loader import FoyerFFs, GMSOFFs __version__ = "0.1.2" diff --git a/forcefield_utilities/gmso_xml.py b/forcefield_utilities/gmso_xml.py new file mode 100644 index 0000000..04c59b0 --- /dev/null +++ b/forcefield_utilities/gmso_xml.py @@ -0,0 +1,990 @@ +from functools import lru_cache +from typing import List, Optional, Set, Union + +import numpy as np +import sympy +import unyt as u +from gmso import ForceField as GMSOForceField +from gmso.core.angle_type import AngleType as GMSOAngleType +from gmso.core.atom_type import AtomType as GMSOAtomType +from gmso.core.bond_type import BondType as GMSOBondType +from gmso.core.dihedral_type import DihedralType as GMSODihedralType +from gmso.core.improper_type import ImproperType as GMSOImproperType +from gmso.utils._constants import FF_TOKENS_SEPARATOR +from pydantic import BaseModel, Field + +from forcefield_utilities.utils import pad_with_wildcards + + +def get_identifiers_registry(): + return { + "AtomTypes": set(), + "BondTypes": set(), + "AngleTypes": set(), + "DihedralTypes": set(), + "ImproperTypes": set(), + "PariPotentialTypes": set(), + } + + +def register_identifiers(registry, identifier, for_type="AtomTypes"): + if identifier in registry: + raise ValueError( + f"Duplicate identifier found for {for_type}: {identifier}" + ) + + if for_type == "AtomTypes": + registry.add(identifier) + elif ( + for_type == "BondTypes" + or for_type == "AngleTypes" + or for_type == "DihedralTypes" + or for_type == "PairPotentialTypes" + ): + registry.add(identifier) + registry.add(tuple(reversed(identifier))) + elif for_type == "ImproperTypes": + (central, second, third, fourth) = identifier + mirrors = [ + (central, second, third, fourth), + (central, second, fourth, third), + (central, third, second, fourth), + (central, third, fourth, second), + (central, fourth, second, third), + (central, fourth, third, second), + ] + for mirror in mirrors: + registry.add(mirror) + + +@lru_cache(maxsize=128) +def indep_vars(expr: str, dependent: frozenset) -> Set: + """Given an expression and dependent variables, return independent variables for it.""" + return sympy.sympify(expr).free_symbols - dependent + + +class GMSOXMLTag(BaseModel): + def parameters(self, units=None): + params = self.children[0] + params_dict = {} + for parameter in params.children: + if units is None: + params_dict[parameter.name] = parameter.value + else: + params_dict[parameter.name] = ( + parameter.value * units[parameter.name] + ) + return params_dict + + class Config: + arbitrary_types_allowed = True + allow_population_by_field_name = True + + +class GMSOXMLChild(GMSOXMLTag): + pass + + +class ParametersUnitDef(GMSOXMLTag): + parameter: str = Field( + ..., description="The name of the parameter", alias="parameter" + ) + + unit: str = Field( + ..., description="The unit of the parameter", alias="unit" + ) + + @classmethod + def load_from_etree(cls, root): + return cls(**root.attrib) + + +class Parameter(GMSOXMLTag): + name: str = Field( + ..., description="The name of the parameter", alias="name" + ) + + value: Union[float, np.ndarray] = Field( + ..., description="The value of the parameter", alias="value" + ) + + @classmethod + def load_from_etree(cls, root): + attribs = root.attrib + if "value" in root.attrib: + return cls(**attribs) + else: + children = root.getchildren() + if len(children) == 0: + raise ValueError( + "Neither a single value nor a sequence of values provided for " + f"parameter {attribs['name']}. Please provide one or the other" + ) + value = np.array( + [param_value.text for param_value in children], dtype=float + ) + return cls(name=attribs["name"], value=value) + + +class Parameters(GMSOXMLTag): + children: List[Parameter] = Field( + ..., description="Parameter", alias="parameter" + ) + + @classmethod + def load_from_etree(cls, root): + children = [] + for el in root.iterchildren(): + if el.tag == "Parameter": + children.append(Parameter.load_from_etree(el)) + return cls(children=children) + + +class AtomType(GMSOXMLTag): + name: str = Field( + ..., description="The name for this atom type", alias="name" + ) + + element: Optional[str] = Field( + None, description="The element of the atom type", alias="element" + ) + + charge: Optional[float] = Field( + None, description="The charge of the atom type", alias="charge" + ) + + mass: Optional[float] = Field( + None, description="The mass of the atom type", alias="mass" + ) + + expression: Optional[str] = Field( + None, + description="The expression for this atom type", + alias="expression", + ) + + independent_variables: Optional[str] = Field( + None, + description="The independent variables for this atom type", + alias="independent_variables", + ) + + atomclass: Optional[str] = Field( + None, description="The atomclass of this atomtype", alias="atomclass" + ) + + doi: Optional[str] = Field( + None, description="The doi of this atomtype", alias="doi" + ) + + overrides: Optional[str] = Field( + None, description="The overrides", alias="overrides" + ) + + definition: Optional[str] = Field( + None, + description="The smarts definition of this atom type", + alias="definition", + ) + + description: Optional[str] = Field( + None, + description="The description of this atom type", + alias="description", + ) + + children: List[Parameters] = Field( + ..., description="The parameters and their values", alias="children" + ) + + @classmethod + def load_from_etree(cls, root): + attribs = root.attrib + children = [] + for el in root.iterchildren(): + if el.tag == "Parameters": + children.append(Parameters.load_from_etree(el)) + return cls(children=children, **attribs) + + +class AtomTypes(GMSOXMLChild): + name: Optional[str] = Field( + None, description="The name for this atom type group", alias="name" + ) + + expression: str = Field( + ..., + description="The expression for this atom type group", + alias="expression", + ) + + children: List[Union[ParametersUnitDef, AtomType]] = Field( + ..., description="The children of AtomTypes", alias="parameters" + ) + + def to_gmso_potentials(self, default_units): + potentials = {"atom_types": {}} + parameters_units = filter( + lambda c: isinstance(c, ParametersUnitDef), self.children + ) + units = { + parameter_unit.parameter: u.Unit(parameter_unit.unit) + for parameter_unit in parameters_units + } + + for atom_type in filter( + lambda c: isinstance(c, AtomType), self.children + ): + atom_type_dict = atom_type.dict( + by_alias=True, + exclude={"children", "element"}, + exclude_none=True, + ) + + overrides = atom_type_dict.get("overrides") + if overrides: + atom_type_dict["overrides"] = set( + o.strip() for o in overrides.split(",") + ) + else: + atom_type_dict["overrides"] = set() + + if "expression" not in atom_type_dict: + atom_type_dict["expression"] = self.expression + atom_type_dict["parameters"] = atom_type.parameters(units) + + if not atom_type_dict.get("independent_variables"): + atom_type_dict["independent_variables"] = indep_vars( + atom_type_dict["expression"], + frozenset(atom_type_dict["parameters"]), + ) + + if default_units.get("charge") and atom_type_dict.get("charge"): + atom_type_dict["charge"] = ( + atom_type_dict["charge"] * default_units["charge"] + ) + + if default_units.get("mass") and atom_type_dict.get("mass"): + atom_type_dict["mass"] = ( + atom_type_dict["mass"] * default_units["mass"] + ) + gmso_atom_type = GMSOAtomType(**atom_type_dict) + element = atom_type.element + if element: + gmso_atom_type.add_tag("element", element) + potentials["atom_types"][atom_type.name] = gmso_atom_type + return potentials + + @classmethod + def load_from_etree(cls, root, existing): + attribs = root.attrib + children = [] + for el in root.iterchildren(): + if el.tag == "ParametersUnitDef": + children.append(ParametersUnitDef.load_from_etree(el)) + elif el.tag == "AtomType": + atom_type = AtomType.load_from_etree(el) + identifier = atom_type.name + register_identifiers(existing, identifier, "AtomTypes") + children.append(atom_type) + return cls(children=children, **attribs) + + +class BondType(GMSOXMLTag): + name: str = Field( + None, description="The name of the bond type", alias="name" + ) + + class1: Optional[str] = Field( + None, description="Class 1 for this bond type", alias="class1" + ) + + class2: Optional[str] = Field( + None, description="Class 2 for this bond type", alias="class2" + ) + + type1: Optional[str] = Field( + None, description="Type 1 for this bond type", alias="type1" + ) + + type2: Optional[str] = Field( + None, description="Type 2 for this bond type", alias="type2" + ) + + children: List[Parameters] = Field( + ..., description="The parameters and their values", alias="children" + ) + + @classmethod + def load_from_etree(cls, root): + children = [] + attribs = pad_with_wildcards(root.attrib, 2) + for el in root.iterchildren(): + if el.tag == "Parameters": + children.append(Parameters.load_from_etree(el)) + return cls(children=children, **attribs) + + +class BondTypes(GMSOXMLChild): + name: Optional[str] = Field( + None, description="The name for this bond types group", alias="name" + ) + + expression: str = Field( + ..., description="The expression for this bond types group" + ) + + children: List[Union[ParametersUnitDef, BondType]] = Field( + ..., description="Children of this bond type tag", alias="children" + ) + + def to_gmso_potentials(self, default_units): + potentials = {"bond_types": {}} + parameters_units = filter( + lambda c: isinstance(c, ParametersUnitDef), self.children + ) + units = { + parameter_unit.parameter: u.Unit(parameter_unit.unit) + for parameter_unit in parameters_units + } + + for bond_type in filter( + lambda c: isinstance(c, BondType), self.children + ): + bond_type_dict = bond_type.dict( + by_alias=True, + exclude={"children", "type1", "type2", "class1", "class2"}, + exclude_none=True, + ) + + if "expression" not in bond_type_dict: + bond_type_dict["expression"] = self.expression + + if bond_type.type1 and bond_type.type2: + bond_type_dict["member_types"] = ( + bond_type.type1, + bond_type.type2, + ) + + elif bond_type.class1 and bond_type.class2: + bond_type_dict["member_classes"] = ( + bond_type.class1, + bond_type.class2, + ) + + bond_type_dict["parameters"] = bond_type.parameters(units) + bond_type_dict["independent_variables"] = indep_vars( + bond_type_dict["expression"], + frozenset(bond_type_dict["parameters"]), + ) + + gmso_bond_type = GMSOBondType(**bond_type_dict) + if gmso_bond_type.member_types: + potentials["bond_types"][ + FF_TOKENS_SEPARATOR.join(gmso_bond_type.member_types) + ] = gmso_bond_type + else: + potentials["bond_types"][ + FF_TOKENS_SEPARATOR.join(gmso_bond_type.member_classes) + ] = gmso_bond_type + + return potentials + + @classmethod + def load_from_etree(cls, root, existing): + attribs = root.attrib + children = [] + for el in root.iterchildren(): + if el.tag == "ParametersUnitDef": + children.append(ParametersUnitDef.load_from_etree(el)) + elif el.tag == "BondType": + bond_type = BondType.load_from_etree(el) + identifier = tuple( + [bond_type.class1, bond_type.class2] + if bond_type.class1 + else [bond_type.type1, bond_type.type2] + ) + register_identifiers(existing, identifier, "BondTypes") + children.append(bond_type) + + return cls(children=children, **attribs) + + +class AngleType(GMSOXMLTag): + name: str = Field( + None, description="The name of the angle type", alias="name" + ) + + class1: Optional[str] = Field( + None, description="Class 1 for this angle type", alias="class1" + ) + + class2: Optional[str] = Field( + None, description="Class 2 for this angle type", alias="class2" + ) + + class3: Optional[str] = Field( + None, description="Class 3 for this angle type", alias="class3" + ) + + type1: Optional[str] = Field( + None, description="Type 1 for this angle type", alias="type1" + ) + + type2: Optional[str] = Field( + None, description="Type 2 for this angle type", alias="type2" + ) + + type3: Optional[str] = Field( + None, description="Type 3 for this angle type", alias="type3" + ) + + children: List[Parameters] = Field( + ..., description="The parameters and their values", alias="children" + ) + + @classmethod + def load_from_etree(cls, root): + children = [] + attribs = pad_with_wildcards(root.attrib, 3) + for el in root.iterchildren(): + if el.tag == "Parameters": + children.append(Parameters.load_from_etree(el)) + return cls(children=children, **attribs) + + +class AngleTypes(GMSOXMLChild): + name: Optional[str] = Field( + None, description="The name for this angle types group", alias="name" + ) + + expression: str = Field( + ..., description="The expression for this angle types group" + ) + + children: List[Union[ParametersUnitDef, AngleType]] = Field( + ..., description="Children of this angle types tag", alias="children" + ) + + def to_gmso_potentials(self, default_units): + potentials = {"angle_types": {}} + parameters_units = filter( + lambda c: isinstance(c, ParametersUnitDef), self.children + ) + units = { + parameter_unit.parameter: u.Unit(parameter_unit.unit) + for parameter_unit in parameters_units + } + + for angle_type in filter( + lambda c: isinstance(c, AngleType), self.children + ): + angle_type_dict = angle_type.dict( + by_alias=True, + exclude={ + "children", + "type1", + "type2", + "type3", + "class1", + "class2", + "class3", + }, + exclude_none=True, + ) + + if "expression" not in angle_type_dict: + angle_type_dict["expression"] = self.expression + + if angle_type.type1 and angle_type.type2 and angle_type.type3: + angle_type_dict["member_types"] = ( + angle_type.type1, + angle_type.type2, + angle_type.type3, + ) + + elif angle_type.class1 and angle_type.class2 and angle_type.class3: + angle_type_dict["member_classes"] = ( + angle_type.class1, + angle_type.class2, + angle_type.class3, + ) + + angle_type_dict["parameters"] = angle_type.parameters(units) + angle_type_dict["independent_variables"] = indep_vars( + angle_type_dict["expression"], + frozenset(angle_type_dict["parameters"]), + ) + gmso_angle_type = GMSOAngleType(**angle_type_dict) + if gmso_angle_type.member_types: + potentials["angle_types"][ + FF_TOKENS_SEPARATOR.join(gmso_angle_type.member_types) + ] = gmso_angle_type + else: + potentials["angle_types"][ + FF_TOKENS_SEPARATOR.join(gmso_angle_type.member_classes) + ] = gmso_angle_type + + return potentials + + @classmethod + def load_from_etree(cls, root, existing): + attribs = root.attrib + children = [] + for el in root.iterchildren(): + if el.tag == "ParametersUnitDef": + children.append(ParametersUnitDef.load_from_etree(el)) + elif el.tag == "AngleType": + angle_type = AngleType.load_from_etree(el) + identifier = tuple( + [angle_type.class1, angle_type.class2, angle_type.class3] + if angle_type.class1 + else [angle_type.type1, angle_type.type2, angle_type.type3] + ) + register_identifiers(existing, identifier, "AngleTypes") + children.append(angle_type) + return cls(children=children, **attribs) + + +class TorsionType(GMSOXMLTag): + name: str = Field( + None, description="The name of the Dihedral/Improper type", alias="name" + ) + + class1: Optional[str] = Field( + None, + description="Class 1 for this Dihedral/Improper type", + alias="class1", + ) + + class2: Optional[str] = Field( + None, + description="Class 2 for this Dihedral/Improper type", + alias="class2", + ) + + class3: Optional[str] = Field( + None, + description="Class 3 for this Dihedral/Improper type", + alias="class3", + ) + + class4: Optional[str] = Field( + None, + description="Class 4 for this Dihedral/Improper type", + alias="class4", + ) + + type1: Optional[str] = Field( + None, + description="Type 1 for this Dihedral/Improper type", + alias="type1", + ) + + type2: Optional[str] = Field( + None, + description="Type 2 for this Dihedral/Improper type", + alias="type2", + ) + + type3: Optional[str] = Field( + None, + description="Type 3 for this Dihedral/Improper type", + alias="type3", + ) + + type4: Optional[str] = Field( + None, + description="Type 4 for this Dihedral/Improper type", + alias="type4", + ) + + children: List[Parameters] = Field( + ..., description="The parameters and their values", alias="children" + ) + + @classmethod + def load_from_etree(cls, root): + children = [] + attribs = pad_with_wildcards(root.attrib, 4) + for el in root.iterchildren(): + if el.tag == "Parameters": + children.append(Parameters.load_from_etree(el)) + return cls(children=children, **attribs) + + +class DihedralType(TorsionType): + pass + + +class ImproperType(TorsionType): + pass + + +class TorsionTypes(GMSOXMLChild): + name: Optional[str] = Field( + None, description="The name for this angle types group", alias="name" + ) + + expression: str = Field( + ..., description="The expression for this angle types group" + ) + + children: List[Union[ParametersUnitDef, TorsionType]] = Field( + ..., + description="Children of this dihedral/improper types tag", + alias="children", + ) + + def to_gmso_potentials(self, default_units): + potentials = {"dihedral_types": {}, "improper_types": {}} + parameters_units = filter( + lambda c: isinstance(c, ParametersUnitDef), self.children + ) + units = { + parameter_unit.parameter: u.Unit(parameter_unit.unit) + for parameter_unit in parameters_units + } + + for torsion_type in filter( + lambda c: isinstance(c, (DihedralType, ImproperType)), self.children + ): + torsion_dict = torsion_type.dict( + by_alias=True, + exclude={ + "children", + "type1", + "type2", + "type3", + "type4", + "class1", + "class2", + "class3", + "class4", + }, + exclude_none=True, + ) + + if "expression" not in torsion_dict: + torsion_dict["expression"] = self.expression + + if ( + torsion_type.type1 + and torsion_type.type2 + and torsion_type.type3 + and torsion_type.type4 + ): + torsion_dict["member_types"] = ( + torsion_type.type1, + torsion_type.type2, + torsion_type.type3, + torsion_type.type4, + ) + + elif ( + torsion_type.class1 + and torsion_type.class2 + and torsion_type.class3 + and torsion_type.class4 + ): + torsion_dict["member_classes"] = ( + torsion_type.class1, + torsion_type.class2, + torsion_type.class3, + torsion_type.class4, + ) + + torsion_dict["parameters"] = torsion_type.parameters(units) + torsion_dict["independent_variables"] = indep_vars( + torsion_dict["expression"], + frozenset(torsion_dict["parameters"]), + ) + if isinstance(torsion_type, DihedralType): + gmso_torsion_type = GMSODihedralType(**torsion_dict) + key = "dihedral_types" + else: + gmso_torsion_type = GMSOImproperType(**torsion_dict) + key = "improper_types" + + if gmso_torsion_type.member_types: + potentials[key][ + FF_TOKENS_SEPARATOR.join(gmso_torsion_type.member_types) + ] = gmso_torsion_type + else: + potentials[key][ + FF_TOKENS_SEPARATOR.join(gmso_torsion_type.member_classes) + ] = gmso_torsion_type + + return potentials + + @classmethod + def load_from_etree(cls, root, existing_dihedrals, existing_impropers): + attribs = root.attrib + children = [] + child_loaders = { + "DihedralType": DihedralType, + "ImproperType": ImproperType, + } + for el in root.iterchildren(): + if el.tag == "ParametersUnitDef": + children.append(ParametersUnitDef.load_from_etree(el)) + elif el.tag == "DihedralType" or el.tag == "ImproperType": + tor_type = child_loaders[el.tag].load_from_etree(el) + identifier = tuple( + [ + tor_type.class1, + tor_type.class2, + tor_type.class3, + tor_type.class4, + ] + if tor_type.class1 + else [ + tor_type.type1, + tor_type.type2, + tor_type.type3, + tor_type.type4, + ] + ) + register_identifiers( + existing_impropers + if el.tag == "ImproperType" + else existing_dihedrals, + identifier, + el.tag + "s", + ) + children.append(tor_type) + + return cls(children=children, **attribs) + + +class ImproperTypes(TorsionTypes): + pass + + +class DihedralTypes(TorsionTypes): + pass + + +class PairPotentialType(GMSOXMLTag): + name: str = Field( + ..., description="Name of this PairPotential Type", alias="name" + ) + + type1: Optional[str] = Field( + None, description="The type1 of this PairPotential Type", alias="type1" + ) + + type2: Optional[str] = Field( + None, description="The type2 of this PairPotential Type", alias="type2" + ) + + class1: Optional[str] = Field( + None, + description="The class1 of this PairPotential Type", + alias="class1", + ) + + class2: Optional[str] = Field( + None, + description="The class2 of this PairPotential Type", + alias="class2", + ) + + children: List[Parameters] = Field( + ..., description="The parameters and their values", alias="children" + ) + + @classmethod + def load_from_etree(cls, root): + attribs = pad_with_wildcards(root.attrib, 2) + children = [] + for el in root.iterchildren(): + children.append(Parameters.load_from_etree(el)) + return cls(children=children, **attribs) + + +class PairPotentialTypes(GMSOXMLChild): + name: Optional[str] = Field( + None, + description="The name of this pair potential types group", + alias="name", + ) + + expression: str = Field( + ..., + description="The expression for this pair potential types group", + alias="expression", + ) + + children: List[Union[ParametersUnitDef, PairPotentialType]] = Field( + ..., description="The children", alias="children" + ) + + @classmethod + def load_from_etree(cls, root, existing): + attribs = root.attrib + children = [] + for el in root.iterchildren(): + if el.tag == "ParametersUnitDef": + children.append(ParametersUnitDef.load_from_etree(el)) + elif el.tag == "PairPotentialType": + pptype = PairPotentialType.load_from_etree(el) + identifier = tuple( + [pptype.class1, pptype.class2] + if pptype.class1 + else [pptype.type1, pptype.type2] + ) + register_identifiers(existing, identifier, "PairPotentialTypes") + children.append(pptype) + return cls(children=children, **attribs) + + +class Units(GMSOXMLTag): + energy: Optional[str] = Field(None, alias="energy") + + distance: Optional[str] = Field(None, alias="distance") + + mass: Optional[str] = Field(None, alias="mass") + + charge: Optional[str] = Field(None, alias="charge") + + temperature: Optional[str] = Field(None, alias="temperature") + + angle: Optional[str] = Field(None, alias="angle") + + time: Optional[str] = Field(None, alias="time") + + @classmethod + def load_from_etree(cls, root): + attribs = root.attrib + return cls(**attribs) + + +class FFMetaData(GMSOXMLChild): + children: List[Units] = Field([], alias="children") + + electrostatics14Scale: float = Field(0.5, alias="electrostatics14Scale") + + nonBonded14Scale: float = Field(0.5, alias="nonBonded14Scale") + + combining_rule: str = Field("geometric", alias="combiningRule") + + @classmethod + def load_from_etree(cls, root): + attribs = root.attrib + children = [] + for unit in root.iterchildren(): + children.append(Units.load_from_etree(unit)) + return cls(children=children, **attribs) + + def gmso_scaling_factors(self): + return self.dict( + include={"electrostatics14Scale", "nonBonded14Scale"}, + exclude_none=True, + ) + + def get_default_units(self): + units_dict = {} + units = self.children[0].dict(by_alias=True, exclude_none=True) + for name, unit in units.items(): + try: + units_dict[name] = u.Unit(unit) + except u.exceptions.UnitParseError: + units_dict[name] = getattr(u.physical_constants, unit) + + return units_dict + + +class ForceField(GMSOXMLTag): + name: str = Field( + "ForceField", description="Name of the ForceField", alias="name" + ) + + version: str = Field( + "1.0.0", description="The version of the ForceField", alias="version" + ) + + children: List[GMSOXMLChild] = Field( + ..., description="The children tags", alias="children" + ) + + def to_gmso_ff(self): + ff = GMSOForceField() + ff.name = self.name + ff.version = self.version + metadata = list( + filter(lambda child: isinstance(child, FFMetaData), self.children) + ).pop() + default_units = metadata.get_default_units() + ff.scaling_factors = metadata.gmso_scaling_factors() + ff.combining_rule = metadata.combining_rule + remaining_children = filter( + lambda c: not isinstance(c, (FFMetaData, Units)), + self.children, + ) + ff_potentials = {} + + for child in remaining_children: + if hasattr(child, "to_gmso_potentials"): + potentials = child.to_gmso_potentials(default_units) + for attr in potentials: + if attr in ff_potentials: + ff_potentials[attr].update(potentials[attr]) + else: + ff_potentials[attr] = potentials[attr] + + for attr in ff_potentials: + setattr(ff, attr, ff_potentials[attr]) + + return ff + + @classmethod + def load_from_etree(cls, root) -> "ForceField": + attribs = root.attrib + children = [] + identifiers_registry = get_identifiers_registry() + for el in root.iterchildren(): + if el.tag == "FFMetaData": + children.append(FFMetaData.load_from_etree(el)) + if el.tag == "AtomTypes": + children.append( + AtomTypes.load_from_etree( + el, identifiers_registry["AtomTypes"] + ) + ) + elif el.tag == "BondTypes": + children.append( + BondTypes.load_from_etree( + el, identifiers_registry["BondTypes"] + ) + ) + elif el.tag == "AngleTypes": + children.append( + AngleTypes.load_from_etree( + el, identifiers_registry["AngleTypes"] + ) + ) + elif el.tag == "DihedralTypes": + children.append( + DihedralTypes.load_from_etree( + el, + identifiers_registry["DihedralTypes"], + identifiers_registry["ImproperTypes"], + ) + ) + elif el.tag == "ImproperTypes": + children.append( + ImproperTypes.load_from_etree( + el, + identifiers_registry["DihedralTypes"], + identifiers_registry["ImproperTypes"], + ) + ) + elif el.tag == "PairPotentialTypes": + children.append( + PairPotentialTypes.load_from_etree( + el, identifiers_registry["PairPotentialTypes"] + ) + ) + + return cls(children=children, **attribs) diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua.xml b/forcefield_utilities/tests/files/propanol_Mie_ua.xml new file mode 100644 index 0000000..9c4939b --- /dev/null +++ b/forcefield_utilities/tests/files/propanol_Mie_ua.xml @@ -0,0 +1,116 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_angle_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_angle_type.xml new file mode 100644 index 0000000..2923710 --- /dev/null +++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_angle_type.xml @@ -0,0 +1,130 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_bond_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_bond_type.xml new file mode 100644 index 0000000..9333165 --- /dev/null +++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_bond_type.xml @@ -0,0 +1,134 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_dihedral_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_dihedral_type.xml new file mode 100644 index 0000000..885d4c0 --- /dev/null +++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_dihedral_type.xml @@ -0,0 +1,154 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_improper_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_improper_type.xml new file mode 100644 index 0000000..b9be0e1 --- /dev/null +++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_improper_type.xml @@ -0,0 +1,154 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_list_wildcards.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_list_wildcards.xml new file mode 100644 index 0000000..3111feb --- /dev/null +++ b/forcefield_utilities/tests/files/propanol_Mie_ua_list_wildcards.xml @@ -0,0 +1,139 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0.0 + 0.416955197548017 + -0.0579667482245528 + 0.37345529632637 + 0.0 + 0.0 + + + 0.0 + 180.0 + 0.0 + 0.0 + 0.0 + 0.0 + + + + + + + 0.0 + 0.6955197548017 + -0.023 + 0.37345529632637 + 0.0 + 0.0 + + + 0.0 + 90.0 + 0.0 + 0.0 + 0.0 + 0.0 + + + + + diff --git a/forcefield_utilities/tests/test_gmso_xml.py b/forcefield_utilities/tests/test_gmso_xml.py new file mode 100644 index 0000000..c35ca11 --- /dev/null +++ b/forcefield_utilities/tests/test_gmso_xml.py @@ -0,0 +1,301 @@ +import pytest +import unyt as u +from gmso.tests.utils import get_path +from lxml import etree +from sympy import sympify + +from forcefield_utilities.tests.base_test import BaseTest +from forcefield_utilities.tests.utils import get_test_file_path +from forcefield_utilities.xml_loader import GMSOFFs + + +class TestEthyleneFF(BaseTest): + @pytest.fixture(scope="session") + def ff_example_zero(self): + example_zero = get_path("ethylene.xml") + return GMSOFFs().load(example_zero).to_gmso_ff() + + def test_metadata(self, ff_example_zero): + assert ff_example_zero.scaling_factors == { + "electrostatics14Scale": 0.5, + "nonBonded14Scale": 0.5, + } + assert ff_example_zero.combining_rule == "geometric" + + def test_atom_types(self, ff_example_zero): + opls_143 = ff_example_zero.atom_types["opls_143"] + assert opls_143.name == "opls_143" + assert opls_143.charge == -0.23 * u.elementary_charge + assert opls_143.get_tag("element") == "C" + + def test_bond_types(self, ff_example_zero): + btype_harmonic_1 = ff_example_zero.bond_types["opls_143~opls_143"] + assert str(btype_harmonic_1.expression) == "0.5*k*(r - r_eq)**2" + assert btype_harmonic_1.member_types == ("opls_143", "opls_143") + assert btype_harmonic_1.parameters == { + "k": u.unyt_quantity(459403.2, "kJ/mol/nm**2"), + "r_eq": u.unyt_quantity(0.134, "nm"), + } + + assert btype_harmonic_1.name == "BondType-Harmonic-1" + + def test_angle_types(self, ff_example_zero): + angle_type_harmonic_2 = ff_example_zero.angle_types[ + "opls_144~opls_143~opls_144" + ] + assert ( + str(angle_type_harmonic_2.expression) + == "0.5*k*(theta - theta_eq)**2" + ) + assert angle_type_harmonic_2.member_types == ( + "opls_144", + "opls_143", + "opls_144", + ) + assert angle_type_harmonic_2.parameters == { + "k": u.unyt_quantity(292.88, "kJ/(mol*radian**2)"), + "theta_eq": u.unyt_quantity(2.0420352248, "radian"), + } + + assert angle_type_harmonic_2.name == "AngleType-Harmonic-2" + + def test_dihedral_types(self, ff_example_zero): + dihedral_type_rb_1 = ff_example_zero.dihedral_types[ + "opls_144~opls_143~opls_143~opls_144" + ] + + assert ( + str(dihedral_type_rb_1.expression) + == "c_0 + c_1*cos(psi) + c_2*cos(psi)**2 + c_3*cos(psi)**3 + c_4*cos(psi)**4 + c_5*cos(psi)**5" + ) + assert dihedral_type_rb_1.member_types == ( + "opls_144", + "opls_143", + "opls_143", + "opls_144", + ) + + assert dihedral_type_rb_1.parameters == { + "c_0": u.unyt_quantity(58.576, "kJ/mol"), + "c_1": u.unyt_quantity(0.0, "kJ/mol"), + "c_2": u.unyt_quantity(-58.576, "kJ/mol"), + "c_3": u.unyt_quantity(0.0, "kJ/mol"), + "c_4": u.unyt_quantity(0.0, "kJ/mol"), + "c_5": u.unyt_quantity(0.0, "kJ/mol"), + } + + assert dihedral_type_rb_1.name == "DihedralType-RyckaertBellemans-1" + + def test_improper_types(self, ff_example_zero): + pass + + +class TestTwoPropanolMIEFF(BaseTest): + @pytest.fixture(scope="session") + def propanol_ua_mie(self): + propanol_ua_mie_path = get_test_file_path("propanol_Mie_ua.xml") + return GMSOFFs().load(propanol_ua_mie_path).to_gmso_ff() + + def test_metadata(self, propanol_ua_mie): + assert ( + propanol_ua_mie.name + == "Mie two-propanol- This is for testing only and not for use for simulations " + ) + assert propanol_ua_mie.scaling_factors == { + "electrostatics14Scale": 0.0, + "nonBonded14Scale": 0.0, + } + assert propanol_ua_mie.combining_rule == "geometric" + + def test_atom_types(self, propanol_ua_mie): + ch3_sp3 = propanol_ua_mie.atom_types["CH3_sp3"] + + assert ch3_sp3.name == "CH3_sp3" + assert ch3_sp3.atomclass == "CH3" + assert ch3_sp3.get_tag("element") == "_CH3" + assert u.allclose_units(ch3_sp3.charge, 0.0 * u.coulomb) + assert ch3_sp3.definition == "[_CH3;X1][_CH3,_HC]" + assert u.allclose_units(ch3_sp3.mass, 15.03500 * u.amu) + assert ( + ch3_sp3.description + == "Alkane CH3, Mie using the k constant from Trappe-UA" + ) + assert ch3_sp3.doi == "10.1021/jp984742e and 10.1021/jp972543+" + assert ch3_sp3.overrides == set() + + assert ch3_sp3.expression == sympify( + "(n/(n-m)) * (n/m)**(m/(n-m)) * epsilon * ((sigma/r)**n - (sigma/r)**m)" + ) + + parameters = ch3_sp3.get_parameters() + assert u.allclose_units( + parameters["epsilon"], 0.194746017346801 * u.kcal / u.mol + ) + assert u.allclose_units(parameters["sigma"], 3.751 * u.angstrom) + assert u.allclose_units(parameters["n"], 11 * u.dimensionless) + assert u.allclose_units(parameters["m"], 6 * u.dimensionless) + + o = propanol_ua_mie.atom_types["O"] + + assert o.name == "O" + assert o.atomclass == "O" + assert o.get_tag("element") == "O" + assert u.allclose_units(o.charge, -0.700 * u.elementary_charge) + assert o.definition == "OH" + assert u.allclose_units(o.mass, 15.99940 * u.amu) + assert o.description == "Oxygen in hydroxyl" + assert o.doi == "10.1021/jp003882x" + assert o.overrides == set() + + assert o.expression == sympify( + "(n/(n-m)) * (n/m)**(m/(n-m)) * epsilon * ((sigma/r)**n - (sigma/r)**m)" + ) + + parameters = o.get_parameters() + assert u.allclose_units( + parameters["epsilon"], 0.184809996053596 * u.kcal / u.mol + ) + assert u.allclose_units(parameters["sigma"], 3.021 * u.angstrom) + assert u.allclose_units(parameters["n"], 13 * u.dimensionless) + assert u.allclose_units(parameters["m"], 6 * u.dimensionless) + + def test_bond_types(self, propanol_ua_mie): + bond_type_ch3_ch = propanol_ua_mie.bond_types["CH3~CH"] + assert bond_type_ch3_ch.name == "BondType_Harmonic_CH3_CH" + assert bond_type_ch3_ch.member_classes == ("CH3", "CH") + assert bond_type_ch3_ch.expression == sympify("k * (r-r_eq)**2") + + parameters = bond_type_ch3_ch.get_parameters() + assert u.allclose_units( + parameters["k"], 1200.80305927342 * u.kcal / u.mol / u.angstrom**2 + ) + assert u.allclose_units(parameters["r_eq"], 1.5401 * u.angstrom) + + def test_angle_types(self, propanol_ua_mie): + angle_type_ch3_ch_o = propanol_ua_mie.angle_types["CH3~CH~O"] + + assert angle_type_ch3_ch_o.name == "AngleType_Harmonic_CH3_CH_O" + assert angle_type_ch3_ch_o.member_classes == ("CH3", "CH", "O") + assert angle_type_ch3_ch_o.expression == sympify( + "k * (theta - theta_eq)**2" + ) + + parameters = angle_type_ch3_ch_o.get_parameters() + assert u.allclose_units( + parameters["k"], 100.155094635497 * u.kcal / u.mol / u.radian**2 + ) + assert u.allclose_units(parameters["theta_eq"], 109.51 * u.degree) + + def test_dihedral_types(self, propanol_ua_mie): + dihedral_type_ch3_ch_o_h = propanol_ua_mie.dihedral_types["CH3~CH~O~H"] + + assert ( + dihedral_type_ch3_ch_o_h.name + == "DihedralType_Periodic_Proper_CH3_CH_O_H" + ) + assert dihedral_type_ch3_ch_o_h.member_classes == ( + "CH3", + "CH", + "O", + "H", + ) + assert dihedral_type_ch3_ch_o_h.expression == sympify( + "k0 + k1 * (1 + cos(1 * phi - phi_eq1)) + " + "k2 * (1 + cos(2 * phi - phi_eq2)) + " + "k3 * (1 + cos(3 * phi - phi_eq3)) + " + "k4 * (1 + cos(4 * phi - phi_eq4)) + " + "k5 * (1 + cos(5 * phi - phi_eq5))" + ) + + parameters = dihedral_type_ch3_ch_o_h.get_parameters() + assert u.allclose_units(parameters["k0"], 0.0 * u.kcal / u.mol) + assert u.allclose_units( + parameters["k1"], 0.416955197548017 * u.kcal / u.mol + ) + assert u.allclose_units( + parameters["k2"], -0.0579667482245528 * u.kcal / u.mol + ) + assert u.allclose_units( + parameters["k3"], 0.37345529632637 * u.kcal / u.mol + ) + assert u.allclose_units(parameters["k4"], 0.0 * u.kcal / u.mol) + assert u.allclose_units(parameters["k5"], 0.0 * u.kcal / u.mol) + + assert u.allclose_units(parameters["phi_eq1"], 0.0 * u.degree) + assert u.allclose_units(parameters["phi_eq2"], 180 * u.degree) + assert u.allclose_units(parameters["phi_eq3"], 0.0 * u.degree) + assert u.allclose_units(parameters["phi_eq4"], 0.0 * u.degree) + assert u.allclose_units(parameters["phi_eq5"], 0.0 * u.degree) + + +class TestListParameters(BaseTest): + @pytest.fixture(scope="session") + def propanol_ua_mie_list(self): + propanol_ua_mie_path = get_test_file_path( + "propanol_Mie_ua_list_wildcards.xml" + ) + return GMSOFFs().load(propanol_ua_mie_path).to_gmso_ff() + + def test_dihedral_params(self, propanol_ua_mie_list): + dih_with_list = propanol_ua_mie_list.dihedral_types["CH3~CH~O~H"] + params = dih_with_list.get_parameters() + assert u.allclose_units( + params["phi_eq"], [0.0, 180.0, 0.0, 0.0, 0.0, 0.0] * u.degree + ) + assert u.allclose_units( + params["k"], + [0.0, 0.4169552, -0.05796675, 0.3734553, 0.0, 0.0] * u.kcal / u.mol, + ) + + dih_with_list_wildcards = propanol_ua_mie_list.get_potential( + group="dihedral_type", key=["", "CH", "O", "H"] + ) + params_wildcards_dihedrals = dih_with_list_wildcards.get_parameters() + assert u.allclose_units( + params_wildcards_dihedrals["phi_eq"], + [0.0, 90.0, 0.0, 0.0, 0.0, 0.0] * u.degree, + ) + assert u.allclose_units( + params_wildcards_dihedrals["k"], + [0.0, 0.69551975, -0.023, 0.3734553, 0.0, 0.0] * u.kcal / u.mol, + ) + + +class TestDuplicateEntries: + def test_value_error_bond_types(self): + with pytest.raises(ValueError): + GMSOFFs().load( + get_test_file_path( + "propanol_Mie_ua_duplicate_entries_bond_type.xml" + ) + ) + + def test_value_error_angle_types(self): + with pytest.raises( + ValueError, match=r"Duplicate identifier found for AngleTypes.*" + ): + GMSOFFs().load( + get_test_file_path( + "propanol_Mie_ua_duplicate_entries_angle_type.xml" + ) + ) + + def test_value_error_dihedral_types(self): + with pytest.raises( + ValueError, match=r"Duplicate identifier found for DihedralTypes.*" + ): + GMSOFFs().load( + get_test_file_path( + "propanol_Mie_ua_duplicate_entries_dihedral_type.xml" + ) + ) + + def test_value_error_improper_types(self): + with pytest.raises( + ValueError, match=r"Duplicate identifier found for ImproperTypes.*" + ): + GMSOFFs().load( + get_test_file_path( + "propanol_Mie_ua_duplicate_entries_improper_type.xml" + ) + ) diff --git a/forcefield_utilities/tests/test_xml_loader.py b/forcefield_utilities/tests/test_xml_loader.py index 251c8df..f9891b1 100644 --- a/forcefield_utilities/tests/test_xml_loader.py +++ b/forcefield_utilities/tests/test_xml_loader.py @@ -1,8 +1,9 @@ import pytest from forcefield_utilities.tests.base_test import BaseTest +from forcefield_utilities.tests.utils import get_test_file_path from forcefield_utilities.utils import get_package_file_path -from forcefield_utilities.xml_loader import FoyerFFs +from forcefield_utilities.xml_loader import FoyerFFs, GMSOFFs class TestXMLLoader(BaseTest): @@ -10,6 +11,10 @@ class TestXMLLoader(BaseTest): def foyer_xml_loader(self): return FoyerFFs() + @pytest.fixture + def gmso_xml_loader(self): + return GMSOFFs() + def test_load_gaff(self, foyer_xml_loader): foyer_xml_loader.get_ff("gaff") assert "gaff" in foyer_xml_loader.loaded_ffs @@ -56,3 +61,16 @@ def test_custom_register(self, foyer_xml_loader): assert "benzene_lb" in foyer_xml_loader.loaded_ffs assert loaded_id != id(foyer_xml_loader.load("benzene_lb")) + + def test_different_loading_entries(self, foyer_xml_loader, gmso_xml_loader): + assert id(foyer_xml_loader.loaded_ffs) != id(gmso_xml_loader.loaded_ffs) + gmso_xml_loader.load(get_test_file_path("propanol_Mie_ua.xml")) + assert "propanol_Mie_ua" in gmso_xml_loader.loaded_ffs + assert "propanol_Mie_ua" not in foyer_xml_loader.loaded_ffs + + def test_class_methods_gmso_ff(self): + gmso_xml_loader = GMSOFFs() + gmso_xml_loader.load(get_test_file_path("propanol_Mie_ua.xml")) + ff1 = gmso_xml_loader.get_ff("propanol_Mie_ua") + ff2 = gmso_xml_loader.load("propanol_Mie_ua") + assert id(ff1) == id(ff2) diff --git a/forcefield_utilities/tests/utils.py b/forcefield_utilities/tests/utils.py new file mode 100644 index 0000000..84fd646 --- /dev/null +++ b/forcefield_utilities/tests/utils.py @@ -0,0 +1,13 @@ +from pathlib import Path + + +def get_test_file_path(filename): + """Give a filename, return its location in test files.""" + file_path = Path(__file__).parent / "files" / filename + + if not file_path.resolve().exists(): + raise FileNotFoundError( + f"File {filename} not found in {file_path.parent}" + ) + + return str(file_path) diff --git a/forcefield_utilities/utils.py b/forcefield_utilities/utils.py index 04f4058..eb94245 100644 --- a/forcefield_utilities/utils.py +++ b/forcefield_utilities/utils.py @@ -50,3 +50,23 @@ def _deprecate_kwargs(kwargs, deprecated_kwargs): DeprecationWarning, 3, ) + + +def pad_with_wildcards(input_dictionary, max_len, wildcard="*"): + """Pad empty type or classes with wildcards""" + types = [f"type{j+1}" for j in range(max_len)] + classes = [f"class{j+1}" for j in range(max_len)] + + if types[0] in input_dictionary: + for type_ in types: + value = input_dictionary[type_] + if isinstance(value, str) and value.strip() == "": + input_dictionary[type_] = wildcard + + elif classes[0] in input_dictionary: + for class_ in classes: + value = input_dictionary[class_] + if isinstance(value, str) and value.strip() == "": + input_dictionary[class_] = wildcard + + return input_dictionary diff --git a/forcefield_utilities/xml_loader.py b/forcefield_utilities/xml_loader.py index edda286..9aa93e0 100644 --- a/forcefield_utilities/xml_loader.py +++ b/forcefield_utilities/xml_loader.py @@ -1,11 +1,12 @@ -import importlib -import os +import abc from pathlib import Path from typing import Union +from gmso.utils.ff_utils import _validate_schema as validate_gmso_schema from lxml import etree -from forcefield_utilities.foyer_xml import ForceField +from forcefield_utilities.foyer_xml import ForceField as FoyerForceField +from forcefield_utilities.gmso_xml import ForceField as GMSOForceField from forcefield_utilities.utils import ( call_on_import, deprecate_kwargs, @@ -15,35 +16,28 @@ custom_forcefields = {} -def _parse_foyer_xml(xml_path): - """Return the foyer Forcefield object from the relative path ``xml_path`` inside the foyer package.""" - with open(xml_path) as ff_file: - root = etree.parse(ff_file).getroot() - return ForceField.load_from_etree(root) - - -class FoyerFFs: - """Object to provide methods to forcefields shipped with Foyer. +class XMLLoader: + """Object to provide methods to forcefields shipped with Foyer/GMSO. Attributes __________ loaded_ffs : dict, keys are strings, values are the loaded forcefield This is a place to store loaded xmls, which can be accessed through - the getter by indexing the FoyerFFs object. xmls are only stored once, + the getter by indexing the object. xmls are only stored once, and custom_xml is where a forcefield is stored from path. Note: This is a class level attribute. Methods _______ get_ff(ffname="oplsaa"): - Load and directly return the forcefield xml from Foyer or a path. + Load and directly return the forcefield xml from Foyer/GMSO or a path. Parameters __________ ffname : str - Name of forcefield to load. See self.ff_registry attribute - for loading method. - load(ffname="oplsaa", rel_to_module=False): - Load and return the forcefield xml from Foyer or a path. + Name of forcefield to load or path to load + + load(ffname="oplsaa"): + Load and return the forcefield xml from Foyer/GMSO or a path. Notes ----- @@ -51,14 +45,20 @@ class level attribute. `clear_loaded_ffs` can be used to clear the cache. """ - loaded_ffs = {} - overwritten_custom_ffs = set() + loaded_ffs = None + overwritten_custom_ffs = None + search_foyer = False + + @abc.abstractmethod + def load_xml(self, xml_path) -> Union[FoyerForceField, GMSOForceField]: + """Load the xml file""" + return NotImplemented @classmethod @deprecate_kwargs(deprecated_kwargs={"rel_to_module"}) def get_ff( cls, ffname: Union[str, Path], rel_to_module: bool = False - ) -> ForceField: + ) -> Union[FoyerForceField, GMSOForceField]: """Load and directly return the forcefield xml from Foyer or a path. Parameters __________ @@ -72,10 +72,12 @@ def get_ff( ff : forcefield_utilities.foyer_xml.Forcefield """ loader = cls() - return loader.load(ffname, rel_to_module) + return loader.load(ffname) @deprecate_kwargs(deprecated_kwargs={"rel_to_module"}) - def load(self, ffname, rel_to_module=False): + def load( + self, ffname, rel_to_module=False + ) -> Union[FoyerForceField, GMSOForceField]: """Load and return the forcefield xml from Foyer or a path. Parameters __________ @@ -86,7 +88,7 @@ def load(self, ffname, rel_to_module=False): Returns ______ - ff : forcefield_utilities.foyer_xml.Forcefield + ff : forcefield_utilities.foyer_xml.Forcefield or forcefield_utilties.gmso_xml.Forcefield Notes ----- @@ -118,29 +120,33 @@ def load(self, ffname, rel_to_module=False): ff_path = Path(ffname) if ffname in custom_forcefields: - self.loaded_ffs[ffname] = _parse_foyer_xml( + self.loaded_ffs[ffname] = self.load_xml( xml_path=custom_forcefields[ffname] ) self.overwritten_custom_ffs.discard(ffname) return self.loaded_ffs[ff_path.name] if self._is_xml(ff_path): - self.loaded_ffs[ff_path.stem] = _parse_foyer_xml( + self.loaded_ffs[ff_path.stem] = self.load_xml( xml_path=ff_path.resolve() ) - else: + elif self.search_foyer: xml_path = get_package_file_path( "foyer", f"forcefields/xml/{ffname}.xml" ) - self.loaded_ffs[ff_path.stem] = _parse_foyer_xml(xml_path) + self.loaded_ffs[ff_path.stem] = self.load_xml(xml_path) + else: + raise FileNotFoundError( + f"{ffname} not found, it isn't registered forcefiled name or a XML file." + ) return self.loaded_ffs[ff_path.stem] - def __getitem__(self, ffname) -> ForceField: + def __getitem__(self, ffname) -> Union[FoyerForceField, GMSOForceField]: """Get function for indexing by loaded forcefields.""" return self.loaded_ffs[ffname] - def __getattr__(self, item) -> ForceField: + def __getattr__(self, item) -> Union[FoyerForceField, GMSOForceField]: """Accessor for loaded forcefields.""" if item in self.loaded_ffs: return self.loaded_ffs[item] @@ -152,7 +158,7 @@ def __getattr__(self, item) -> ForceField: def register_custom_forcefield( self, name: str, path_: Union[str, Path], overwrite: bool = True ) -> None: - """Register a custom foyer forcefield's XML path to load. + """Register a custom foyer/gmso forcefield's XML path to load. Parameters ---------- @@ -181,10 +187,42 @@ def _is_xml(path_: Path) -> bool: return path_.suffix == ".xml" +class FoyerFFs(XMLLoader): + """Utility class to load foyer forcefields.""" + + loaded_ffs = {} + overwritten_custom_ffs = set() + search_foyer = True + + def load_xml(self, xml_path): + """Return the foyer Forcefield object from the relative path ``xml_path`` inside the foyer package.""" + with open(xml_path) as ff_file: + root = etree.parse(ff_file).getroot() + return FoyerForceField.load_from_etree(root) + + +class GMSOFFs(XMLLoader): + """Utility class to load gmso forcefields.""" + + loaded_ffs = {} + overwritten_custom_ffs = set() + search_foyer = False + + def load_xml(self, xml_path): + """Return the gmso Forcefield object from the relative path ``xml_path`` for a gmso XML.""" + with open(xml_path) as ff_file: + ff_etree = etree.parse(ff_file) + validate_gmso_schema(ff_etree) + root = ff_etree.getroot() + return GMSOForceField.load_from_etree(root) + + @call_on_import def register_gaff(): """Include GAFF as part of FoyerFFs if antefoyer is installed locally.""" try: + import importlib + importlib.import_module("antefoyer") except ImportError: return