diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad857a0..7d2a6a4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -90,7 +90,7 @@ jobs: pip install .[dev] pip uninstall --yes aiida-core git clone --depth 1 https://github.com/aiidateam/aiida-core.git - uv pip install ./aiida-core + pip install ./aiida-core - name: Check aiida-core version is the edget ('post' in version tag) run: | diff --git a/pyproject.toml b/pyproject.toml index 4cb067f..5504437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ filterwarnings = [ 'ignore:Object of type .* not in session, .* operation along .* will not proceed:sqlalchemy.exc.SAWarning', 'ignore:The `Code` class is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning', 'ignore:`CalcJobNode.*` is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning', + 'ignore:`WorkChainNode.get_outgoing` is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning', ] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", diff --git a/src/aiida_sssp_workflow/workflows/convergence/_base.py b/src/aiida_sssp_workflow/workflows/convergence/_base.py index 0ee27a7..7a591cc 100644 --- a/src/aiida_sssp_workflow/workflows/convergence/_base.py +++ b/src/aiida_sssp_workflow/workflows/convergence/_base.py @@ -310,6 +310,9 @@ def run_reference(self): ecutwfc, ecutrho = round(ecutwfc), round(ecutrho) builder = self.prepare_evaluate_builder(ecutwfc=ecutwfc, ecutrho=ecutrho) + # Add link to the called workchain by '{ecutwfc}_{ecutrho}' + builder.metadata.call_link_label = f"cutoffs_{ecutwfc}_{ecutrho}" + running = self.submit(builder) running.base.extras.set("wavefunction_cutoff", ecutwfc) running.base.extras.set("charge_density_cutoff", ecutrho) @@ -351,6 +354,9 @@ def run_convergence(self): ecutwfc, ecutrho = round(ecutwfc), round(ecutrho) builder = self.prepare_evaluate_builder(ecutwfc=ecutwfc, ecutrho=ecutrho) + # Add link to the called workchain by '{ecutwfc}_{ecutrho}' + builder.metadata.call_link_label = f"cutoffs_{ecutwfc}_{ecutrho}" + running = self.submit(builder) self.report( f"launching fix ecutrho={ecutrho} [ecutwfc={ecutwfc}] {running.process_label}<{running.pk}>" diff --git a/src/aiida_sssp_workflow/workflows/convergence/phonon_frequencies.py b/src/aiida_sssp_workflow/workflows/convergence/phonon_frequencies.py index c24b723..2380936 100644 --- a/src/aiida_sssp_workflow/workflows/convergence/phonon_frequencies.py +++ b/src/aiida_sssp_workflow/workflows/convergence/phonon_frequencies.py @@ -139,7 +139,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho): # pw calculation builder.scf.kpoints_distance = orm.Float(protocol["kpoints_distance"]) - builder.scf.metadata.call_link_label = "phonon_frequencies_scf" + builder.scf.metadata.call_link_label = "scf" builder.scf.pw["code"] = self.inputs.pw_code builder.scf.pw["pseudos"] = self.pseudos builder.scf.pw["parameters"] = orm.Dict(dict=pw_parameters) @@ -159,7 +159,7 @@ def prepare_evaluate_builder(self, ecutwfc, ecutrho): } } - builder.phonon.metadata.call_link_label = "phonon_frequencies_ph" + builder.phonon.metadata.call_link_label = "ph" builder.phonon["qpoints"] = qpoints builder.phonon.ph["code"] = self.inputs.ph_code builder.phonon.ph["parameters"] = orm.Dict(dict=ph_parameters) diff --git a/tests/workflows/convergence/test_caching.py b/tests/workflows/convergence/test_caching.py index 5f63269..7c3f35f 100644 --- a/tests/workflows/convergence/test_caching.py +++ b/tests/workflows/convergence/test_caching.py @@ -37,32 +37,35 @@ def test_caching_bands( # check the first scf of reference # The pw calculation - # XXX: use link label to find the expected calcjob node - source_ref_wf = [ - p - for p in source_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] - source_scf_calcjob_node = source_ref_wf.called[1].called[0].called[1] + source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120") + first_run = source_ref_wf.get_outgoing().get_node_by_label( + "bands_with_factor_3" + ) + source_scf_calcjob_node = ( + first_run.get_outgoing().get_node_by_label("scf").called[1] + ) assert ( source_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) is None ) # check the first bands of reference was cached # The pw calculation - source_band_calcjob_node = source_ref_wf.called[1].called[1].called[0] + source_band_calcjob_node = ( + first_run.get_outgoing().get_node_by_label("bands").called[0] + ) assert ( source_band_calcjob_node.base.extras.get("_aiida_cached_from", None) is None ) # Run again and check it is using caching _, cached_node = run_get_node(bands_builder) - cached_ref_wf = [ - p - for p in cached_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] - cached_scf_calcjob_node = cached_ref_wf.called[1].called[0].called[1] + cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120") + first_run = cached_ref_wf.get_outgoing().get_node_by_label( + "bands_with_factor_3" + ) + cached_scf_calcjob_node = ( + first_run.get_outgoing().get_node_by_label("scf").called[1] + ) assert ( cached_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) @@ -70,7 +73,9 @@ def test_caching_bands( ) assert not cached_scf_calcjob_node.base.caching.is_valid_cache - cached_band_calcjob_node = cached_ref_wf.called[1].called[1].called[0] + cached_band_calcjob_node = ( + first_run.get_outgoing().get_node_by_label("bands").called[0] + ) assert ( cached_band_calcjob_node.base.extras.get("_aiida_cached_from", None) @@ -112,30 +117,30 @@ def test_caching_phonon_frequencies( # check the first scf of reference # The pw calculation - source_ref_wf = [ - p - for p in source_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] - source_scf_calcjob_node = source_ref_wf.called[0].called[1] + source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120") + + source_scf_calcjob_node = ( + source_ref_wf.get_outgoing().get_node_by_label("scf").called[1] + ) assert ( source_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) is None ) # The ph calculation - source_ph_calcjob_node = source_ref_wf.called[1].called[0] + source_ph_calcjob_node = ( + source_ref_wf.get_outgoing().get_node_by_label("ph").called[0] + ) assert ( source_ph_calcjob_node.base.extras.get("_aiida_cached_from", None) is None ) # Run again and check it is using caching _, cached_node = run_get_node(phonon_frequencies_builder) - cached_ref_wf = [ - p - for p in cached_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] - cached_scf_calcjob_node = cached_ref_wf.called[0].called[1] + cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120") + + cached_scf_calcjob_node = ( + cached_ref_wf.get_outgoing().get_node_by_label("scf").called[1] + ) assert ( cached_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) @@ -144,7 +149,9 @@ def test_caching_phonon_frequencies( assert not cached_scf_calcjob_node.base.caching.is_valid_cache # Run again and check it is using caching - cached_ph_calcjob_node = cached_ref_wf.called[1].called[0] + cached_ph_calcjob_node = ( + cached_ref_wf.get_outgoing().get_node_by_label("ph").called[0] + ) assert ( cached_ph_calcjob_node.base.extras.get("_aiida_cached_from", None) == source_ph_calcjob_node.uuid @@ -179,12 +186,13 @@ def test_caching_bands_rerun_pw_prepare( _, source_node = run_get_node(bands_builder) # Make the source band calculation invalid cache - source_ref_wf = [ - p - for p in source_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] - source_band_calcjob_node = source_ref_wf.called[1].called[1].called[0] + source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120") + first_run = source_ref_wf.get_outgoing().get_node_by_label( + "bands_with_factor_3" + ) + source_band_calcjob_node = ( + first_run.get_outgoing().get_node_by_label("bands").called[0] + ) assert source_band_calcjob_node.is_valid_cache source_band_calcjob_node.is_valid_cache = False @@ -192,20 +200,29 @@ def test_caching_bands_rerun_pw_prepare( # Run again and check it is using caching _, cached_node = run_get_node(bands_builder) - # Check the band work chain finished okay - cached_ref_wf = [ - p - for p in cached_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] + # Check the band work chain finished okay which means it runs the second scf again + # and get the bands run again and success + cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120") assert cached_ref_wf.called[1].is_finished_ok - cached_band_calcjob_node = cached_ref_wf.called[1].called[1].called[0] + # check that the first scf from cache is not okay therefore trigger the rerun + first_run = cached_ref_wf.get_outgoing().get_node_by_label( + "bands_with_factor_3" + ) + cached_band_calcjob_node = first_run.called[1].called[0] assert ( cached_band_calcjob_node.base.extras.get("_aiida_cached_from", None) is None ) assert cached_band_calcjob_node.exit_code.status == 305 + assert ( + first_run.get_outgoing() + .get_node_by_label("bands_rerun") + .called[0] + .exit_code.status + == 0 + ) + @pytest.mark.slow @pytest.mark.usefixtures("aiida_profile_clean") @@ -240,12 +257,10 @@ def test_caching_phonon_frequencies_rerun_pw_prepare( _, source_node = run_get_node(phonon_frequencies_builder) # Make the source ph calculation invalid cache - source_ref_wf = [ - p - for p in source_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] - source_ph_calcjob_node = source_ref_wf.called[1].called[0] + source_ref_wf = source_node.get_outgoing().get_node_by_label("cutoffs_30_120") + source_ph_calcjob_node = ( + source_ref_wf.get_outgoing().get_node_by_label("ph").called[0] + ) assert source_ph_calcjob_node.is_valid_cache source_ph_calcjob_node.is_valid_cache = False @@ -253,11 +268,7 @@ def test_caching_phonon_frequencies_rerun_pw_prepare( # Run again and check it is using caching _, cached_node = run_get_node(phonon_frequencies_builder) - cached_ref_wf = [ - p - for p in cached_node.called - if p.base.extras.get("wavefunction_cutoff", None) == 30 - ][0] + cached_ref_wf = cached_node.get_outgoing().get_node_by_label("cutoffs_30_120") # Check the ph from rerun pw is finished okay assert cached_ref_wf.is_finished_ok