Skip to content

Commit

Permalink
Make exposed inputs optional, change validation logic, change return …
Browse files Browse the repository at this point in the history
…logic
  • Loading branch information
zooks97 committed Jun 7, 2021
1 parent 1d391a6 commit c219666
Showing 1 changed file with 131 additions and 47 deletions.
178 changes: 131 additions & 47 deletions aiida_quantumespresso/workflows/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,42 @@ def validate_inputs(value, _):
- Check that the `Emin`, `Emax` and `DeltaE` inputs are the same for the `dos` and `projwfc` namespaces.
- Check that `Emin` and `Emax` are provided in case `align_to_fermi` is set to `True`.
"""
# pylint: disable=too-many-return-statements
# Check that either the `scf` input or `nscf.pw.parent_folder` is provided.
import warnings
if 'scf' in value and 'parent_folder' in value['nscf']['pw']:
warnings.warn(
'Both the `scf` and `nscf.pw.parent_folder` inputs were provided. The SCF calculation will '
'be run with the inputs provided in `scf` and the `nscf.pw.parent_folder` will be ignored.'
)
elif not 'scf' in value and not 'parent_folder' in value['nscf']['pw']:
return 'Specifying either the `scf` or `nscf.pw.parent_folder` input is required.'
if 'scf' not in value:
if 'nscf' not in value:
return 'Specifying either the `scf` or `nscf` with `nscf.pw.parent_folder` input is required.'
if 'parent_folder' not in value['nscf']['pw']:
return 'Specifying either the `scf` or `nscf.pw.parent_folder` input is required.'
else:
if 'nscf' not in value:
warnings.warn(
'Ony `scf` inputs were provided. This is not the recommended method for DOS/PDOS '
'calculations, as the `nscf` calculation gives a denser k-grid for better BZ integration.'
)
elif 'parent_folder' in value['nscf']['pw']:
warnings.warn(
'Both the `scf` and `nscf.pw.parent_folder` inputs were provided. The SCF calculation will '
'be run with the inputs provided in `scf` and the `nscf.pw.parent_folder` will be ignored.'
)

if 'dos' not in value and 'projwfc' not in value:
return 'Specifying either the `dos` or `projwfc` input is required.'

# This is really hacky
if 'serial_clean' in value:
if not value['serial_clean'].value and ('dos' not in value or 'projwfc' not in value):
return 'Cannot run in parallel if either `dos` or `projwfc` input is not provided.'
else:
if 'dos' not in value or 'projwfc' not in value:
return 'Cannot run in parallel if either `dos` or `projwfc` input is not provided.'

for par in ['Emin', 'Emax', 'DeltaE']:
if value['dos']['parameters']['DOS'].get(par, None) != value['projwfc']['parameters']['PROJWFC'].get(par, None):
return f'The `{par}`` parameter has to be equal for the `dos` and `projwfc` inputs.'
if 'dos' in value and 'projwfc' in value:
for par in ['Emin', 'Emax', 'DeltaE']:
if value['dos']['parameters']['DOS'].get(par,
None) != value['projwfc']['parameters']['PROJWFC'].get(par, None):
return f'The `{par}`` parameter has to be equal for the `dos` and `projwfc` inputs.'

if value.get('align_to_fermi', False):
for par in ['Emin', 'Emax']:
Expand Down Expand Up @@ -186,6 +209,7 @@ def clean_workchain_calcs(workchain):
ProjwfcCalculation = plugins.CalculationFactory('quantumespresso.projwfc')


# pylint: disable=too-many-public-methods
class PdosWorkChain(ProtocolMixin, WorkChain):
"""A WorkChain to compute Total & Partial Density of States of a structure, using Quantum Espresso."""

Expand Down Expand Up @@ -245,7 +269,9 @@ def define(cls, spec):
exclude=('clean_workdir', 'pw.structure'),
namespace_options={
'help': 'Inputs for the `PwBaseWorkChain` of the `nscf` calculation.',
'validator': validate_nscf
'validator': validate_nscf,
'required': False,
'populate_defaults': False
}
)
spec.expose_inputs(
Expand All @@ -255,7 +281,9 @@ def define(cls, spec):
namespace_options={
'help': ('Input parameters for the `dos.x` calculation. Note that the `Emin`, `Emax` and `DeltaE` '
'values have to match with those in the `projwfc` inputs.'),
'validator': validate_dos
'validator': validate_dos,
'required': False,
'populate_defaults': False
}
)
spec.expose_inputs(
Expand All @@ -265,7 +293,9 @@ def define(cls, spec):
namespace_options={
'help': ('Input parameters for the `projwfc.x` calculation. Note that the `Emin`, `Emax` and `DeltaE` '
'values have to match with those in the `dos` inputs.'),
'validator': validate_projwfc
'validator': validate_projwfc,
'required': False,
'populate_defaults': False
}
)
spec.inputs.validator = validate_inputs
Expand All @@ -276,16 +306,22 @@ def define(cls, spec):
cls.run_scf,
cls.inspect_scf,
),
cls.run_nscf,
cls.inspect_nscf,
if_(cls.should_run_nscf)(
cls.run_nscf,
cls.inspect_nscf
),
if_(cls.serial_clean)(
cls.run_dos_serial,
cls.inspect_dos_serial,
cls.run_projwfc_serial,
cls.inspect_projwfc_serial
).else_(
if_(cls.should_run_dos)(
cls.run_dos_serial,
cls.inspect_dos_serial,
),
if_(cls.should_run_projwfc)(
cls.run_projwfc_serial,
cls.inspect_projwfc_serial
)
).else_( # Should only get here if both `dos` and `projwfc` inputs provided
cls.run_pdos_parallel,
cls.inspect_pdos_parallel,
cls.inspect_pdos_parallel
),
cls.results,
)
Expand Down Expand Up @@ -316,45 +352,68 @@ def get_protocol_filepath(cls):

@classmethod
def get_builder_from_protocol(
cls, pw_code, dos_code, projwfc_code, structure, protocol=None, overrides=None, **kwargs
cls,
pw_code,
structure,
dos_code=None,
projwfc_code=None,
do_scf=True,
do_nscf=True,
protocol=None,
overrides=None,
**kwargs
):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
:param pw_code: the ``Code`` instance configured for the ``quantumespresso.pw`` plugin.
:param dos_code: the ``Code`` instance configured for the ``quantumespresso.dos`` plugin.
:param projwfc_code: the ``Code`` instance configured for the ``quantumespresso.projwfc`` plugin.
:param structure: the ``StructureData`` instance to use.
:param dos_code: the ``Code`` instance configured for the ``quantumespresso.dos`` plugin, if not specified, not
run.
:param projwfc_code: the ``Code`` instance configured for the ``quantumespresso.projwfc`` plugin, if not
specified, not run.
:param protocol: protocol to use, if not specified, the default will be used.
:param overrides: optional dictionary of inputs to override the defaults of the protocol.
:param kwargs: additional keyword arguments that will be passed to the ``get_builder_from_protocol`` of all the
sub processes that are called by this workchain.
:return: a process builder instance with all inputs defined ready for launch.
"""


