diff --git a/forcefield_utilities/__init__.py b/forcefield_utilities/__init__.py
index a0b39f0..185c521 100644
--- a/forcefield_utilities/__init__.py
+++ b/forcefield_utilities/__init__.py
@@ -1,3 +1,3 @@
-from forcefield_utilities.xml_loader import FoyerFFs
+from forcefield_utilities.xml_loader import FoyerFFs, GMSOFFs
__version__ = "0.1.2"
diff --git a/forcefield_utilities/gmso_xml.py b/forcefield_utilities/gmso_xml.py
new file mode 100644
index 0000000..04c59b0
--- /dev/null
+++ b/forcefield_utilities/gmso_xml.py
@@ -0,0 +1,990 @@
+from functools import lru_cache
+from typing import List, Optional, Set, Union
+
+import numpy as np
+import sympy
+import unyt as u
+from gmso import ForceField as GMSOForceField
+from gmso.core.angle_type import AngleType as GMSOAngleType
+from gmso.core.atom_type import AtomType as GMSOAtomType
+from gmso.core.bond_type import BondType as GMSOBondType
+from gmso.core.dihedral_type import DihedralType as GMSODihedralType
+from gmso.core.improper_type import ImproperType as GMSOImproperType
+from gmso.utils._constants import FF_TOKENS_SEPARATOR
+from pydantic import BaseModel, Field
+
+from forcefield_utilities.utils import pad_with_wildcards
+
+
+def get_identifiers_registry():
+ return {
+ "AtomTypes": set(),
+ "BondTypes": set(),
+ "AngleTypes": set(),
+ "DihedralTypes": set(),
+ "ImproperTypes": set(),
+ "PariPotentialTypes": set(),
+ }
+
+
+def register_identifiers(registry, identifier, for_type="AtomTypes"):
+ if identifier in registry:
+ raise ValueError(
+ f"Duplicate identifier found for {for_type}: {identifier}"
+ )
+
+ if for_type == "AtomTypes":
+ registry.add(identifier)
+ elif (
+ for_type == "BondTypes"
+ or for_type == "AngleTypes"
+ or for_type == "DihedralTypes"
+ or for_type == "PairPotentialTypes"
+ ):
+ registry.add(identifier)
+ registry.add(tuple(reversed(identifier)))
+ elif for_type == "ImproperTypes":
+ (central, second, third, fourth) = identifier
+ mirrors = [
+ (central, second, third, fourth),
+ (central, second, fourth, third),
+ (central, third, second, fourth),
+ (central, third, fourth, second),
+ (central, fourth, second, third),
+ (central, fourth, third, second),
+ ]
+ for mirror in mirrors:
+ registry.add(mirror)
+
+
+@lru_cache(maxsize=128)
+def indep_vars(expr: str, dependent: frozenset) -> Set:
+ """Given an expression and dependent variables, return independent variables for it."""
+ return sympy.sympify(expr).free_symbols - dependent
+
+
+class GMSOXMLTag(BaseModel):
+ def parameters(self, units=None):
+ params = self.children[0]
+ params_dict = {}
+ for parameter in params.children:
+ if units is None:
+ params_dict[parameter.name] = parameter.value
+ else:
+ params_dict[parameter.name] = (
+ parameter.value * units[parameter.name]
+ )
+ return params_dict
+
+ class Config:
+ arbitrary_types_allowed = True
+ allow_population_by_field_name = True
+
+
+class GMSOXMLChild(GMSOXMLTag):
+ pass
+
+
+class ParametersUnitDef(GMSOXMLTag):
+ parameter: str = Field(
+ ..., description="The name of the parameter", alias="parameter"
+ )
+
+ unit: str = Field(
+ ..., description="The unit of the parameter", alias="unit"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ return cls(**root.attrib)
+
+
+class Parameter(GMSOXMLTag):
+ name: str = Field(
+ ..., description="The name of the parameter", alias="name"
+ )
+
+ value: Union[float, np.ndarray] = Field(
+ ..., description="The value of the parameter", alias="value"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ attribs = root.attrib
+ if "value" in root.attrib:
+ return cls(**attribs)
+ else:
+ children = root.getchildren()
+ if len(children) == 0:
+ raise ValueError(
+ "Neither a single value nor a sequence of values provided for "
+ f"parameter {attribs['name']}. Please provide one or the other"
+ )
+ value = np.array(
+ [param_value.text for param_value in children], dtype=float
+ )
+ return cls(name=attribs["name"], value=value)
+
+
+class Parameters(GMSOXMLTag):
+ children: List[Parameter] = Field(
+ ..., description="Parameter", alias="parameter"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ children = []
+ for el in root.iterchildren():
+ if el.tag == "Parameter":
+ children.append(Parameter.load_from_etree(el))
+ return cls(children=children)
+
+
+class AtomType(GMSOXMLTag):
+ name: str = Field(
+ ..., description="The name for this atom type", alias="name"
+ )
+
+ element: Optional[str] = Field(
+ None, description="The element of the atom type", alias="element"
+ )
+
+ charge: Optional[float] = Field(
+ None, description="The charge of the atom type", alias="charge"
+ )
+
+ mass: Optional[float] = Field(
+ None, description="The mass of the atom type", alias="mass"
+ )
+
+ expression: Optional[str] = Field(
+ None,
+ description="The expression for this atom type",
+ alias="expression",
+ )
+
+ independent_variables: Optional[str] = Field(
+ None,
+ description="The independent variables for this atom type",
+ alias="independent_variables",
+ )
+
+ atomclass: Optional[str] = Field(
+ None, description="The atomclass of this atomtype", alias="atomclass"
+ )
+
+ doi: Optional[str] = Field(
+ None, description="The doi of this atomtype", alias="doi"
+ )
+
+ overrides: Optional[str] = Field(
+ None, description="The overrides", alias="overrides"
+ )
+
+ definition: Optional[str] = Field(
+ None,
+ description="The smarts definition of this atom type",
+ alias="definition",
+ )
+
+ description: Optional[str] = Field(
+ None,
+ description="The description of this atom type",
+ alias="description",
+ )
+
+ children: List[Parameters] = Field(
+ ..., description="The parameters and their values", alias="children"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ attribs = root.attrib
+ children = []
+ for el in root.iterchildren():
+ if el.tag == "Parameters":
+ children.append(Parameters.load_from_etree(el))
+ return cls(children=children, **attribs)
+
+
+class AtomTypes(GMSOXMLChild):
+ name: Optional[str] = Field(
+ None, description="The name for this atom type group", alias="name"
+ )
+
+ expression: str = Field(
+ ...,
+ description="The expression for this atom type group",
+ alias="expression",
+ )
+
+ children: List[Union[ParametersUnitDef, AtomType]] = Field(
+ ..., description="The children of AtomTypes", alias="parameters"
+ )
+
+ def to_gmso_potentials(self, default_units):
+ potentials = {"atom_types": {}}
+ parameters_units = filter(
+ lambda c: isinstance(c, ParametersUnitDef), self.children
+ )
+ units = {
+ parameter_unit.parameter: u.Unit(parameter_unit.unit)
+ for parameter_unit in parameters_units
+ }
+
+ for atom_type in filter(
+ lambda c: isinstance(c, AtomType), self.children
+ ):
+ atom_type_dict = atom_type.dict(
+ by_alias=True,
+ exclude={"children", "element"},
+ exclude_none=True,
+ )
+
+ overrides = atom_type_dict.get("overrides")
+ if overrides:
+ atom_type_dict["overrides"] = set(
+ o.strip() for o in overrides.split(",")
+ )
+ else:
+ atom_type_dict["overrides"] = set()
+
+ if "expression" not in atom_type_dict:
+ atom_type_dict["expression"] = self.expression
+ atom_type_dict["parameters"] = atom_type.parameters(units)
+
+ if not atom_type_dict.get("independent_variables"):
+ atom_type_dict["independent_variables"] = indep_vars(
+ atom_type_dict["expression"],
+ frozenset(atom_type_dict["parameters"]),
+ )
+
+ if default_units.get("charge") and atom_type_dict.get("charge"):
+ atom_type_dict["charge"] = (
+ atom_type_dict["charge"] * default_units["charge"]
+ )
+
+ if default_units.get("mass") and atom_type_dict.get("mass"):
+ atom_type_dict["mass"] = (
+ atom_type_dict["mass"] * default_units["mass"]
+ )
+ gmso_atom_type = GMSOAtomType(**atom_type_dict)
+ element = atom_type.element
+ if element:
+ gmso_atom_type.add_tag("element", element)
+ potentials["atom_types"][atom_type.name] = gmso_atom_type
+ return potentials
+
+ @classmethod
+ def load_from_etree(cls, root, existing):
+ attribs = root.attrib
+ children = []
+ for el in root.iterchildren():
+ if el.tag == "ParametersUnitDef":
+ children.append(ParametersUnitDef.load_from_etree(el))
+ elif el.tag == "AtomType":
+ atom_type = AtomType.load_from_etree(el)
+ identifier = atom_type.name
+ register_identifiers(existing, identifier, "AtomTypes")
+ children.append(atom_type)
+ return cls(children=children, **attribs)
+
+
+class BondType(GMSOXMLTag):
+ name: str = Field(
+ None, description="The name of the bond type", alias="name"
+ )
+
+ class1: Optional[str] = Field(
+ None, description="Class 1 for this bond type", alias="class1"
+ )
+
+ class2: Optional[str] = Field(
+ None, description="Class 2 for this bond type", alias="class2"
+ )
+
+ type1: Optional[str] = Field(
+ None, description="Type 1 for this bond type", alias="type1"
+ )
+
+ type2: Optional[str] = Field(
+ None, description="Type 2 for this bond type", alias="type2"
+ )
+
+ children: List[Parameters] = Field(
+ ..., description="The parameters and their values", alias="children"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ children = []
+ attribs = pad_with_wildcards(root.attrib, 2)
+ for el in root.iterchildren():
+ if el.tag == "Parameters":
+ children.append(Parameters.load_from_etree(el))
+ return cls(children=children, **attribs)
+
+
+class BondTypes(GMSOXMLChild):
+ name: Optional[str] = Field(
+ None, description="The name for this bond types group", alias="name"
+ )
+
+ expression: str = Field(
+ ..., description="The expression for this bond types group"
+ )
+
+ children: List[Union[ParametersUnitDef, BondType]] = Field(
+ ..., description="Children of this bond type tag", alias="children"
+ )
+
+ def to_gmso_potentials(self, default_units):
+ potentials = {"bond_types": {}}
+ parameters_units = filter(
+ lambda c: isinstance(c, ParametersUnitDef), self.children
+ )
+ units = {
+ parameter_unit.parameter: u.Unit(parameter_unit.unit)
+ for parameter_unit in parameters_units
+ }
+
+ for bond_type in filter(
+ lambda c: isinstance(c, BondType), self.children
+ ):
+ bond_type_dict = bond_type.dict(
+ by_alias=True,
+ exclude={"children", "type1", "type2", "class1", "class2"},
+ exclude_none=True,
+ )
+
+ if "expression" not in bond_type_dict:
+ bond_type_dict["expression"] = self.expression
+
+ if bond_type.type1 and bond_type.type2:
+ bond_type_dict["member_types"] = (
+ bond_type.type1,
+ bond_type.type2,
+ )
+
+ elif bond_type.class1 and bond_type.class2:
+ bond_type_dict["member_classes"] = (
+ bond_type.class1,
+ bond_type.class2,
+ )
+
+ bond_type_dict["parameters"] = bond_type.parameters(units)
+ bond_type_dict["independent_variables"] = indep_vars(
+ bond_type_dict["expression"],
+ frozenset(bond_type_dict["parameters"]),
+ )
+
+ gmso_bond_type = GMSOBondType(**bond_type_dict)
+ if gmso_bond_type.member_types:
+ potentials["bond_types"][
+ FF_TOKENS_SEPARATOR.join(gmso_bond_type.member_types)
+ ] = gmso_bond_type
+ else:
+ potentials["bond_types"][
+ FF_TOKENS_SEPARATOR.join(gmso_bond_type.member_classes)
+ ] = gmso_bond_type
+
+ return potentials
+
+ @classmethod
+ def load_from_etree(cls, root, existing):
+ attribs = root.attrib
+ children = []
+ for el in root.iterchildren():
+ if el.tag == "ParametersUnitDef":
+ children.append(ParametersUnitDef.load_from_etree(el))
+ elif el.tag == "BondType":
+ bond_type = BondType.load_from_etree(el)
+ identifier = tuple(
+ [bond_type.class1, bond_type.class2]
+ if bond_type.class1
+ else [bond_type.type1, bond_type.type2]
+ )
+ register_identifiers(existing, identifier, "BondTypes")
+ children.append(bond_type)
+
+ return cls(children=children, **attribs)
+
+
+class AngleType(GMSOXMLTag):
+ name: str = Field(
+ None, description="The name of the angle type", alias="name"
+ )
+
+ class1: Optional[str] = Field(
+ None, description="Class 1 for this angle type", alias="class1"
+ )
+
+ class2: Optional[str] = Field(
+ None, description="Class 2 for this angle type", alias="class2"
+ )
+
+ class3: Optional[str] = Field(
+ None, description="Class 3 for this angle type", alias="class3"
+ )
+
+ type1: Optional[str] = Field(
+ None, description="Type 1 for this angle type", alias="type1"
+ )
+
+ type2: Optional[str] = Field(
+ None, description="Type 2 for this angle type", alias="type2"
+ )
+
+ type3: Optional[str] = Field(
+ None, description="Type 3 for this angle type", alias="type3"
+ )
+
+ children: List[Parameters] = Field(
+ ..., description="The parameters and their values", alias="children"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ children = []
+ attribs = pad_with_wildcards(root.attrib, 3)
+ for el in root.iterchildren():
+ if el.tag == "Parameters":
+ children.append(Parameters.load_from_etree(el))
+ return cls(children=children, **attribs)
+
+
+class AngleTypes(GMSOXMLChild):
+ name: Optional[str] = Field(
+ None, description="The name for this angle types group", alias="name"
+ )
+
+ expression: str = Field(
+ ..., description="The expression for this angle types group"
+ )
+
+ children: List[Union[ParametersUnitDef, AngleType]] = Field(
+ ..., description="Children of this angle types tag", alias="children"
+ )
+
+ def to_gmso_potentials(self, default_units):
+ potentials = {"angle_types": {}}
+ parameters_units = filter(
+ lambda c: isinstance(c, ParametersUnitDef), self.children
+ )
+ units = {
+ parameter_unit.parameter: u.Unit(parameter_unit.unit)
+ for parameter_unit in parameters_units
+ }
+
+ for angle_type in filter(
+ lambda c: isinstance(c, AngleType), self.children
+ ):
+ angle_type_dict = angle_type.dict(
+ by_alias=True,
+ exclude={
+ "children",
+ "type1",
+ "type2",
+ "type3",
+ "class1",
+ "class2",
+ "class3",
+ },
+ exclude_none=True,
+ )
+
+ if "expression" not in angle_type_dict:
+ angle_type_dict["expression"] = self.expression
+
+ if angle_type.type1 and angle_type.type2 and angle_type.type3:
+ angle_type_dict["member_types"] = (
+ angle_type.type1,
+ angle_type.type2,
+ angle_type.type3,
+ )
+
+ elif angle_type.class1 and angle_type.class2 and angle_type.class3:
+ angle_type_dict["member_classes"] = (
+ angle_type.class1,
+ angle_type.class2,
+ angle_type.class3,
+ )
+
+ angle_type_dict["parameters"] = angle_type.parameters(units)
+ angle_type_dict["independent_variables"] = indep_vars(
+ angle_type_dict["expression"],
+ frozenset(angle_type_dict["parameters"]),
+ )
+ gmso_angle_type = GMSOAngleType(**angle_type_dict)
+ if gmso_angle_type.member_types:
+ potentials["angle_types"][
+ FF_TOKENS_SEPARATOR.join(gmso_angle_type.member_types)
+ ] = gmso_angle_type
+ else:
+ potentials["angle_types"][
+ FF_TOKENS_SEPARATOR.join(gmso_angle_type.member_classes)
+ ] = gmso_angle_type
+
+ return potentials
+
+ @classmethod
+ def load_from_etree(cls, root, existing):
+ attribs = root.attrib
+ children = []
+ for el in root.iterchildren():
+ if el.tag == "ParametersUnitDef":
+ children.append(ParametersUnitDef.load_from_etree(el))
+ elif el.tag == "AngleType":
+ angle_type = AngleType.load_from_etree(el)
+ identifier = tuple(
+ [angle_type.class1, angle_type.class2, angle_type.class3]
+ if angle_type.class1
+ else [angle_type.type1, angle_type.type2, angle_type.type3]
+ )
+ register_identifiers(existing, identifier, "AngleTypes")
+ children.append(angle_type)
+ return cls(children=children, **attribs)
+
+
+class TorsionType(GMSOXMLTag):
+ name: str = Field(
+ None, description="The name of the Dihedral/Improper type", alias="name"
+ )
+
+ class1: Optional[str] = Field(
+ None,
+ description="Class 1 for this Dihedral/Improper type",
+ alias="class1",
+ )
+
+ class2: Optional[str] = Field(
+ None,
+ description="Class 2 for this Dihedral/Improper type",
+ alias="class2",
+ )
+
+ class3: Optional[str] = Field(
+ None,
+ description="Class 3 for this Dihedral/Improper type",
+ alias="class3",
+ )
+
+ class4: Optional[str] = Field(
+ None,
+ description="Class 4 for this Dihedral/Improper type",
+ alias="class4",
+ )
+
+ type1: Optional[str] = Field(
+ None,
+ description="Type 1 for this Dihedral/Improper type",
+ alias="type1",
+ )
+
+ type2: Optional[str] = Field(
+ None,
+ description="Type 2 for this Dihedral/Improper type",
+ alias="type2",
+ )
+
+ type3: Optional[str] = Field(
+ None,
+ description="Type 3 for this Dihedral/Improper type",
+ alias="type3",
+ )
+
+ type4: Optional[str] = Field(
+ None,
+ description="Type 4 for this Dihedral/Improper type",
+ alias="type4",
+ )
+
+ children: List[Parameters] = Field(
+ ..., description="The parameters and their values", alias="children"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ children = []
+ attribs = pad_with_wildcards(root.attrib, 4)
+ for el in root.iterchildren():
+ if el.tag == "Parameters":
+ children.append(Parameters.load_from_etree(el))
+ return cls(children=children, **attribs)
+
+
+class DihedralType(TorsionType):
+ pass
+
+
+class ImproperType(TorsionType):
+ pass
+
+
+class TorsionTypes(GMSOXMLChild):
+ name: Optional[str] = Field(
+ None, description="The name for this angle types group", alias="name"
+ )
+
+ expression: str = Field(
+ ..., description="The expression for this angle types group"
+ )
+
+ children: List[Union[ParametersUnitDef, TorsionType]] = Field(
+ ...,
+ description="Children of this dihedral/improper types tag",
+ alias="children",
+ )
+
+ def to_gmso_potentials(self, default_units):
+ potentials = {"dihedral_types": {}, "improper_types": {}}
+ parameters_units = filter(
+ lambda c: isinstance(c, ParametersUnitDef), self.children
+ )
+ units = {
+ parameter_unit.parameter: u.Unit(parameter_unit.unit)
+ for parameter_unit in parameters_units
+ }
+
+ for torsion_type in filter(
+ lambda c: isinstance(c, (DihedralType, ImproperType)), self.children
+ ):
+ torsion_dict = torsion_type.dict(
+ by_alias=True,
+ exclude={
+ "children",
+ "type1",
+ "type2",
+ "type3",
+ "type4",
+ "class1",
+ "class2",
+ "class3",
+ "class4",
+ },
+ exclude_none=True,
+ )
+
+ if "expression" not in torsion_dict:
+ torsion_dict["expression"] = self.expression
+
+ if (
+ torsion_type.type1
+ and torsion_type.type2
+ and torsion_type.type3
+ and torsion_type.type4
+ ):
+ torsion_dict["member_types"] = (
+ torsion_type.type1,
+ torsion_type.type2,
+ torsion_type.type3,
+ torsion_type.type4,
+ )
+
+ elif (
+ torsion_type.class1
+ and torsion_type.class2
+ and torsion_type.class3
+ and torsion_type.class4
+ ):
+ torsion_dict["member_classes"] = (
+ torsion_type.class1,
+ torsion_type.class2,
+ torsion_type.class3,
+ torsion_type.class4,
+ )
+
+ torsion_dict["parameters"] = torsion_type.parameters(units)
+ torsion_dict["independent_variables"] = indep_vars(
+ torsion_dict["expression"],
+ frozenset(torsion_dict["parameters"]),
+ )
+ if isinstance(torsion_type, DihedralType):
+ gmso_torsion_type = GMSODihedralType(**torsion_dict)
+ key = "dihedral_types"
+ else:
+ gmso_torsion_type = GMSOImproperType(**torsion_dict)
+ key = "improper_types"
+
+ if gmso_torsion_type.member_types:
+ potentials[key][
+ FF_TOKENS_SEPARATOR.join(gmso_torsion_type.member_types)
+ ] = gmso_torsion_type
+ else:
+ potentials[key][
+ FF_TOKENS_SEPARATOR.join(gmso_torsion_type.member_classes)
+ ] = gmso_torsion_type
+
+ return potentials
+
+ @classmethod
+ def load_from_etree(cls, root, existing_dihedrals, existing_impropers):
+ attribs = root.attrib
+ children = []
+ child_loaders = {
+ "DihedralType": DihedralType,
+ "ImproperType": ImproperType,
+ }
+ for el in root.iterchildren():
+ if el.tag == "ParametersUnitDef":
+ children.append(ParametersUnitDef.load_from_etree(el))
+ elif el.tag == "DihedralType" or el.tag == "ImproperType":
+ tor_type = child_loaders[el.tag].load_from_etree(el)
+ identifier = tuple(
+ [
+ tor_type.class1,
+ tor_type.class2,
+ tor_type.class3,
+ tor_type.class4,
+ ]
+ if tor_type.class1
+ else [
+ tor_type.type1,
+ tor_type.type2,
+ tor_type.type3,
+ tor_type.type4,
+ ]
+ )
+ register_identifiers(
+ existing_impropers
+ if el.tag == "ImproperType"
+ else existing_dihedrals,
+ identifier,
+ el.tag + "s",
+ )
+ children.append(tor_type)
+
+ return cls(children=children, **attribs)
+
+
+class ImproperTypes(TorsionTypes):
+ pass
+
+
+class DihedralTypes(TorsionTypes):
+ pass
+
+
+class PairPotentialType(GMSOXMLTag):
+ name: str = Field(
+ ..., description="Name of this PairPotential Type", alias="name"
+ )
+
+ type1: Optional[str] = Field(
+ None, description="The type1 of this PairPotential Type", alias="type1"
+ )
+
+ type2: Optional[str] = Field(
+ None, description="The type2 of this PairPotential Type", alias="type2"
+ )
+
+ class1: Optional[str] = Field(
+ None,
+ description="The class1 of this PairPotential Type",
+ alias="class1",
+ )
+
+ class2: Optional[str] = Field(
+ None,
+ description="The class2 of this PairPotential Type",
+ alias="class2",
+ )
+
+ children: List[Parameters] = Field(
+ ..., description="The parameters and their values", alias="children"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root):
+ attribs = pad_with_wildcards(root.attrib, 2)
+ children = []
+ for el in root.iterchildren():
+ children.append(Parameters.load_from_etree(el))
+ return cls(children=children, **attribs)
+
+
+class PairPotentialTypes(GMSOXMLChild):
+ name: Optional[str] = Field(
+ None,
+ description="The name of this pair potential types group",
+ alias="name",
+ )
+
+ expression: str = Field(
+ ...,
+ description="The expression for this pair potential types group",
+ alias="expression",
+ )
+
+ children: List[Union[ParametersUnitDef, PairPotentialType]] = Field(
+ ..., description="The children", alias="children"
+ )
+
+ @classmethod
+ def load_from_etree(cls, root, existing):
+ attribs = root.attrib
+ children = []
+ for el in root.iterchildren():
+ if el.tag == "ParametersUnitDef":
+ children.append(ParametersUnitDef.load_from_etree(el))
+ elif el.tag == "PairPotentialType":
+ pptype = PairPotentialType.load_from_etree(el)
+ identifier = tuple(
+ [pptype.class1, pptype.class2]
+ if pptype.class1
+ else [pptype.type1, pptype.type2]
+ )
+ register_identifiers(existing, identifier, "PairPotentialTypes")
+ children.append(pptype)
+ return cls(children=children, **attribs)
+
+
+class Units(GMSOXMLTag):
+ energy: Optional[str] = Field(None, alias="energy")
+
+ distance: Optional[str] = Field(None, alias="distance")
+
+ mass: Optional[str] = Field(None, alias="mass")
+
+ charge: Optional[str] = Field(None, alias="charge")
+
+ temperature: Optional[str] = Field(None, alias="temperature")
+
+ angle: Optional[str] = Field(None, alias="angle")
+
+ time: Optional[str] = Field(None, alias="time")
+
+ @classmethod
+ def load_from_etree(cls, root):
+ attribs = root.attrib
+ return cls(**attribs)
+
+
+class FFMetaData(GMSOXMLChild):
+ children: List[Units] = Field([], alias="children")
+
+ electrostatics14Scale: float = Field(0.5, alias="electrostatics14Scale")
+
+ nonBonded14Scale: float = Field(0.5, alias="nonBonded14Scale")
+
+ combining_rule: str = Field("geometric", alias="combiningRule")
+
+ @classmethod
+ def load_from_etree(cls, root):
+ attribs = root.attrib
+ children = []
+ for unit in root.iterchildren():
+ children.append(Units.load_from_etree(unit))
+ return cls(children=children, **attribs)
+
+ def gmso_scaling_factors(self):
+ return self.dict(
+ include={"electrostatics14Scale", "nonBonded14Scale"},
+ exclude_none=True,
+ )
+
+ def get_default_units(self):
+ units_dict = {}
+ units = self.children[0].dict(by_alias=True, exclude_none=True)
+ for name, unit in units.items():
+ try:
+ units_dict[name] = u.Unit(unit)
+ except u.exceptions.UnitParseError:
+ units_dict[name] = getattr(u.physical_constants, unit)
+
+ return units_dict
+
+
+class ForceField(GMSOXMLTag):
+ name: str = Field(
+ "ForceField", description="Name of the ForceField", alias="name"
+ )
+
+ version: str = Field(
+ "1.0.0", description="The version of the ForceField", alias="version"
+ )
+
+ children: List[GMSOXMLChild] = Field(
+ ..., description="The children tags", alias="children"
+ )
+
+ def to_gmso_ff(self):
+ ff = GMSOForceField()
+ ff.name = self.name
+ ff.version = self.version
+ metadata = list(
+ filter(lambda child: isinstance(child, FFMetaData), self.children)
+ ).pop()
+ default_units = metadata.get_default_units()
+ ff.scaling_factors = metadata.gmso_scaling_factors()
+ ff.combining_rule = metadata.combining_rule
+ remaining_children = filter(
+ lambda c: not isinstance(c, (FFMetaData, Units)),
+ self.children,
+ )
+ ff_potentials = {}
+
+ for child in remaining_children:
+ if hasattr(child, "to_gmso_potentials"):
+ potentials = child.to_gmso_potentials(default_units)
+ for attr in potentials:
+ if attr in ff_potentials:
+ ff_potentials[attr].update(potentials[attr])
+ else:
+ ff_potentials[attr] = potentials[attr]
+
+ for attr in ff_potentials:
+ setattr(ff, attr, ff_potentials[attr])
+
+ return ff
+
+ @classmethod
+ def load_from_etree(cls, root) -> "ForceField":
+ attribs = root.attrib
+ children = []
+ identifiers_registry = get_identifiers_registry()
+ for el in root.iterchildren():
+ if el.tag == "FFMetaData":
+ children.append(FFMetaData.load_from_etree(el))
+ if el.tag == "AtomTypes":
+ children.append(
+ AtomTypes.load_from_etree(
+ el, identifiers_registry["AtomTypes"]
+ )
+ )
+ elif el.tag == "BondTypes":
+ children.append(
+ BondTypes.load_from_etree(
+ el, identifiers_registry["BondTypes"]
+ )
+ )
+ elif el.tag == "AngleTypes":
+ children.append(
+ AngleTypes.load_from_etree(
+ el, identifiers_registry["AngleTypes"]
+ )
+ )
+ elif el.tag == "DihedralTypes":
+ children.append(
+ DihedralTypes.load_from_etree(
+ el,
+ identifiers_registry["DihedralTypes"],
+ identifiers_registry["ImproperTypes"],
+ )
+ )
+ elif el.tag == "ImproperTypes":
+ children.append(
+ ImproperTypes.load_from_etree(
+ el,
+ identifiers_registry["DihedralTypes"],
+ identifiers_registry["ImproperTypes"],
+ )
+ )
+ elif el.tag == "PairPotentialTypes":
+ children.append(
+ PairPotentialTypes.load_from_etree(
+ el, identifiers_registry["PairPotentialTypes"]
+ )
+ )
+
+ return cls(children=children, **attribs)
diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua.xml b/forcefield_utilities/tests/files/propanol_Mie_ua.xml
new file mode 100644
index 0000000..9c4939b
--- /dev/null
+++ b/forcefield_utilities/tests/files/propanol_Mie_ua.xml
@@ -0,0 +1,116 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_angle_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_angle_type.xml
new file mode 100644
index 0000000..2923710
--- /dev/null
+++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_angle_type.xml
@@ -0,0 +1,130 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_bond_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_bond_type.xml
new file mode 100644
index 0000000..9333165
--- /dev/null
+++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_bond_type.xml
@@ -0,0 +1,134 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_dihedral_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_dihedral_type.xml
new file mode 100644
index 0000000..885d4c0
--- /dev/null
+++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_dihedral_type.xml
@@ -0,0 +1,154 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_improper_type.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_improper_type.xml
new file mode 100644
index 0000000..b9be0e1
--- /dev/null
+++ b/forcefield_utilities/tests/files/propanol_Mie_ua_duplicate_entries_improper_type.xml
@@ -0,0 +1,154 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/forcefield_utilities/tests/files/propanol_Mie_ua_list_wildcards.xml b/forcefield_utilities/tests/files/propanol_Mie_ua_list_wildcards.xml
new file mode 100644
index 0000000..3111feb
--- /dev/null
+++ b/forcefield_utilities/tests/files/propanol_Mie_ua_list_wildcards.xml
@@ -0,0 +1,139 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 0.0
+ 0.416955197548017
+ -0.0579667482245528
+ 0.37345529632637
+ 0.0
+ 0.0
+
+
+ 0.0
+ 180.0
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+
+
+
+
+
+
+ 0.0
+ 0.6955197548017
+ -0.023
+ 0.37345529632637
+ 0.0
+ 0.0
+
+
+ 0.0
+ 90.0
+ 0.0
+ 0.0
+ 0.0
+ 0.0
+
+
+
+
+
diff --git a/forcefield_utilities/tests/test_gmso_xml.py b/forcefield_utilities/tests/test_gmso_xml.py
new file mode 100644
index 0000000..c35ca11
--- /dev/null
+++ b/forcefield_utilities/tests/test_gmso_xml.py
@@ -0,0 +1,301 @@
+import pytest
+import unyt as u
+from gmso.tests.utils import get_path
+from lxml import etree
+from sympy import sympify
+
+from forcefield_utilities.tests.base_test import BaseTest
+from forcefield_utilities.tests.utils import get_test_file_path
+from forcefield_utilities.xml_loader import GMSOFFs
+
+
+class TestEthyleneFF(BaseTest):
+ @pytest.fixture(scope="session")
+ def ff_example_zero(self):
+ example_zero = get_path("ethylene.xml")
+ return GMSOFFs().load(example_zero).to_gmso_ff()
+
+ def test_metadata(self, ff_example_zero):
+ assert ff_example_zero.scaling_factors == {
+ "electrostatics14Scale": 0.5,
+ "nonBonded14Scale": 0.5,
+ }
+ assert ff_example_zero.combining_rule == "geometric"
+
+ def test_atom_types(self, ff_example_zero):
+ opls_143 = ff_example_zero.atom_types["opls_143"]
+ assert opls_143.name == "opls_143"
+ assert opls_143.charge == -0.23 * u.elementary_charge
+ assert opls_143.get_tag("element") == "C"
+
+ def test_bond_types(self, ff_example_zero):
+ btype_harmonic_1 = ff_example_zero.bond_types["opls_143~opls_143"]
+ assert str(btype_harmonic_1.expression) == "0.5*k*(r - r_eq)**2"
+ assert btype_harmonic_1.member_types == ("opls_143", "opls_143")
+ assert btype_harmonic_1.parameters == {
+ "k": u.unyt_quantity(459403.2, "kJ/mol/nm**2"),
+ "r_eq": u.unyt_quantity(0.134, "nm"),
+ }
+
+ assert btype_harmonic_1.name == "BondType-Harmonic-1"
+
+ def test_angle_types(self, ff_example_zero):
+ angle_type_harmonic_2 = ff_example_zero.angle_types[
+ "opls_144~opls_143~opls_144"
+ ]
+ assert (
+ str(angle_type_harmonic_2.expression)
+ == "0.5*k*(theta - theta_eq)**2"
+ )
+ assert angle_type_harmonic_2.member_types == (
+ "opls_144",
+ "opls_143",
+ "opls_144",
+ )
+ assert angle_type_harmonic_2.parameters == {
+ "k": u.unyt_quantity(292.88, "kJ/(mol*radian**2)"),
+ "theta_eq": u.unyt_quantity(2.0420352248, "radian"),
+ }
+
+ assert angle_type_harmonic_2.name == "AngleType-Harmonic-2"
+
+ def test_dihedral_types(self, ff_example_zero):
+ dihedral_type_rb_1 = ff_example_zero.dihedral_types[
+ "opls_144~opls_143~opls_143~opls_144"
+ ]
+
+ assert (
+ str(dihedral_type_rb_1.expression)
+ == "c_0 + c_1*cos(psi) + c_2*cos(psi)**2 + c_3*cos(psi)**3 + c_4*cos(psi)**4 + c_5*cos(psi)**5"
+ )
+ assert dihedral_type_rb_1.member_types == (
+ "opls_144",
+ "opls_143",
+ "opls_143",
+ "opls_144",
+ )
+
+ assert dihedral_type_rb_1.parameters == {
+ "c_0": u.unyt_quantity(58.576, "kJ/mol"),
+ "c_1": u.unyt_quantity(0.0, "kJ/mol"),
+ "c_2": u.unyt_quantity(-58.576, "kJ/mol"),
+ "c_3": u.unyt_quantity(0.0, "kJ/mol"),
+ "c_4": u.unyt_quantity(0.0, "kJ/mol"),
+ "c_5": u.unyt_quantity(0.0, "kJ/mol"),
+ }
+
+ assert dihedral_type_rb_1.name == "DihedralType-RyckaertBellemans-1"
+
+ def test_improper_types(self, ff_example_zero):
+ pass
+
+
+class TestTwoPropanolMIEFF(BaseTest):
+ @pytest.fixture(scope="session")
+ def propanol_ua_mie(self):
+ propanol_ua_mie_path = get_test_file_path("propanol_Mie_ua.xml")
+ return GMSOFFs().load(propanol_ua_mie_path).to_gmso_ff()
+
+ def test_metadata(self, propanol_ua_mie):
+ assert (
+ propanol_ua_mie.name
+ == "Mie two-propanol- This is for testing only and not for use for simulations "
+ )
+ assert propanol_ua_mie.scaling_factors == {
+ "electrostatics14Scale": 0.0,
+ "nonBonded14Scale": 0.0,
+ }
+ assert propanol_ua_mie.combining_rule == "geometric"
+
+ def test_atom_types(self, propanol_ua_mie):
+ ch3_sp3 = propanol_ua_mie.atom_types["CH3_sp3"]
+
+ assert ch3_sp3.name == "CH3_sp3"
+ assert ch3_sp3.atomclass == "CH3"
+ assert ch3_sp3.get_tag("element") == "_CH3"
+ assert u.allclose_units(ch3_sp3.charge, 0.0 * u.coulomb)
+ assert ch3_sp3.definition == "[_CH3;X1][_CH3,_HC]"
+ assert u.allclose_units(ch3_sp3.mass, 15.03500 * u.amu)
+ assert (
+ ch3_sp3.description
+ == "Alkane CH3, Mie using the k constant from Trappe-UA"
+ )
+ assert ch3_sp3.doi == "10.1021/jp984742e and 10.1021/jp972543+"
+ assert ch3_sp3.overrides == set()
+
+ assert ch3_sp3.expression == sympify(
+ "(n/(n-m)) * (n/m)**(m/(n-m)) * epsilon * ((sigma/r)**n - (sigma/r)**m)"
+ )
+
+ parameters = ch3_sp3.get_parameters()
+ assert u.allclose_units(
+ parameters["epsilon"], 0.194746017346801 * u.kcal / u.mol
+ )
+ assert u.allclose_units(parameters["sigma"], 3.751 * u.angstrom)
+ assert u.allclose_units(parameters["n"], 11 * u.dimensionless)
+ assert u.allclose_units(parameters["m"], 6 * u.dimensionless)
+
+ o = propanol_ua_mie.atom_types["O"]
+
+ assert o.name == "O"
+ assert o.atomclass == "O"
+ assert o.get_tag("element") == "O"
+ assert u.allclose_units(o.charge, -0.700 * u.elementary_charge)
+ assert o.definition == "OH"
+ assert u.allclose_units(o.mass, 15.99940 * u.amu)
+ assert o.description == "Oxygen in hydroxyl"
+ assert o.doi == "10.1021/jp003882x"
+ assert o.overrides == set()
+
+ assert o.expression == sympify(
+ "(n/(n-m)) * (n/m)**(m/(n-m)) * epsilon * ((sigma/r)**n - (sigma/r)**m)"
+ )
+
+ parameters = o.get_parameters()
+ assert u.allclose_units(
+ parameters["epsilon"], 0.184809996053596 * u.kcal / u.mol
+ )
+ assert u.allclose_units(parameters["sigma"], 3.021 * u.angstrom)
+ assert u.allclose_units(parameters["n"], 13 * u.dimensionless)
+ assert u.allclose_units(parameters["m"], 6 * u.dimensionless)
+
+ def test_bond_types(self, propanol_ua_mie):
+ bond_type_ch3_ch = propanol_ua_mie.bond_types["CH3~CH"]
+ assert bond_type_ch3_ch.name == "BondType_Harmonic_CH3_CH"
+ assert bond_type_ch3_ch.member_classes == ("CH3", "CH")
+ assert bond_type_ch3_ch.expression == sympify("k * (r-r_eq)**2")
+
+ parameters = bond_type_ch3_ch.get_parameters()
+ assert u.allclose_units(
+ parameters["k"], 1200.80305927342 * u.kcal / u.mol / u.angstrom**2
+ )
+ assert u.allclose_units(parameters["r_eq"], 1.5401 * u.angstrom)
+
+ def test_angle_types(self, propanol_ua_mie):
+ angle_type_ch3_ch_o = propanol_ua_mie.angle_types["CH3~CH~O"]
+
+ assert angle_type_ch3_ch_o.name == "AngleType_Harmonic_CH3_CH_O"
+ assert angle_type_ch3_ch_o.member_classes == ("CH3", "CH", "O")
+ assert angle_type_ch3_ch_o.expression == sympify(
+ "k * (theta - theta_eq)**2"
+ )
+
+ parameters = angle_type_ch3_ch_o.get_parameters()
+ assert u.allclose_units(
+ parameters["k"], 100.155094635497 * u.kcal / u.mol / u.radian**2
+ )
+ assert u.allclose_units(parameters["theta_eq"], 109.51 * u.degree)
+
+ def test_dihedral_types(self, propanol_ua_mie):
+ dihedral_type_ch3_ch_o_h = propanol_ua_mie.dihedral_types["CH3~CH~O~H"]
+
+ assert (
+ dihedral_type_ch3_ch_o_h.name
+ == "DihedralType_Periodic_Proper_CH3_CH_O_H"
+ )
+ assert dihedral_type_ch3_ch_o_h.member_classes == (
+ "CH3",
+ "CH",
+ "O",
+ "H",
+ )
+ assert dihedral_type_ch3_ch_o_h.expression == sympify(
+ "k0 + k1 * (1 + cos(1 * phi - phi_eq1)) + "
+ "k2 * (1 + cos(2 * phi - phi_eq2)) + "
+ "k3 * (1 + cos(3 * phi - phi_eq3)) + "
+ "k4 * (1 + cos(4 * phi - phi_eq4)) + "
+ "k5 * (1 + cos(5 * phi - phi_eq5))"
+ )
+
+ parameters = dihedral_type_ch3_ch_o_h.get_parameters()
+ assert u.allclose_units(parameters["k0"], 0.0 * u.kcal / u.mol)
+ assert u.allclose_units(
+ parameters["k1"], 0.416955197548017 * u.kcal / u.mol
+ )
+ assert u.allclose_units(
+ parameters["k2"], -0.0579667482245528 * u.kcal / u.mol
+ )
+ assert u.allclose_units(
+ parameters["k3"], 0.37345529632637 * u.kcal / u.mol
+ )
+ assert u.allclose_units(parameters["k4"], 0.0 * u.kcal / u.mol)
+ assert u.allclose_units(parameters["k5"], 0.0 * u.kcal / u.mol)
+
+ assert u.allclose_units(parameters["phi_eq1"], 0.0 * u.degree)
+ assert u.allclose_units(parameters["phi_eq2"], 180 * u.degree)
+ assert u.allclose_units(parameters["phi_eq3"], 0.0 * u.degree)
+ assert u.allclose_units(parameters["phi_eq4"], 0.0 * u.degree)
+ assert u.allclose_units(parameters["phi_eq5"], 0.0 * u.degree)
+
+
+class TestListParameters(BaseTest):
+ @pytest.fixture(scope="session")
+ def propanol_ua_mie_list(self):
+ propanol_ua_mie_path = get_test_file_path(
+ "propanol_Mie_ua_list_wildcards.xml"
+ )
+ return GMSOFFs().load(propanol_ua_mie_path).to_gmso_ff()
+
+ def test_dihedral_params(self, propanol_ua_mie_list):
+ dih_with_list = propanol_ua_mie_list.dihedral_types["CH3~CH~O~H"]
+ params = dih_with_list.get_parameters()
+ assert u.allclose_units(
+ params["phi_eq"], [0.0, 180.0, 0.0, 0.0, 0.0, 0.0] * u.degree
+ )
+ assert u.allclose_units(
+ params["k"],
+ [0.0, 0.4169552, -0.05796675, 0.3734553, 0.0, 0.0] * u.kcal / u.mol,
+ )
+
+ dih_with_list_wildcards = propanol_ua_mie_list.get_potential(
+ group="dihedral_type", key=["", "CH", "O", "H"]
+ )
+ params_wildcards_dihedrals = dih_with_list_wildcards.get_parameters()
+ assert u.allclose_units(
+ params_wildcards_dihedrals["phi_eq"],
+ [0.0, 90.0, 0.0, 0.0, 0.0, 0.0] * u.degree,
+ )
+ assert u.allclose_units(
+ params_wildcards_dihedrals["k"],
+ [0.0, 0.69551975, -0.023, 0.3734553, 0.0, 0.0] * u.kcal / u.mol,
+ )
+
+
+class TestDuplicateEntries:
+ def test_value_error_bond_types(self):
+ with pytest.raises(ValueError):
+ GMSOFFs().load(
+ get_test_file_path(
+ "propanol_Mie_ua_duplicate_entries_bond_type.xml"
+ )
+ )
+
+ def test_value_error_angle_types(self):
+ with pytest.raises(
+ ValueError, match=r"Duplicate identifier found for AngleTypes.*"
+ ):
+ GMSOFFs().load(
+ get_test_file_path(
+ "propanol_Mie_ua_duplicate_entries_angle_type.xml"
+ )
+ )
+
+ def test_value_error_dihedral_types(self):
+ with pytest.raises(
+ ValueError, match=r"Duplicate identifier found for DihedralTypes.*"
+ ):
+ GMSOFFs().load(
+ get_test_file_path(
+ "propanol_Mie_ua_duplicate_entries_dihedral_type.xml"
+ )
+ )
+
+ def test_value_error_improper_types(self):
+ with pytest.raises(
+ ValueError, match=r"Duplicate identifier found for ImproperTypes.*"
+ ):
+ GMSOFFs().load(
+ get_test_file_path(
+ "propanol_Mie_ua_duplicate_entries_improper_type.xml"
+ )
+ )
diff --git a/forcefield_utilities/tests/test_xml_loader.py b/forcefield_utilities/tests/test_xml_loader.py
index 251c8df..f9891b1 100644
--- a/forcefield_utilities/tests/test_xml_loader.py
+++ b/forcefield_utilities/tests/test_xml_loader.py
@@ -1,8 +1,9 @@
import pytest
from forcefield_utilities.tests.base_test import BaseTest
+from forcefield_utilities.tests.utils import get_test_file_path
from forcefield_utilities.utils import get_package_file_path
-from forcefield_utilities.xml_loader import FoyerFFs
+from forcefield_utilities.xml_loader import FoyerFFs, GMSOFFs
class TestXMLLoader(BaseTest):
@@ -10,6 +11,10 @@ class TestXMLLoader(BaseTest):
def foyer_xml_loader(self):
return FoyerFFs()
+ @pytest.fixture
+ def gmso_xml_loader(self):
+ return GMSOFFs()
+
def test_load_gaff(self, foyer_xml_loader):
foyer_xml_loader.get_ff("gaff")
assert "gaff" in foyer_xml_loader.loaded_ffs
@@ -56,3 +61,16 @@ def test_custom_register(self, foyer_xml_loader):
assert "benzene_lb" in foyer_xml_loader.loaded_ffs
assert loaded_id != id(foyer_xml_loader.load("benzene_lb"))
+
+ def test_different_loading_entries(self, foyer_xml_loader, gmso_xml_loader):
+ assert id(foyer_xml_loader.loaded_ffs) != id(gmso_xml_loader.loaded_ffs)
+ gmso_xml_loader.load(get_test_file_path("propanol_Mie_ua.xml"))
+ assert "propanol_Mie_ua" in gmso_xml_loader.loaded_ffs
+ assert "propanol_Mie_ua" not in foyer_xml_loader.loaded_ffs
+
+ def test_class_methods_gmso_ff(self):
+ gmso_xml_loader = GMSOFFs()
+ gmso_xml_loader.load(get_test_file_path("propanol_Mie_ua.xml"))
+ ff1 = gmso_xml_loader.get_ff("propanol_Mie_ua")
+ ff2 = gmso_xml_loader.load("propanol_Mie_ua")
+ assert id(ff1) == id(ff2)
diff --git a/forcefield_utilities/tests/utils.py b/forcefield_utilities/tests/utils.py
new file mode 100644
index 0000000..84fd646
--- /dev/null
+++ b/forcefield_utilities/tests/utils.py
@@ -0,0 +1,13 @@
+from pathlib import Path
+
+
+def get_test_file_path(filename):
+ """Give a filename, return its location in test files."""
+ file_path = Path(__file__).parent / "files" / filename
+
+ if not file_path.resolve().exists():
+ raise FileNotFoundError(
+ f"File {filename} not found in {file_path.parent}"
+ )
+
+ return str(file_path)
diff --git a/forcefield_utilities/utils.py b/forcefield_utilities/utils.py
index 04f4058..eb94245 100644
--- a/forcefield_utilities/utils.py
+++ b/forcefield_utilities/utils.py
@@ -50,3 +50,23 @@ def _deprecate_kwargs(kwargs, deprecated_kwargs):
DeprecationWarning,
3,
)
+
+
+def pad_with_wildcards(input_dictionary, max_len, wildcard="*"):
+ """Pad empty type or classes with wildcards"""
+ types = [f"type{j+1}" for j in range(max_len)]
+ classes = [f"class{j+1}" for j in range(max_len)]
+
+ if types[0] in input_dictionary:
+ for type_ in types:
+ value = input_dictionary[type_]
+ if isinstance(value, str) and value.strip() == "":
+ input_dictionary[type_] = wildcard
+
+ elif classes[0] in input_dictionary:
+ for class_ in classes:
+ value = input_dictionary[class_]
+ if isinstance(value, str) and value.strip() == "":
+ input_dictionary[class_] = wildcard
+
+ return input_dictionary
diff --git a/forcefield_utilities/xml_loader.py b/forcefield_utilities/xml_loader.py
index edda286..9aa93e0 100644
--- a/forcefield_utilities/xml_loader.py
+++ b/forcefield_utilities/xml_loader.py
@@ -1,11 +1,12 @@
-import importlib
-import os
+import abc
from pathlib import Path
from typing import Union
+from gmso.utils.ff_utils import _validate_schema as validate_gmso_schema
from lxml import etree
-from forcefield_utilities.foyer_xml import ForceField
+from forcefield_utilities.foyer_xml import ForceField as FoyerForceField
+from forcefield_utilities.gmso_xml import ForceField as GMSOForceField
from forcefield_utilities.utils import (
call_on_import,
deprecate_kwargs,
@@ -15,35 +16,28 @@
custom_forcefields = {}
-def _parse_foyer_xml(xml_path):
- """Return the foyer Forcefield object from the relative path ``xml_path`` inside the foyer package."""
- with open(xml_path) as ff_file:
- root = etree.parse(ff_file).getroot()
- return ForceField.load_from_etree(root)
-
-
-class FoyerFFs:
- """Object to provide methods to forcefields shipped with Foyer.
+class XMLLoader:
+ """Object to provide methods to forcefields shipped with Foyer/GMSO.
Attributes
__________
loaded_ffs : dict, keys are strings, values are the loaded forcefield
This is a place to store loaded xmls, which can be accessed through
- the getter by indexing the FoyerFFs object. xmls are only stored once,
+ the getter by indexing the object. xmls are only stored once,
and custom_xml is where a forcefield is stored from path. Note: This is a
class level attribute.
Methods
_______
get_ff(ffname="oplsaa"):
- Load and directly return the forcefield xml from Foyer or a path.
+ Load and directly return the forcefield xml from Foyer/GMSO or a path.
Parameters
__________
ffname : str
- Name of forcefield to load. See self.ff_registry attribute
- for loading method.
- load(ffname="oplsaa", rel_to_module=False):
- Load and return the forcefield xml from Foyer or a path.
+ Name of forcefield to load or path to load
+
+ load(ffname="oplsaa"):
+ Load and return the forcefield xml from Foyer/GMSO or a path.
Notes
-----
@@ -51,14 +45,20 @@ class level attribute.
`clear_loaded_ffs` can be used to clear the cache.
"""
- loaded_ffs = {}
- overwritten_custom_ffs = set()
+ loaded_ffs = None
+ overwritten_custom_ffs = None
+ search_foyer = False
+
+ @abc.abstractmethod
+ def load_xml(self, xml_path) -> Union[FoyerForceField, GMSOForceField]:
+ """Load the xml file"""
+ return NotImplemented
@classmethod
@deprecate_kwargs(deprecated_kwargs={"rel_to_module"})
def get_ff(
cls, ffname: Union[str, Path], rel_to_module: bool = False
- ) -> ForceField:
+ ) -> Union[FoyerForceField, GMSOForceField]:
"""Load and directly return the forcefield xml from Foyer or a path.
Parameters
__________
@@ -72,10 +72,12 @@ def get_ff(
ff : forcefield_utilities.foyer_xml.Forcefield
"""
loader = cls()
- return loader.load(ffname, rel_to_module)
+ return loader.load(ffname)
@deprecate_kwargs(deprecated_kwargs={"rel_to_module"})
- def load(self, ffname, rel_to_module=False):
+ def load(
+ self, ffname, rel_to_module=False
+ ) -> Union[FoyerForceField, GMSOForceField]:
"""Load and return the forcefield xml from Foyer or a path.
Parameters
__________
@@ -86,7 +88,7 @@ def load(self, ffname, rel_to_module=False):
Returns
______
- ff : forcefield_utilities.foyer_xml.Forcefield
+ ff : forcefield_utilities.foyer_xml.Forcefield or forcefield_utilties.gmso_xml.Forcefield
Notes
-----
@@ -118,29 +120,33 @@ def load(self, ffname, rel_to_module=False):
ff_path = Path(ffname)
if ffname in custom_forcefields:
- self.loaded_ffs[ffname] = _parse_foyer_xml(
+ self.loaded_ffs[ffname] = self.load_xml(
xml_path=custom_forcefields[ffname]
)
self.overwritten_custom_ffs.discard(ffname)
return self.loaded_ffs[ff_path.name]
if self._is_xml(ff_path):
- self.loaded_ffs[ff_path.stem] = _parse_foyer_xml(
+ self.loaded_ffs[ff_path.stem] = self.load_xml(
xml_path=ff_path.resolve()
)
- else:
+ elif self.search_foyer:
xml_path = get_package_file_path(
"foyer", f"forcefields/xml/{ffname}.xml"
)
- self.loaded_ffs[ff_path.stem] = _parse_foyer_xml(xml_path)
+ self.loaded_ffs[ff_path.stem] = self.load_xml(xml_path)
+ else:
+ raise FileNotFoundError(
+ f"{ffname} not found, it isn't registered forcefiled name or a XML file."
+ )
return self.loaded_ffs[ff_path.stem]
- def __getitem__(self, ffname) -> ForceField:
+ def __getitem__(self, ffname) -> Union[FoyerForceField, GMSOForceField]:
"""Get function for indexing by loaded forcefields."""
return self.loaded_ffs[ffname]
- def __getattr__(self, item) -> ForceField:
+ def __getattr__(self, item) -> Union[FoyerForceField, GMSOForceField]:
"""Accessor for loaded forcefields."""
if item in self.loaded_ffs:
return self.loaded_ffs[item]
@@ -152,7 +158,7 @@ def __getattr__(self, item) -> ForceField:
def register_custom_forcefield(
self, name: str, path_: Union[str, Path], overwrite: bool = True
) -> None:
- """Register a custom foyer forcefield's XML path to load.
+ """Register a custom foyer/gmso forcefield's XML path to load.
Parameters
----------
@@ -181,10 +187,42 @@ def _is_xml(path_: Path) -> bool:
return path_.suffix == ".xml"
+class FoyerFFs(XMLLoader):
+ """Utility class to load foyer forcefields."""
+
+ loaded_ffs = {}
+ overwritten_custom_ffs = set()
+ search_foyer = True
+
+ def load_xml(self, xml_path):
+ """Return the foyer Forcefield object from the relative path ``xml_path`` inside the foyer package."""
+ with open(xml_path) as ff_file:
+ root = etree.parse(ff_file).getroot()
+ return FoyerForceField.load_from_etree(root)
+
+
+class GMSOFFs(XMLLoader):
+ """Utility class to load gmso forcefields."""
+
+ loaded_ffs = {}
+ overwritten_custom_ffs = set()
+ search_foyer = False
+
+ def load_xml(self, xml_path):
+ """Return the gmso Forcefield object from the relative path ``xml_path`` for a gmso XML."""
+ with open(xml_path) as ff_file:
+ ff_etree = etree.parse(ff_file)
+ validate_gmso_schema(ff_etree)
+ root = ff_etree.getroot()
+ return GMSOForceField.load_from_etree(root)
+
+
@call_on_import
def register_gaff():
"""Include GAFF as part of FoyerFFs if antefoyer is installed locally."""
try:
+ import importlib
+
importlib.import_module("antefoyer")
except ImportError:
return