Skip to content

Commit

Permalink
Test on the main branche of aiida-core
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jun 21, 2024
1 parent 5c18163 commit 134cd64
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
pip install .[dev]
pip uninstall --yes aiida-core
git clone --depth 1 https://github.com/aiidateam/aiida-core.git
uv pip install ./aiida-core
pip install ./aiida-core
- name: Check aiida-core version is the edget ('post' in version tag)
run: |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ filterwarnings = [
'ignore:Object of type .* not in session, .* operation along .* will not proceed:sqlalchemy.exc.SAWarning',
'ignore:The `Code` class is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning',
'ignore:`CalcJobNode.*` is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning',
'ignore:`WorkChainNode.get_outgoing` is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning',
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
Expand Down
6 changes: 6 additions & 0 deletions src/aiida_sssp_workflow/workflows/convergence/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ def run_reference(self):
ecutwfc, ecutrho = round(ecutwfc), round(ecutrho)
builder = self.prepare_evaluate_builder(ecutwfc=ecutwfc, ecutrho=ecutrho)

# Add link to the called workchain by '{ecutwfc}_{ecutrho}'
builder.metadata.call_link_label = f"cutoffs_{ecutwfc}_{ecutrho}"

running = self.submit(builder)
running.base.extras.set("wavefunction_cutoff", ecutwfc)
running.base.extras.set("charge_density_cutoff", ecutrho)
Expand Down Expand Up @@ -351,6 +354,9 @@ def run_convergence(self):
ecutwfc, ecutrho = round(ecutwfc), round(ecutrho)
builder = self.prepare_evaluate_builder(ecutwfc=ecutwfc, ecutrho=ecutrho)

# Add link to the called workchain by '{ecutwfc}_{ecutrho}'
builder.metadata.call_link_label = f"cutoffs_{ecutwfc}_{ecutrho}"

running = self.submit(builder)
self.report(
f"launching fix ecutrho={ecutrho} [ecutwfc={ecutwfc}] {running.process_label}<{running.pk}>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho):

# pw calculation
builder.scf.kpoints_distance = orm.Float(protocol["kpoints_distance"])
builder.scf.metadata.call_link_label = "phonon_frequencies_scf"
builder.scf.metadata.call_link_label = "scf"
builder.scf.pw["code"] = self.inputs.pw_code
builder.scf.pw["pseudos"] = self.pseudos
builder.scf.pw["parameters"] = orm.Dict(dict=pw_parameters)
Expand All @@ -159,7 +159,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho):
}
}

