From e1765bfdfe68ee32b4839e41bae008d7b7b6c571 Mon Sep 17 00:00:00 2001 From: Bonan Zhu Date: Tue, 23 Jan 2024 14:42:31 +0800 Subject: [PATCH] Update band handling --- aiida_user_addons/vworkflows/new_bands.py | 218 +++++----------------- 1 file changed, 48 insertions(+), 170 deletions(-) diff --git a/aiida_user_addons/vworkflows/new_bands.py b/aiida_user_addons/vworkflows/new_bands.py index 8715696..cf355a3 100644 --- a/aiida_user_addons/vworkflows/new_bands.py +++ b/aiida_user_addons/vworkflows/new_bands.py @@ -82,9 +82,7 @@ def define(cls, spec): relax_work = WorkflowFactory(cls._relax_wk_string) base_work = WorkflowFactory(cls._base_wk_string) - spec.input( - "structure", help="The input structure", valid_type=orm.StructureData - ) + spec.input("structure", help="The input structure", valid_type=orm.StructureData) spec.input( "bs_kpoints", help="Explicit kpoints for the bands. Will not generate kpoints if supplied.", @@ -209,18 +207,12 @@ def define(cls, spec): required=False, help="Primitive structure used for band structure calculations", ) - spec.output( - "band_structure", required=False, help="Computed band structure with labels" - ) - spec.output( - "seekpath_parameters", help="Parameters used by seekpath", required=False - ) + spec.output("band_structure", required=False, help="Computed band structure with labels") + spec.output("seekpath_parameters", help="Parameters used by seekpath", required=False) spec.output("dos", required=False) spec.output("projectors", required=False) - spec.exit_code( - 501, "ERROR_SUB_PROC_RELAX_FAILED", message="Relaxation workchain failed" - ) + spec.exit_code(501, "ERROR_SUB_PROC_RELAX_FAILED", message="Relaxation workchain failed") spec.exit_code(502, "ERROR_SUB_PROC_SCF_FAILED", message="SCF workchain failed") spec.exit_code( 503, @@ -296,9 +288,7 @@ def should_generate_path(self): Seekpath should only run if no explicit bands is provided or we are just running for DOS, in which case the original structure is used. """ - return "bs_kpoints" not in self.inputs and ( - not self.inputs.get("only_dos", False) - ) + return "bs_kpoints" not in self.inputs and (not self.inputs.get("only_dos", False)) def generate_path(self): """ @@ -322,9 +312,7 @@ def generate_path(self): else: # Using sumo interface inputs = { - "line_density": self.inputs.get( - "line_density", orm.Float(self.DEFAULT_LINE_DENSITY) - ), + "line_density": self.inputs.get("line_density", orm.Float(self.DEFAULT_LINE_DENSITY)), "symprec": self.inputs.get("symprec", orm.Float(self.DEFAULT_SYMPREC)), "mode": orm.Str(mode), "metadata": {"call_link_label": "sumo_kpath"}, @@ -336,36 +324,26 @@ def generate_path(self): # For magnetic structures, create different kinds for the analysis in case that the # symmetry should be lowered. This also makes sure that the magnetic moments are consistent if magmom: - decorate_result = magnetic_structure_decorate( - self.ctx.current_structure, orm.List(list=magmom) - ) + decorate_result = magnetic_structure_decorate(self.ctx.current_structure, orm.List(list=magmom)) decorated = decorate_result["structure"] # Run seekpath on the decorated structure kpath_results = func(decorated, **inputs) decorated_primitive = kpath_results["primitive_structure"] # Convert back to undecorated structures and add consistent magmom input - dedecorate_result = magnetic_structure_dedecorate( - decorated_primitive, decorate_result["mapping"] - ) + dedecorate_result = magnetic_structure_dedecorate(decorated_primitive, decorate_result["mapping"]) self.ctx.magmom = dedecorate_result["magmom"].get_list() self.ctx.current_structure = dedecorate_result["structure"] else: kpath_results = func(self.ctx.current_structure, **inputs) self.ctx.current_structure = kpath_results["primitive_structure"] - if not np.allclose( - self.ctx.current_structure.cell, current_structure_backup.cell - ): + if not np.allclose(self.ctx.current_structure.cell, current_structure_backup.cell): if self.inputs.scf.get("kpoints"): self.report( "The primitive structure is not the same as the input structure but explicty kpoints are supplied - aborting the workchain." ) - return ( - self.exit_codes.ERROR_INPUT_STRUCTURE_NOT_PRIMITIVE - ) # pylint: disable=no-member - self.report( - "The primitive structure is not the same as the input structure - using the former for all calculations from now." - ) + return self.exit_codes.ERROR_INPUT_STRUCTURE_NOT_PRIMITIVE # pylint: disable=no-member + self.report("The primitive structure is not the same as the input structure - using the former for all calculations from now.") self.ctx.bs_kpoints = kpath_results["explicit_kpoints"] self.out("primitive_structure", self.ctx.current_structure) if "parameters" in kpath_results: @@ -391,9 +369,7 @@ def run_scf(self): # Ensure that writing the CHGCAR file is on pdict = inputs.parameters.get_dict() - if (pdict[OVERRIDE_NAMESPACE].get("lcharg") == False) or ( - pdict[OVERRIDE_NAMESPACE].get("LCHARG") == False - ): + if (pdict[OVERRIDE_NAMESPACE].get("lcharg") == False) or (pdict[OVERRIDE_NAMESPACE].get("LCHARG") == False): pdict[OVERRIDE_NAMESPACE]["lcharg"] = True inputs.parameters = orm.Dict(dict=pdict) self.report("Correction: setting LCHARG to True") @@ -401,9 +377,7 @@ def run_scf(self): # Take magmom from the context, in case that the magmom is rearranged in the primitive cell magmom = self.ctx.get("magmom") if magmom: - inputs.parameters = nested_update_dict_node( - inputs.parameters, {OVERRIDE_NAMESPACE: {"magmom": magmom}} - ) + inputs.parameters = nested_update_dict_node(inputs.parameters, {OVERRIDE_NAMESPACE: {"magmom": magmom}}) running = self.submit(base_work, **inputs) self.report(f"Running SCF calculation {running}") @@ -439,9 +413,7 @@ def run_bands_dos(self): inputs.chgcar = self.ctx.chgcar if not (inputs.get("restart_folder") or inputs.get("chgcar")): - raise RuntimeError( - "One of the restart_folder or chgcar must be set for non-scf calculations" - ) + raise RuntimeError("One of the restart_folder or chgcar must be set for non-scf calculations") running = {} @@ -449,18 +421,12 @@ def run_bands_dos(self): if (only_dos is None) or (only_dos.value is False): if "bands" in self.inputs: - bands_input = AttributeDict( - self.exposed_inputs(base_work, namespace="bands") - ) + bands_input = AttributeDict(self.exposed_inputs(base_work, namespace="bands")) else: bands_input = AttributeDict( { - "settings": orm.Dict( - dict={"parser_settings": {"add_bands": True}} - ), - "parameters": orm.Dict( - dict={"charge": {"constant_charge": True}} - ), + "settings": orm.Dict(dict={"parser_settings": {"add_bands": True}}), + "parameters": orm.Dict(dict={"charge": {"constant_charge": True}}), } ) @@ -503,25 +469,19 @@ def run_bands_dos(self): if ("dos_kpoints_density" in self.inputs) or ("dos" in self.inputs): if "dos" in self.inputs: - dos_input = AttributeDict( - self.exposed_inputs(base_work, namespace="dos") - ) + dos_input = AttributeDict(self.exposed_inputs(base_work, namespace="dos")) else: dos_input = AttributeDict( { "settings": orm.Dict(dict={"add_dos": True}), - "parameters": orm.Dict( - dict={"charge": {"constant_charge": True}} - ), + "parameters": orm.Dict(dict={"charge": {"constant_charge": True}}), } ) # Use the supplied kpoints density for DOS if "dos_kpoints_density" in self.inputs: dos_kpoints = orm.KpointsData() dos_kpoints.set_cell_from_structure(self.ctx.current_structure) - dos_kpoints.set_kpoints_mesh_from_density( - self.inputs.dos_kpoints_density.value * 2 * np.pi - ) + dos_kpoints.set_kpoints_mesh_from_density(self.inputs.dos_kpoints_density.value * 2 * np.pi) dos_input.kpoints = dos_kpoints # Special treatment - combine the parameters @@ -568,9 +528,7 @@ def inspect_bands_dos(self): if "bands_workchain" in self.ctx: bands = self.ctx.bands_workchain if not bands.is_finished_ok: - self.report( - f"Bands calculation finished with error, exit_status: {bands}" - ) + self.report(f"Bands calculation finished with error, exit_status: {bands}") exit_code = self.exit_codes.ERROR_SUB_PROC_BANDS_FAILED self.out( "band_structure", @@ -582,9 +540,7 @@ def inspect_bands_dos(self): if "dos_workchain" in self.ctx: dos = self.ctx.dos_workchain if not dos.is_finished_ok: - self.report( - f"DOS calculation finished with error, exit_status: {dos.exit_status}" - ) + self.report(f"DOS calculation finished with error, exit_status: {dos.exit_status}") exit_code = self.exit_codes.ERROR_SUB_PROC_DOS_FAILED # Attach outputs @@ -614,11 +570,7 @@ def on_terminated(self): pass if cleaned_calcs: - self.report( - "cleaned remote folders of calculations: {}".format( - " ".join(map(str, cleaned_calcs)) - ) - ) + self.report("cleaned remote folders of calculations: {}".format(" ".join(map(str, cleaned_calcs)))) @calcfunction @@ -639,9 +591,7 @@ def seekpath_structure_analysis(structure, **kwargs): from aiida.tools import get_explicit_kpoints_path # All keyword arugments should be `Data` node instances of base type and so should have the `.value` attribute - unwrapped_kwargs = { - key: node.value for key, node in kwargs.items() if isinstance(node, orm.Data) - } + unwrapped_kwargs = {key: node.value for key, node in kwargs.items() if isinstance(node, orm.Data)} return get_explicit_kpoints_path(structure, **unwrapped_kwargs) @@ -721,9 +671,7 @@ def define(cls, spec): help="Number of kpoints per split, INCLUDING the weighted SCF kpoints.", required=True, ) - spec.input( - "structure", help="The input structure", valid_type=orm.StructureData - ) + spec.input("structure", help="The input structure", valid_type=orm.StructureData) spec.expose_inputs( relax_work, namespace="relax", @@ -769,16 +717,10 @@ def define(cls, spec): required=False, help="Primitive structure used for band structure calculations", ) - spec.output( - "band_structure", required=False, help="Computed band structure with labels" - ) - spec.output( - "seekpath_parameters", help="Parameters used by seekpath", required=False - ) + spec.output("band_structure", required=False, help="Computed band structure with labels") + spec.output("seekpath_parameters", help="Parameters used by seekpath", required=False) - spec.exit_code( - 501, "ERROR_SUB_PROC_RELAX_FAILED", message="Relaxation workchain failed" - ) + spec.exit_code(501, "ERROR_SUB_PROC_RELAX_FAILED", message="Relaxation workchain failed") spec.exit_code(502, "ERROR_SUB_PROC_SCF_FAILED", message="SCF workchain failed") spec.exit_code( 503, @@ -805,33 +747,20 @@ def make_splitted_kpoints(self): if "kpoints" in self.inputs.scf: scf_kpoints = self.inputs.scf.kpoints # Relaxation workchain has kpoints output - elif ( - "workchain_relax" in self.ctx - and "kpoints" in self.ctx["workchain_relax"].outputs - ): + elif "workchain_relax" in self.ctx and "kpoints" in self.ctx["workchain_relax"].outputs: scf_kpoints = self.ctx.workchain_relax.outputs.kpoints - self.report( - f"Using output from <{self.ctx.workchain_relax}> for SCF kpoints." - ) + self.report(f"Using output from <{self.ctx.workchain_relax}> for SCF kpoints.") # Parse from relaxation output elif "workchain_relax" in self.ctx: # Try getting the kpoints from the retrieved folder scf_kpoints = extract_kpoints_from_calc(self.ctx.workchain_relax) - self.report( - f"Extracted SCF kpoints from retrieved vasprun.xml of <{self.ctx.workchain_relax}>." - ) + self.report(f"Extracted SCF kpoints from retrieved vasprun.xml of <{self.ctx.workchain_relax}>.") else: - self.report( - "No valid SCF kpoints is avaliable to use. Please define scf.kpoints explicitly!" - ) - return ( - self.exit_codes.ERROR_NO_VALID_SCF_KPOINTS_INPUT - ) # pylint: disable=no-member + self.report("No valid SCF kpoints is avaliable to use. Please define scf.kpoints explicitly!") + return self.exit_codes.ERROR_NO_VALID_SCF_KPOINTS_INPUT # pylint: disable=no-member # Number of kpoints per split, NOT including the SCF kpoints - per_split = orm.Int( - self.inputs.kpoints_per_split.value - scf_kpoints.get_kpoints().shape[0] - ) + per_split = orm.Int(self.inputs.kpoints_per_split.value - scf_kpoints.get_kpoints().shape[0]) kpoints_for_calc = split_kpoints(scf_kpoints, full_kpoints, per_split) self.ctx.kpoints_for_calc = kpoints_for_calc @@ -848,14 +777,10 @@ def run_scf_multi(self): # Ensure that the bands are parsed if "settings" not in inputs: - inputs.settings = orm.Dict( - dict={"parser_settings": {"add_bands": True}} - ) + inputs.settings = orm.Dict(dict={"parser_settings": {"add_bands": True}}) else: # Merge with 'parser_settings' - inputs.settings = nested_update_dict_node( - inputs.settings, {"parser_settings": {"add_bands": True}} - ) + inputs.settings = nested_update_dict_node(inputs.settings, {"parser_settings": {"add_bands": True}}) # Swap the kpoints the the one with zero-weight parts inputs.kpoints = value @@ -863,9 +788,7 @@ def run_scf_multi(self): inputs.metadata.call_link_label = f"bandstructure_split_{idx:03d}" inputs.structure = self.ctx.current_structure running = self.submit(workflow_class, **inputs) - self.report( - f"launching {workflow_class.__name__}<{running.pk}> for split #{idx}" - ) + self.report(f"launching {workflow_class.__name__}<{running.pk}> for split #{idx}") self.to_context(workchains=append_(running)) def inspect_and_combine_bands(self): @@ -879,14 +802,10 @@ def inspect_and_combine_bands(self): self.report("At least one calculation did not have zero return code!") # Extract the bands information - self.report( - f"Extracting output bandstructure from {len(self.ctx.workchains)} workchains." - ) + self.report(f"Extracting output bandstructure from {len(self.ctx.workchains)} workchains.") kwargs = {} for work in workchains: - link_label = ( - work.get_incoming(link_type=LinkType.CALL_WORK).one().link_label - ) + link_label = work.get_incoming(link_type=LinkType.CALL_WORK).one().link_label link_idx = int(link_label.split("_")[-1]) kwargs[f"band_{link_idx:03d}"] = work.outputs.bands kwargs[f"kpoint_{link_idx:03d}"] = work.inputs.kpoints @@ -905,9 +824,7 @@ def split_kpoints(scf_kpoints, band_kpoints, kpn_per_split): return _split_kpoints(scf_kpoints, band_kpoints, kpn_per_split) -def _split_kpoints( - scf_kpoints: orm.KpointsData, band_kpoints: orm.KpointsData, kpn_per_split: orm.Int -): +def _split_kpoints(scf_kpoints: orm.KpointsData, band_kpoints: orm.KpointsData, kpn_per_split: orm.Int): """ Split the kpoints into multiple one and combined with SCF kpoints @@ -920,9 +837,7 @@ def _split_kpoints( # Split the kpoints kpn_per_split = int(kpn_per_split) - kpt_splits = [ - band_kpn[i : i + kpn_per_split] for i in range(0, nband_kpts, kpn_per_split) - ] + kpt_splits = [band_kpn[i : i + kpn_per_split] for i in range(0, nband_kpts, kpn_per_split)] splitted_kpoints = {} for isplit, skpts in enumerate(kpt_splits): @@ -975,17 +890,11 @@ def combine_bands_data(bs_kpoints, **kwargs): Returns a `BandsData` by combining the zero-weighted bands from each calculation. """ - kpoints_list = [ - [node, int(key.split("_")[1])] - for key, node in kwargs.items() - if "kpoint" in key - ] + kpoints_list = [[node, int(key.split("_")[1])] for key, node in kwargs.items() if "kpoint" in key] kpoints_list.sort(key=lambda x: x[1]) kpoints_list = [item[0] for item in kpoints_list] - bands_list = [ - [node, int(key.split("_")[1])] for key, node in kwargs.items() if "band" in key - ] + bands_list = [[node, int(key.split("_")[1])] for key, node in kwargs.items() if "band" in key] bands_list.sort(key=lambda x: x[1]) bands_list = [item[0] for item in bands_list] @@ -1044,9 +953,7 @@ def _combine_bands_data( # Sanity check all valid kpoints should combine into the original path all_kpoints = np.concatenate(kpoints_combine, axis=0) if not np.allclose(all_kpoints, bs_kpoints.get_kpoints()): - raise ValueError( - "The k-path segements do not much the original path when combined!" - ) + raise ValueError("The k-path segements do not much the original path when combined!") # Compose the node band_data = orm.BandsData() @@ -1076,47 +983,18 @@ def _extract_kpoints_from_retrieved(retrieved): """ Extract explicity kpoints from a finished calculation """ - tmpdir = Path(mkdtemp()) - if "vasprun.xml" in retrieved.list_object_names(): - with retrieved.open("vasprun.xml", mode="r") as fsrc: - with open(tmpdir / "vasprun.xml", mode="w") as fdst: - shutil.copyfileobj(fsrc, fdst) - elif "vasprun.xml.gz" in retrieved.list_object_names(): - with retrieved.open("vasprun.xml.gz", mode="rb") as fsrc: - with GzipFile(fileobj=fsrc, mode="rb") as gobj: - with open(tmpdir / "vasprun.xml", mode="wb") as fdst: - shutil.copyfileobj(gobj, fdst) - else: - raise RuntimeError("No valid vasprun.xml file to use!!") - - new_format = False - try: - # NOTE should be deprecated!!! - parser = VasprunParser(file_path=str(tmpdir / "vasprun.xml")) - except TypeError: - # Use newer version - use file_obj instead - new_format = True - with open(tmpdir / "vasprun.xml") as fh: - parser = VasprunParser(handler=fh) + with retrieved.base.repository.open("vasprun.xml", "rb") as fh: + parser = VasprunParser(handler=fh) vkpoints = parser.kpoints if vkpoints["mode"] != "explicit": raise ValueError("Only explicity kpoints is supported!") - if new_format: - kpoints_array = vkpoints["points"] - weights_array = vkpoints["weights"] - else: - kpoints_array = np.stack( - [kpt.get_point() for kpt in vkpoints["points"]], axis=0 - ) - weights_array = np.array([kpt.get_weight() for kpt in vkpoints["points"]]) + kpoints_array = vkpoints["points"] + weights_array = vkpoints["weights"] kpoints_data = orm.KpointsData() kpoints_data.set_kpoints(kpoints=kpoints_array, weights=weights_array) - # Remove the directory tree - shutil.rmtree(tmpdir) - return kpoints_data