Skip to content

Commit

Permalink
Feat: pop None inputs specified in overrides
Browse files Browse the repository at this point in the history
Fixes #653

Add the possibility of popping input namespaces by specifying
None in the override for the specific namespace. A decorator
is added that generalize the concept to any implementation of
get_builder_from_protocol.
  • Loading branch information
bastonero committed Mar 28, 2024
1 parent ae7d248 commit cc5ea02
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/aiida_quantumespresso/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Decorators for several purposes."""


def remove_none_overrides(func):
def recursively_remove_nones(value):
"""Recursively remove keys with None values from dictionaries."""
if isinstance(value, dict):
return {k: recursively_remove_nones(v) for k, v in value.items() if v is not None}
return value

def remove_keys_from_builder(builder, keys, path=()):
"""Recursively remove specified keys from the builder based on a path."""
if not keys:
return
current_level = keys.pop(0)
if hasattr(builder, current_level):
if keys:
next_attr = getattr(builder, current_level)
remove_keys_from_builder(next_attr, keys, path + (current_level,))
else:
delattr(builder, current_level)

def wrapper(*args, **kwargs):
if 'overrides' in kwargs and kwargs['overrides'] is not None:
original_overrides = kwargs['overrides']

# Identify paths to keys with None values to be removed
paths_to_remove = []
def find_paths(value, path=()):
if isinstance(value, dict):
for k, v in value.items():
if v is None:
paths_to_remove.append(path + (k,))
else:
find_paths(v, path + (k,))
find_paths(original_overrides)

# Recursively remove keys with None values from overrides
cleaned_overrides = recursively_remove_nones(original_overrides)
kwargs['overrides'] = cleaned_overrides

# Call the original function to get the builder
builder = func(*args, **kwargs)

# Remove specified keys from the builder
for path in paths_to_remove:
remove_keys_from_builder(builder, list(path))

return builder
else:
return func(*args, **kwargs)

return wrapper
2 changes: 2 additions & 0 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
from aiida_quantumespresso.utils.decorators import remove_none_overrides

from ..protocols.utils import ProtocolMixin

Expand Down Expand Up @@ -120,6 +121,7 @@ def get_protocol_filepath(cls):
return files(pw_protocols) / 'bands.yaml'

@classmethod
@remove_none_overrides
def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=None, options=None, **kwargs):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
Expand Down
2 changes: 2 additions & 0 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import create_kpoints_from_distance
from aiida_quantumespresso.common.types import ElectronicType, RestartType, SpinType
from aiida_quantumespresso.utils.defaults.calculation import pw as qe_defaults
from aiida_quantumespresso.utils.decorators import remove_none_overrides

from ..protocols.utils import ProtocolMixin

Expand Down Expand Up @@ -103,6 +104,7 @@ def get_protocol_filepath(cls):
return files(pw_protocols) / 'base.yaml'

@classmethod
@remove_none_overrides
def get_builder_from_protocol(
cls,
code,
Expand Down
26 changes: 26 additions & 0 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,29 @@ def test_options(fixture_code, generate_structure):
builder.bands.pw.metadata, # pylint: disable=no-member
):
assert subspace['options']['queue_name'] == queue_name, subspace


def test_pop_none_overrides(fixture_code, generate_structure):
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'relax': {'base_final_scf':None}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'base_final_scf' not in builder['relax'] # pylint: disable=no-member

overrides = {'relax': None}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'relax' not in builder # pylint: disable=no-member

overrides = {'relax': {'base':{'pw':{'parameters':{'SYSTEM':{'ecutwfc': None}}}}}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'ecutwfc' in builder['relax']['base']['pw']['parameters']['SYSTEM'] # pylint: disable=no-member

overrides = {'relax': {'base':{'pw':{'parameters': None}}}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'parameters' not in builder['relax']['base']['pw'] # pylint: disable=no-member
11 changes: 11 additions & 0 deletions tests/workflows/protocols/pw/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,14 @@ def test_options(fixture_code, generate_structure):

assert metadata['options']['queue_name'] == queue_name
assert metadata['options']['withmpi'] == withmpi


def test_pop_none_overrides(fixture_code, generate_structure):
"""Test popping `None` input overrides specified in ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'kpoints_distance': None}
builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert 'kpoints_distance' not in builder # pylint: disable=no-member

0 comments on commit cc5ea02

Please sign in to comment.