Skip to content

Commit

Permalink
Merge pull request #17 from oashour/tests
Browse files Browse the repository at this point in the history
Refactor PWxml
  • Loading branch information
oashour authored Sep 23, 2024
2 parents 8646d70 + 5824cd5 commit 48e2237
Show file tree
Hide file tree
Showing 15 changed files with 1,192 additions and 1,031 deletions.
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ plugins:
options:
show_source: true
members_order: source
docstring_section_style: table
docstring_section_style: spacy
merge_init_into_class: true
docstring_options:
ignore_init_summary: true
Expand All @@ -83,6 +83,7 @@ markdown_extensions:
pygments_lang_class: true
- pymdownx.arithmatex:
generic: true
- tables:

extra_javascript:
- js/katex.js
Expand Down
109 changes: 67 additions & 42 deletions pymatgen/io/espresso/inputs/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
This module defines the base input file classes
This module defines the base input file classes.
"""

import logging
import os
import pathlib
import re
import warnings
Expand All @@ -16,14 +17,40 @@
from pymatgen.io.espresso.utils import parse_pwvals


class CardOptions(Enum):
"""Enum type of all supported options for a PWin card."""

def __str__(self) -> str:
return str(self.value)

def __repr__(self) -> str:
return self.__str__()

def __eq__(self, value: object) -> bool:
if isinstance(value, str):
return self.value.lower() == value.lower()
return self.value.lower() == value.value.lower()

@classmethod
def from_string(cls, s: str):
"""
:param s: String
:return: SupportedOptions
"""
for m in cls:
if m.value.lower() == s.lower():
return m
raise ValueError(f"Can't interpret option {s}.")


class BaseInputFile(ABC, MSONable):
"""
Abstract Base class for input files
"""

_indent = 2

def __init__(self, namelists, cards):
def __init__(self, namelists: list[dict[str, any]], cards: list["InputCard"]):
namelist_names = [nml.value.name for nml in self.namelist_classes]
self.namelists = OrderedDict(
{name: namelists.get(name, None) for name in namelist_names}
Expand Down Expand Up @@ -54,14 +81,14 @@ def card_classes(self):
"""All supported cards as a SupportedCards enum"""
pass

def _make_getter(self, name):
def _make_getter(self, name: str):
"""Returns a getter function for a property with name `name`"""
if name in [n.value.name for n in self.namelist_classes]:
return lambda self: self.namelists[name]
elif name in [c.value.name for c in self.card_classes]:
return lambda self: self.cards[name]

def _make_setter(self, name):
def _make_setter(self, name: str):
"""Returns a setter function for a property with name `name`"""
if name in [n.value.name for n in self.namelist_classes]:

Expand All @@ -84,15 +111,15 @@ def setter(self, value):

return setter

def _make_deleter(self, name):
def _make_deleter(self, name: str):
"""Returns a deleter function for a property with name `name`"""
if name in [n.value.name for n in self.namelist_classes]:
return lambda self: self.namelists.__setitem__(name, None)
elif name in [c.value.name for c in self.card_classes]:
return lambda self: self.cards.__setitem__(name, None)

@classmethod
def from_file(cls, filename):
def from_file(cls, filename: os.PathLike | str) -> "BaseInputFile":
"""
Reads an inputfile from file
Expand All @@ -115,7 +142,7 @@ def from_file(cls, filename):
return cls(namelists, cards)

@classmethod
def _parse_cards(cls, pwi_str):
def _parse_cards(cls, pwi_str: str) -> dict[str, "InputCard"]:
card_strings = pwi_str.rsplit("/", 1)[1].split("\n")
card_strings = [c for c in card_strings if c]
card_idx = [
Expand All @@ -132,7 +159,7 @@ def _parse_cards(cls, pwi_str):

return cards

def validate(self):
def validate(self) -> bool:
"""
Very basic validation for the input file.
Currently only checks that required namelists and cards are present.
Expand Down Expand Up @@ -176,22 +203,33 @@ def __str__(self):

return string

def to_file(self, filename, indent=2):
def to_file(self, filename: os.PathLike | str, indent: int = 2):
"""
Write the input file to a file
Write the input file to a file.
Args:
filename: path to file
indent: number of spaces to use for indentation
"""
self._indent = indent
with open(filename, "wb") as f:
f.write(self.__str__().encode("ascii"))


class InputNamelist(ABC, OrderedDict):
"""
Abstract Base class for namelists in input files
"""

indent = 2

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __str__(self):
"""
Convert namelist to string
"""
nl = f90nml.Namelist({self.name: self})
nl.indent = self.indent * " "
string = str(nl)
Expand All @@ -209,9 +247,22 @@ def required(self):


class InputCard(ABC):
"""
Abstract Base class for cards in input files
Args:
option (str): The option for the card (e.g., "RELAX")
body (list): The body of the card
"""

indent = 2

def __init__(self, option, body):
def __init__(self, option: str | CardOptions, body: str):
"""
Args:
option (str): The option for the card (e.g., "RELAX")
body (list): The body of the card
"""
if isinstance(option, str):
option = self.opts.from_string(option)
self.option = option
Expand Down Expand Up @@ -250,7 +301,7 @@ def __str__(self):
return self.get_header() + self.get_body(" " * self.indent)

# TODO: this should become an abstract method when all cards are implemented
def get_body(self, indent):
def get_body(self, indent: str) -> str:
"""
Convert card body to string
This implementation is for generic (i.e., not fully implemented) cards
Expand All @@ -265,7 +316,7 @@ def body(self):
return self.get_body(self.indent)

@classmethod
def from_string(cls, s: str):
def from_string(cls, s: str) -> "InputCard":
"""
Create card object from string
This implementation is for generic (i.e., not fully implemented) cards
Expand All @@ -274,7 +325,7 @@ def from_string(cls, s: str):
return cls(option, body)

@classmethod
def get_option(cls, option):
def get_option(cls, option: str) -> CardOptions:
"""Initializes a card's options"""
if option is not None:
return cls.opts.from_string(option)
Expand All @@ -285,7 +336,7 @@ def get_option(cls, option):
return cls.default_option

@classmethod
def split_card_string(cls, s: str):
def split_card_string(cls, s: str) -> tuple[str, list]:
"""
Splits a card into an option and a list of values of the correct type.
:param s: String containing a card (as it would appear in a PWin file)
Expand All @@ -310,7 +361,7 @@ def split_card_string(cls, s: str):
option = None
return cls.get_option(option), parse_pwvals(body)

def get_header(self):
def get_header(self) -> str:
"""Gets a card's header as a string"""
if self.name is None:
return ""
Expand All @@ -320,32 +371,6 @@ def get_header(self):
return header


class CardOptions(Enum):
"""Enum type of all supported options for a PWin card."""

def __str__(self) -> str:
return str(self.value)

def __repr__(self) -> str:
return self.__str__()

def __eq__(self, value: object) -> bool:
if isinstance(value, str):
return self.value.lower() == value.lower()
return self.value.lower() == value.value.lower()

@classmethod
def from_string(cls, s: str):
"""
:param s: String
:return: SupportedOptions
"""
for m in cls:
if m.value.lower() == s.lower():
return m
raise ValueError(f"Can't interpret option {s}.")


class SupportedInputs(Enum):
"""Enum type of all supported input cards and namelists."""

Expand Down
Loading

0 comments on commit 48e2237

Please sign in to comment.