inputs = cls.get_protocol_inputs(protocol, overrides)

args = (pw_code, structure, protocol)
scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs)
scf['pw'].pop('structure', None)
scf.pop('clean_workdir', None)
nscf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('nscf', None), **kwargs)
nscf['pw'].pop('structure', None)
nscf['pw']['parameters']['SYSTEM'].pop('smearing', None)
nscf['pw']['parameters']['SYSTEM'].pop('degauss', None)
nscf.pop('clean_workdir', None)
if do_scf:
scf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('scf', None), **kwargs)
scf['pw'].pop('structure', None)
scf.pop('clean_workdir', None)
if do_nscf:
nscf = PwBaseWorkChain.get_builder_from_protocol(*args, overrides=inputs.get('nscf', None), **kwargs)
nscf['pw'].pop('structure', None)
nscf['pw']['parameters']['SYSTEM'].pop('smearing', None)
nscf['pw']['parameters']['SYSTEM'].pop('degauss', None)
nscf.pop('clean_workdir', None)

builder = cls.get_builder()
builder.structure = structure
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
builder.scf = scf
builder.nscf = nscf
builder.dos.code = dos_code # pylint: disable=no-member
builder.dos.parameters = orm.Dict(dict=inputs.get('dos', {}).get('parameters')) # pylint: disable=no-member
builder.dos.metadata = inputs.get('dos', {}).get('metadata') # pylint: disable=no-member
builder.projwfc.code = projwfc_code # pylint: disable=no-member
builder.projwfc.parameters = orm.Dict(dict=inputs.get('projwfc', {}).get('parameters')) # pylint: disable=no-member
builder.projwfc.metadata = inputs.get('projwfc', {}).get('metadata') # pylint: disable=no-member
if do_scf:
builder.scf = scf
else:
builder.pop('scf')
if do_nscf:
builder.nscf = nscf
else:
builder.pop('nscf')
if dos_code is not None:
builder.dos.code = dos_code # pylint: disable=no-member
builder.dos.parameters = orm.Dict(dict=inputs.get('dos', {}).get('parameters')) # pylint: disable=no-member
builder.dos.metadata = inputs.get('dos', {}).get('metadata') # pylint: disable=no-member
else:
builder.pop('dos')
if projwfc_code is not None:
builder.projwfc.code = projwfc_code # pylint: disable=no-member
builder.projwfc.parameters = orm.Dict(dict=inputs.get('projwfc', {}).get('parameters')) # pylint: disable=no-member
builder.projwfc.metadata = inputs.get('projwfc', {}).get('metadata') # pylint: disable=no-member
else:
builder.pop('projwfc')

