Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 19, 2024
1 parent d534af0 commit ee6b19f
Show file tree
Hide file tree
Showing 14 changed files with 261 additions and 195 deletions.
23 changes: 18 additions & 5 deletions src/aiida_sssp_workflow/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aiida.cmdline.params import options, types
from aiida.cmdline.utils import echo
from aiida.engine import ProcessBuilder, run_get_node, submit
from aiida.plugins import DataFactory, WorkflowFactory
from aiida.plugins import WorkflowFactory

from aiida_pseudo.data.pseudo.upf import UpfData
from aiida_sssp_workflow.cli import cmd_root
Expand All @@ -25,6 +25,7 @@

VerificationWorkChain = WorkflowFactory("sssp_workflow.verification")


def guess_properties_list(property: list) -> Tuple[List[str], str]:
# if the property is not specified, use the default list with all properties calculated.
# otherwise, use the specified properties.
Expand All @@ -43,21 +44,27 @@ def guess_properties_list(property: list) -> Tuple[List[str], str]:

return properties_list, extra_desc


def guess_is_convergence(properties_list: list) -> bool:
"""Check if it is a convergence test"""

return any([c for c in properties_list if c.startswith("convergence")])


def guess_is_full_convergence(properties_list: list) -> bool:
"""Check if all properties are run for convergence test"""

return len([c for c in properties_list if c.startswith("convergence")]) == len(DEFAULT_CONVERGENCE_PROPERTIES_LIST)
return len([c for c in properties_list if c.startswith("convergence")]) == len(
DEFAULT_CONVERGENCE_PROPERTIES_LIST
)


def guess_is_measure(properties_list: list) -> bool:
"""Check if it is a measure test"""

return any([c for c in properties_list if c.startswith("measure")])


def guess_is_ph(properties_list: list) -> bool:
"""Check if it has a measure test"""