builder.phonon.metadata.call_link_label = "phonon_frequencies_ph"
builder.phonon.metadata.call_link_label = "ph"
builder.phonon["qpoints"] = qpoints
builder.phonon.ph["code"] = self.inputs.ph_code
builder.phonon.ph["parameters"] = orm.Dict(dict=ph_parameters)
Expand Down
117 changes: 64 additions & 53 deletions tests/workflows/convergence/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,40 +37,45 @@ def test_caching_bands(

# check the first scf of reference
# The pw calculation
# XXX: use link label to find the expected calcjob node
source_ref_wf = [
p
for p in source_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
source_scf_calcjob_node = source_ref_wf.called[1].called[0].called[1]
source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120")
first_run = source_ref_wf.get_outgoing().get_node_by_label(
"bands_with_factor_3"
)
source_scf_calcjob_node = (
first_run.get_outgoing().get_node_by_label("scf").called[1]
)
assert (
source_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) is None
)

# check the first bands of reference was cached
# The pw calculation
source_band_calcjob_node = source_ref_wf.called[1].called[1].called[0]
source_band_calcjob_node = (
first_run.get_outgoing().get_node_by_label("bands").called[0]
)
assert (
source_band_calcjob_node.base.extras.get("_aiida_cached_from", None) is None
)

# Run again and check it is using caching
_, cached_node = run_get_node(bands_builder)
cached_ref_wf = [
p
for p in cached_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
cached_scf_calcjob_node = cached_ref_wf.called[1].called[0].called[1]
cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120")
first_run = cached_ref_wf.get_outgoing().get_node_by_label(
"bands_with_factor_3"
)
cached_scf_calcjob_node = (
first_run.get_outgoing().get_node_by_label("scf").called[1]
)

assert (
cached_scf_calcjob_node.base.extras.get("_aiida_cached_from", None)
== source_scf_calcjob_node.uuid
)
assert not cached_scf_calcjob_node.base.caching.is_valid_cache

cached_band_calcjob_node = cached_ref_wf.called[1].called[1].called[0]
cached_band_calcjob_node = (
first_run.get_outgoing().get_node_by_label("bands").called[0]
)

assert (
cached_band_calcjob_node.base.extras.get("_aiida_cached_from", None)
Expand Down Expand Up @@ -112,30 +117,30 @@ def test_caching_phonon_frequencies(

# check the first scf of reference
# The pw calculation
source_ref_wf = [
p
for p in source_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
source_scf_calcjob_node = source_ref_wf.called[0].called[1]
source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120")

source_scf_calcjob_node = (
source_ref_wf.get_outgoing().get_node_by_label("scf").called[1]
)
assert (
source_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) is None
)

# The ph calculation
source_ph_calcjob_node = source_ref_wf.called[1].called[0]
source_ph_calcjob_node = (
source_ref_wf.get_outgoing().get_node_by_label("ph").called[0]
)
assert (
source_ph_calcjob_node.base.extras.get("_aiida_cached_from", None) is None
)

# Run again and check it is using caching
_, cached_node = run_get_node(phonon_frequencies_builder)
cached_ref_wf = [
p
for p in cached_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
cached_scf_calcjob_node = cached_ref_wf.called[0].called[1]
cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120")

cached_scf_calcjob_node = (
cached_ref_wf.get_outgoing().get_node_by_label("scf").called[1]
)

assert (
cached_scf_calcjob_node.base.extras.get("_aiida_cached_from", None)
Expand All @@ -144,7 +149,9 @@ def test_caching_phonon_frequencies(
assert not cached_scf_calcjob_node.base.caching.is_valid_cache

# Run again and check it is using caching
cached_ph_calcjob_node = cached_ref_wf.called[1].called[0]
cached_ph_calcjob_node = (
cached_ref_wf.get_outgoing().get_node_by_label("ph").called[0]
)
assert (
cached_ph_calcjob_node.base.extras.get("_aiida_cached_from", None)
== source_ph_calcjob_node.uuid
Expand Down Expand Up @@ -179,33 +186,43 @@ def test_caching_bands_rerun_pw_prepare(
_, source_node = run_get_node(bands_builder)

# Make the source band calculation invalid cache
source_ref_wf = [
p
for p in source_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
source_band_calcjob_node = source_ref_wf.called[1].called[1].called[0]
source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120")
first_run = source_ref_wf.get_outgoing().get_node_by_label(
"bands_with_factor_3"
)
source_band_calcjob_node = (
first_run.get_outgoing().get_node_by_label("bands").called[0]
)
assert source_band_calcjob_node.is_valid_cache

source_band_calcjob_node.is_valid_cache = False

# Run again and check it is using caching
_, cached_node = run_get_node(bands_builder)

# Check the band work chain finished okay
cached_ref_wf = [
p
for p in cached_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
# Check the band work chain finished okay which means it runs the second scf again
# and get the bands run again and success
cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120")
assert cached_ref_wf.called[1].is_finished_ok

cached_band_calcjob_node = cached_ref_wf.called[1].called[1].called[0]
# check that the first scf from cache is not okay therefore trigger the rerun
first_run = cached_ref_wf.get_outgoing().get_node_by_label(
"bands_with_factor_3"
)
cached_band_calcjob_node = first_run.called[1].called[0]
assert (
cached_band_calcjob_node.base.extras.get("_aiida_cached_from", None) is None
)
assert cached_band_calcjob_node.exit_code.status == 305

assert (
first_run.get_outgoing()
.get_node_by_label("bands_rerun")
.called[0]
.exit_code.status
== 0
)


@pytest.mark.slow
@pytest.mark.usefixtures("aiida_profile_clean")
Expand Down Expand Up @@ -240,24 +257,18 @@ def test_caching_phonon_frequencies_rerun_pw_prepare(
_, source_node = run_get_node(phonon_frequencies_builder)

# Make the source ph calculation invalid cache
source_ref_wf = [
p
for p in source_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
source_ph_calcjob_node = source_ref_wf.called[1].called[0]
source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120")
source_ph_calcjob_node = (
source_ref_wf.get_outgoing().get_node_by_label("ph").called[0]
)
assert source_ph_calcjob_node.is_valid_cache

source_ph_calcjob_node.is_valid_cache = False

# Run again and check it is using caching
_, cached_node = run_get_node(phonon_frequencies_builder)

cached_ref_wf = [
p
for p in cached_node.called
if p.base.extras.get("wavefunction_cutoff", None) == 30
][0]
cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120")
# Check the ph from rerun pw is finished okay
assert cached_ref_wf.is_finished_ok

Expand Down

0 comments on commit 134cd64

Please sign in to comment.