Skip to content

Commit

Permalink
Support nested dependency in model (aiidalab#1147)
Browse files Browse the repository at this point in the history
* Support nested dependency in the model
* replace "." with "__" in the code model key to avoid being treated as a nested model
  • Loading branch information
superstar54 authored Feb 11, 2025
1 parent 0331057 commit c6e51a7
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 32 deletions.
16 changes: 9 additions & 7 deletions src/aiidalab_qe/app/submission/global_settings/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def update_global_codes(self):

def update_active_codes(self):
for identifier, code_model in self.get_models():
if identifier != "quantumespresso.pw":
if identifier != "quantumespresso__pw":
code_model.deactivate()
properties = self._get_properties()
for identifier, code_names in self.plugin_mapping.items():
Expand Down Expand Up @@ -85,6 +85,8 @@ def add_global_model(
base_code_model = None
default_calc_job_plugin = code_model.default_calc_job_plugin
name = default_calc_job_plugin.split(".")[-1]
# "." in the model key means nested models
model_key = default_calc_job_plugin.replace(".", "__")

if not self.has_model(default_calc_job_plugin):
if default_calc_job_plugin == "quantumespresso.pw":
Expand All @@ -100,20 +102,20 @@ def add_global_model(
description=name,
default_calc_job_plugin=default_calc_job_plugin,
)
self.add_model(default_calc_job_plugin, base_code_model)
self.add_model(model_key, base_code_model)

if identifier not in self.plugin_mapping:
self.plugin_mapping[identifier] = [default_calc_job_plugin]
self.plugin_mapping[identifier] = [model_key]
else:
self.plugin_mapping[identifier].append(default_calc_job_plugin)
self.plugin_mapping[identifier].append(model_key)

return base_code_model

def check_resources(self):
if not self.has_model("quantumespresso.pw"):
if not self.has_model("quantumespresso__pw"):
return

pw_code_model = self.get_model("quantumespresso.pw")
pw_code_model = self.get_model("quantumespresso__pw")
protocol = self.input_parameters.get("workchain", {}).get("protocol", "fast")

if not self.input_structure or not pw_code_model.selected:
Expand Down Expand Up @@ -207,7 +209,7 @@ def _get_properties(self) -> list[str]:

def _check_submission_blockers(self):
# No pw code selected
pw_code_model = self.get_model("quantumespresso.pw")
pw_code_model = self.get_model("quantumespresso__pw")
if pw_code_model and not pw_code_model.selected:
yield ("No pw code selected")

Expand Down
10 changes: 5 additions & 5 deletions src/aiidalab_qe/app/submission/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ def _create_builder(self, parameters) -> ProcessBuilderNamespace:
codes = parameters["codes"]["global"]["codes"]

builder.relax.base.pw.metadata.options.resources = {
"num_machines": codes.get("quantumespresso.pw")["nodes"],
"num_mpiprocs_per_machine": codes.get("quantumespresso.pw")[
"num_machines": codes.get("quantumespresso__pw")["nodes"],
"num_mpiprocs_per_machine": codes.get("quantumespresso__pw")[
"ntasks_per_node"
],
"num_cores_per_mpiproc": codes.get("quantumespresso.pw")["cpus_per_task"],
"num_cores_per_mpiproc": codes.get("quantumespresso__pw")["cpus_per_task"],
}
mws = codes.get("quantumespresso.pw")["max_wallclock_seconds"]
mws = codes.get("quantumespresso__pw")["max_wallclock_seconds"]
builder.relax.base.pw.metadata.options["max_wallclock_seconds"] = mws
parallelization = codes["quantumespresso.pw"]["parallelization"]
parallelization = codes["quantumespresso__pw"]["parallelization"]
builder.relax.base.pw.parallelization = orm.Dict(dict=parallelization)

return builder
Expand Down
10 changes: 7 additions & 3 deletions src/aiidalab_qe/common/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ def add_models(self, models: dict[str, T]):
self.add_model(identifier, model)

def get_model(self, identifier) -> T:
if self.has_model(identifier):
return self._models[identifier]
keys = identifier.split(".", 1)
if self.has_model(keys[0]):
if len(keys) == 1:
return self._models[identifier]
else:
return self._models[keys[0]].get_model(keys[1])
raise ValueError(f"Model with identifier '{identifier}' not found.")

def get_models(self) -> t.Iterable[tuple[str, T]]:
Expand All @@ -61,7 +65,7 @@ def _link_model(self, model: T):
if not hasattr(model, "dependencies"):
return
for dependency in model.dependencies:
dependency_parts = dependency.split(".")
dependency_parts = dependency.rsplit(".", 1)
if len(dependency_parts) == 1: # from parent
target_model = self
trait = dependency
Expand Down
6 changes: 3 additions & 3 deletions src/aiidalab_qe/common/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ def update(self):
if self.override:
return
for _, code_model in self.get_models():
default_calc_job_plugin = code_model.default_calc_job_plugin
if default_calc_job_plugin in self.global_codes:
code_resources: dict = self.global_codes[default_calc_job_plugin] # type: ignore
model_key = code_model.default_calc_job_plugin.replace(".", "__")
if model_key in self.global_codes:
code_resources: dict = self.global_codes[model_key] # type: ignore
code_model.set_model_state(code_resources)

def get_model_state(self):
Expand Down
2 changes: 1 addition & 1 deletion src/aiidalab_qe/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_builder_from_protocol(
protocol = parameters["workchain"]["protocol"]

relax_builder = PwRelaxWorkChain.get_builder_from_protocol(
code=codes["global"]["codes"].get("quantumespresso.pw")["code"],
code=codes["global"]["codes"].get("quantumespresso__pw")["code"],
structure=structure,
protocol=protocol,
relax_type=RelaxType(parameters["workchain"]["relax_type"]),
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,9 @@ def app(pw_code, dos_code, projwfc_code, projwfc_bands_code):

# set up codes
global_model = app.submit_model.get_model("global")
global_model.get_model("quantumespresso.pw").activate()
global_model.get_model("quantumespresso.dos").activate()
global_model.get_model("quantumespresso.projwfc").activate()
global_model.get_model("quantumespresso__pw").activate()
global_model.get_model("quantumespresso__dos").activate()
global_model.get_model("quantumespresso__projwfc").activate()

global_model.set_selected_codes(
{
Expand Down Expand Up @@ -502,7 +502,7 @@ def _submit_app_generator(

app.submit_model.input_structure = generate_structure_data()
app.submit_model.get_model("global").get_model(
"quantumespresso.pw"
"quantumespresso__pw"
).num_cpus = 2

return app
Expand Down Expand Up @@ -812,7 +812,7 @@ def _generate_qeapp_workchain(

# step 3 setup code and resources
app.submit_model.get_model("global").get_model(
"quantumespresso.pw"
"quantumespresso__pw"
).num_cpus = 4
parameters = app.submit_model.get_model_state()
builder = app.submit_model._create_builder(parameters)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_codes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_code_not_selected(submit_app_generator):
"""Test if there is an error when the code is not selected."""
app: WizardApp = submit_app_generator(properties=["dos"])
model = app.submit_model
model.get_model("global").get_model("quantumespresso.dos").selected = None
model.get_model("global").get_model("quantumespresso__dos").selected = None
# Check builder construction passes without an error
parameters = model.get_model_state()
model._create_builder(parameters)
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_update_codes_display(app: WizardApp):
assert global_resources.code_widgets["dos"].layout.display == "none"
model.input_parameters = {"workchain": {"properties": ["pdos"]}}
global_model.update_active_codes()
assert global_model.get_model("quantumespresso.dos").is_active is True
assert global_model.get_model("quantumespresso__dos").is_active is True
assert global_resources.code_widgets["dos"].layout.display == "block"


Expand All @@ -54,7 +54,7 @@ def test_check_submission_blockers(app: WizardApp):
assert len(model.internal_submission_blockers) == 0

# set dos code to None, will introduce another blocker
dos_code = model.get_model("global").get_model("quantumespresso.dos")
dos_code = model.get_model("global").get_model("quantumespresso__dos")
dos_value = dos_code.selected
dos_code.selected = None
model.update_submission_blockers()
Expand All @@ -71,7 +71,7 @@ def test_qeapp_computational_resources_widget(app: WizardApp):
app.submit_step.render()
global_model = app.submit_model.get_model("global")
global_resources = app.submit_step.global_resources
pw_code_model = global_model.get_model("quantumespresso.pw")
pw_code_model = global_model.get_model("quantumespresso__pw")
pw_code_widget = global_resources.code_widgets["pw"]
assert pw_code_widget.parallelization.npool.layout.display == "none"
pw_code_model.parallelization_override = True
Expand Down
2 changes: 1 addition & 1 deletion tests/test_submit_qe_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_warning_messages(
submit_model = app.submit_model
global_model = submit_model.get_model("global")

pw_code = global_model.get_model("quantumespresso.pw")
pw_code = global_model.get_model("quantumespresso__pw")

# we increase the resources, so we should have the Warning-3
pw_code.num_cpus = len(os.sched_getaffinity(0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ codes:
override: false
global:
codes:
quantumespresso.dos:
quantumespresso__dos:
cpus: 1
cpus_per_task: 1
max_wallclock_seconds: 43200
nodes: 1
ntasks_per_node: 1
quantumespresso.projwfc:
quantumespresso__projwfc:
cpus: 1
cpus_per_task: 1
max_wallclock_seconds: 43200
nodes: 1
ntasks_per_node: 1
quantumespresso.pw:
quantumespresso__pw:
cpus: 2
cpus_per_task: 1
max_wallclock_seconds: 43200
Expand Down

0 comments on commit c6e51a7

Please sign in to comment.