return builder

Expand Down Expand Up @@ -394,12 +453,23 @@ def run_scf(self):
def inspect_scf(self):
"""Verify that the SCF calculation finished successfully."""
workchain = self.ctx.workchain_scf
if not workchain.is_finished_ok:

acceptable_statuses = ['WARNING_ELECTRONIC_CONVERGENCE_NOT_REACHED']

if workchain.is_excepted or workchain.is_killed:
self.report('SCF PwBaseWorkChain was excepted or killed')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_SCF

if workchain.is_failed and workchain.exit_status not in PwBaseWorkChain.get_exit_statuses(acceptable_statuses):
self.report(f'SCF PwBaseWorkChain failed with exit status {workchain.exit_status}')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_SCF

self.ctx.scf_parent_folder = workchain.outputs.remote_folder

def should_run_nscf(self):
"""Return whether the work chain should run an NSCF calculation."""
return 'nscf' in self.inputs

def run_nscf(self):
"""Run an NSCF calculation, to generate eigenvalues with a denser k-point mesh.
Expand Down Expand Up @@ -447,10 +517,17 @@ def inspect_nscf(self):
self.ctx.nscf_parent_folder = workchain.outputs.remote_folder
self.ctx.nscf_fermi = workchain.outputs.output_parameters.dict.fermi_energy

def should_run_dos(self):
"""Return whether the work chain should run a DOS calculation."""
return 'dos' in self.inputs

def _generate_dos_inputs(self):
"""Run DOS calculation, to generate total Densities of State."""
dos_inputs = AttributeDict(self.exposed_inputs(DosCalculation, 'dos'))
dos_inputs.parent_folder = self.ctx.nscf_parent_folder
if 'nscf' in self.inputs:
dos_inputs.parent_folder = self.ctx.nscf_parent_folder
else:
dos_inputs.parent_folder = self.ctx.scf_parent_folder
dos_parameters = self.inputs.dos.parameters.get_dict()

if dos_parameters.pop('align_to_fermi', False):
Expand Down Expand Up @@ -499,6 +576,10 @@ def inspect_dos_serial(self):
if clean_calcjob_remote(calculation):
self.report(f'cleaned remote folder of DosCalculation<{calculation.pk}>')

def should_run_projwfc(self):
"""Return whether the work chain should run a projwfc calculation."""
return 'projwfc' in self.inputs

def run_projwfc_serial(self):
"""Run Projwfc calculation."""
projwfc_inputs = self._generate_projwfc_inputs()
Expand Down Expand Up @@ -561,9 +642,12 @@ def results(self):
"""Attach the desired output nodes directly as outputs of the workchain."""
self.report('workchain successfully completed')

self.out_many(self.exposed_outputs(self.ctx.workchain_nscf, PwBaseWorkChain, namespace='nscf'))
self.out_many(self.exposed_outputs(self.ctx.calc_dos, DosCalculation, namespace='dos'))
self.out_many(self.exposed_outputs(self.ctx.calc_projwfc, ProjwfcCalculation, namespace='projwfc'))
if 'nscf' in self.inputs:
self.out_many(self.exposed_outputs(self.ctx.workchain_nscf, PwBaseWorkChain, namespace='nscf'))
if 'dos' in self.inputs:
self.out_many(self.exposed_outputs(self.ctx.calc_dos, DosCalculation, namespace='dos'))
if 'projwfc' in self.inputs:
self.out_many(self.exposed_outputs(self.ctx.calc_projwfc, ProjwfcCalculation, namespace='projwfc'))

def on_terminated(self):
"""Clean the working directories of all child calculations if `clean_workdir=True` in the inputs."""
Expand Down

0 comments on commit c219666

Please sign in to comment.