Expand Down Expand Up @@ -175,15 +182,19 @@ def launch(
is_ph = guess_is_ph(properties_list)

if is_ph and not ph_code:
echo.echo_critical("ph_code must be provided since we run on it for phonon frequencies.")
echo.echo_critical(
"ph_code must be provided since we run on it for phonon frequencies."
)

if is_convergence and len(configuration) > 1:
echo.echo_critical(
"Only one configuration is allowed for convergence workflow."
)

if is_measure and not is_full_convergence:
echo.echo_warning("Full convergence tests are not run, so we use maximum cutoffs for transferability verification.")
echo.echo_warning(
"Full convergence tests are not run, so we use maximum cutoffs for transferability verification."
)

# Load the curent AiiDA profile and log to user
_profile = aiida.load_profile()
Expand Down Expand Up @@ -211,7 +222,9 @@ def launch(
clean_workdir=clean_workdir,
)

builder.metadata.label = f"({protocol} at {pw_code.computer.label} - {conf_label}) {pseudo.stem}"
builder.metadata.label = (
f"({protocol} at {pw_code.computer.label} - {conf_label}) {pseudo.stem}"
)
builder.metadata.description = f"""Calculation is run on protocol: {protocol}; on {pw_code.computer.label}; on configuration {conf_label}; on pseudo {pseudo.stem}."""

builder.pw_code = pw_code
Expand Down
1 change: 1 addition & 0 deletions src/aiida_sssp_workflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_default_mpi_options(
"withmpi": with_mpi,
}


def serialize_data(data):
from aiida.orm import (
AbstractCode,
Expand Down
23 changes: 12 additions & 11 deletions src/aiida_sssp_workflow/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiida_sssp_workflow.utils.pseudo import DualType, get_dual_type


def get_protocol(category: str, name: str | None=None):
def get_protocol(category: str, name: str | None = None):
"""Load and read protocol from faml file to a verbose dict
if name not set, return whole protocol."""
import_path = resources.path("aiida_sssp_workflow.protocol", f"{category}.yml")
Expand All @@ -17,24 +17,25 @@ def get_protocol(category: str, name: str | None=None):
else:
return protocol_dict

def generate_cutoff_list(protocol_name: str, element: str, pp_type: str) -> List[Tuple[int, int]]:
"""From the control protocol name, get the cutoff list
"""

def generate_cutoff_list(
protocol_name: str, element: str, pp_type: str
) -> List[Tuple[int, int]]:
"""From the control protocol name, get the cutoff list"""
match get_dual_type(pp_type, element):
case DualType.NC:
dual_type = 'nc_dual_scan'
dual_type = "nc_dual_scan"
case DualType.AUGLOW:
dual_type = 'nonnc_dual_scan'
dual_type = "nonnc_dual_scan"
case DualType.AUGHIGH:
dual_type = 'nonnc_high_dual_scan'
dual_type = "nonnc_high_dual_scan"

dual_scan_list = get_protocol('control', protocol_name)[dual_type]
dual_scan_list = get_protocol("control", protocol_name)[dual_type]
if len(dual_scan_list) > 0:
max_dual = int(max(dual_scan_list))
else:
max_dual = 8

ecutwfc_list = get_protocol('control', protocol_name)['wfc_scan']

return [(e, e*max_dual) for e in ecutwfc_list]
ecutwfc_list = get_protocol("control", protocol_name)["wfc_scan"]

return [(e, e * max_dual) for e in ecutwfc_list]
22 changes: 13 additions & 9 deletions src/aiida_sssp_workflow/utils/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,15 @@ class DualType(Enum):
AUGLOW = "charge augmentation low"
AUGHIGH = "charge augmentation high"


def get_dual_type(pp_type: str, element: str) -> DualType:
if element in HIGH_DUAL_ELEMENTS and pp_type != 'nc':
return DualType.AUGHIGH
elif pp_type == 'nc':
return DualType.NC
else:
return DualType.AUGLOW
if element in HIGH_DUAL_ELEMENTS and pp_type != "nc":
return DualType.AUGHIGH
elif pp_type == "nc":
return DualType.NC
else:
return DualType.AUGLOW


def extract_pseudo_info(pseudo_text: str) -> PseudoInfo:
"""Giving a pseudo, extract the pseudo info and return as a `PseudoInfo` object"""
Expand All @@ -143,17 +145,19 @@ def extract_pseudo_info(pseudo_text: str) -> PseudoInfo:
z_valence=upf_info["z_valence"],
)


def extract_pseudo_info_from_filename(filename: str) -> PseudoInfo:
"""We give standard filename for PP, so it now can be parsed"""
parts = filename.split('.')
parts = filename.split(".")

return PseudoInfo(
element=parts[0],
type=parts[1],
functional=parts[2],
z_valence=int(parts[3].split('_')[1])
z_valence=int(parts[3].split("_")[1]),
)


def _get_proper_dual(pp_info: PseudoInfo) -> int:
if pp_info.type == "nc":
dual = 4
Expand Down
2 changes: 0 additions & 2 deletions src/aiida_sssp_workflow/workflows/convergence/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def configuration(self):

return self.ctx.configuration


@property
def pseudos(self):
"""Syntax sugar for self.ctx.pseudos"""
Expand Down Expand Up @@ -279,7 +278,6 @@ def get_builder(
if ret := is_valid_cutoff_list(cutoff_list):
raise ValueError(ret)


builder.cutoff_list = orm.List(list=cutoff_list)
builder.clean_workdir = orm.Bool(clean_workdir)

Expand Down
24 changes: 11 additions & 13 deletions src/aiida_sssp_workflow/workflows/convergence/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def compute_xy(
report = ConvergenceReport.construct(**report_dict)

reference_node = orm.load_node(report.reference.uuid)
band_structure_r: orm.BandsData = reference_node.outputs.bands.band_structure
band_structure_r: orm.BandsData = reference_node.outputs.bands.band_structure
band_parameters_r: orm.Dict = reference_node.outputs.bands.band_parameters

bandsdata_r = {
Expand All @@ -173,7 +173,7 @@ def compute_xy(
}

# smearing width is from degauss
smearing = reference_node.inputs.bands.pw.parameters.get_dict()['SYSTEM']['degauss']
smearing = reference_node.inputs.bands.pw.parameters.get_dict()["SYSTEM"]["degauss"]
fermi_shift = reference_node.inputs.fermi_shift.value

# always do smearing on high bands and not include the spin since we didn't turn on the spin for all
Expand All @@ -189,15 +189,14 @@ def compute_xy(
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

node = orm.load_node(node_point.uuid)
node = orm.load_node(node_point.uuid)
band_structure_p: orm.BandsData = node.outputs.bands.band_structure
band_parameters_p: orm.Dict = node.outputs.bands.band_parameters


# The raw implementation of `get_bands_distance` is in `aiida_sssp_workflow/calculations/bands_distance.py`
bandsdata_p = {
"number_of_electrons": band_parameters_p["number_of_electrons"],
Expand All @@ -223,14 +222,13 @@ def compute_xy(
# eta_c is the y, others are write into as metadata
ys_eta_c.append(eta_c)
ys_max_diff_c.append(max_diff_c)


return {
'xs': xs,
'ys': ys_eta_c,
'ys_eta_c': ys_eta_c,
'ys_max_diff_c': ys_max_diff_c,
'metadata': {
'unit': unit,
}
"xs": xs,
"ys": ys_eta_c,
"ys_eta_c": ys_eta_c,
"ys_max_diff_c": ys_max_diff_c,
"metadata": {
"unit": unit,
},
}
20 changes: 10 additions & 10 deletions src/aiida_sssp_workflow/workflows/convergence/cohesive_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho) -> ProcessBuilder:

return builder


def compute_xy(
node: orm.Node,
) -> dict[str, Any]:
Expand All @@ -188,30 +189,29 @@ def compute_xy(

reference_node = orm.load_node(report.reference.uuid)
output_parameters_r: orm.Dict = reference_node.outputs.output_parameters
y_ref = output_parameters_r['cohesive_energy_per_atom']
y_ref = output_parameters_r["cohesive_energy_per_atom"]

xs = []
ys = []
for node_point in report.convergence_list:
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

node = orm.load_node(node_point.uuid)
output_parameters_p: orm.Dict = node.outputs.output_parameters

y = (output_parameters_p['cohesive_energy_per_atom'] - y_ref) / y_ref * 100
y = (output_parameters_p["cohesive_energy_per_atom"] - y_ref) / y_ref * 100
ys.append(y)

return {
'xs': xs,
'ys': ys,
'ys_relative_diff': ys,
'metadata': {
'unit': '%',
}
"xs": xs,
"ys": ys,
"ys_relative_diff": ys,
"metadata": {
"unit": "%",
},
}

20 changes: 10 additions & 10 deletions src/aiida_sssp_workflow/workflows/convergence/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho) -> ProcessBuilder:

return builder


def compute_xy(
node: orm.Node,
) -> dict[str, Any]:
Expand All @@ -139,37 +140,36 @@ def compute_xy(

reference_node = orm.load_node(report.reference.uuid)
output_parameters_r: orm.Dict = reference_node.outputs.output_parameters
ref_V0, ref_B0, ref_B1 = output_parameters_r['birch_murnaghan_results']


ref_V0, ref_B0, ref_B1 = output_parameters_r["birch_murnaghan_results"]

xs = []
ys_nu = []
for node_point in report.convergence_list:
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

node = orm.load_node(node_point.uuid)
output_parameters_p: orm.Dict = node.outputs.output_parameters

V0, B0, B1 = output_parameters_p['birch_murnaghan_results']
V0, B0, B1 = output_parameters_p["birch_murnaghan_results"]

y_nu = rel_errors_vec_length(ref_V0, ref_B0, ref_B1, V0, B0, B1)

ys_nu.append(y_nu)

return {
'xs': xs,
'ys': ys_nu,
'metadata': {
'unit': 'n/a',
}
"xs": xs,
"ys": ys_nu,
"metadata": {
"unit": "n/a",
},
}


# def compute_xy_epsilon(
# report: ConvergenceReport,
# ) -> dict[str, Any]:
Expand Down
20 changes: 10 additions & 10 deletions src/aiida_sssp_workflow/workflows/convergence/phonon_frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho):

return builder


def compute_xy(
node: orm.Node,
) -> dict[str, Any]:
Expand All @@ -190,7 +191,7 @@ def compute_xy(
if node_point.exit_status != 0:
# TODO: log to a warning file for where the node is not finished_okay
continue

x = node_point.wavefunction_cutoff
xs.append(x)

Expand Down Expand Up @@ -231,13 +232,12 @@ def compute_xy(
ys_relative_max_diff.append(relative_max_diff)

return {
'xs': xs,
'ys': ys_relative_diff,
'ys_relative_diff': ys_relative_diff,
'ys_omega_max': ys_omega_max,
'ys_relative_max_diff': ys_relative_max_diff,
'metadata': {
'unit_default': '%',
}
"xs": xs,
"ys": ys_relative_diff,
"ys_relative_diff": ys_relative_diff,
"ys_omega_max": ys_omega_max,
"ys_relative_max_diff": ys_relative_max_diff,
"metadata": {
"unit_default": "%",
},
}

Loading

0 comments on commit ee6b19f

Please sign in to comment.