diff --git a/src/atomate2/cp2k/builders/defect.py b/src/atomate2/cp2k/builders/defect.py new file mode 100644 index 0000000000..cc829b1d84 --- /dev/null +++ b/src/atomate2/cp2k/builders/defect.py @@ -0,0 +1,1130 @@ +from datetime import datetime +from itertools import groupby +from typing import Dict, Iterator, List, Literal, Optional + +import numpy as np +from emmet.core.electronic_structure import ElectronicStructureDoc +from emmet.core.material import MaterialsDoc +from maggma.builders import Builder +from maggma.stores import Store +from maggma.utils import grouper +from monty.json import MontyDecoder, jsanitize +from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher +from pymatgen.core import Structure +from pymatgen.io.cp2k.inputs import Cp2kInput +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +from atomate2.cp2k.schemas.calc_types import TaskType +from atomate2.cp2k.schemas.defect import DefectDoc, DefectiveMaterialDoc +from atomate2.settings import Atomate2Settings + +__author__ = "Nicholas Winner " + + +class DefectBuilder(Builder): + """ + The DefectBuilder collects task documents performed on structures containing a single point defect. + The builder is intended to group tasks corresponding to the same defect (species including charge state), + find the best ones, and perform finite-size defect corrections to create a defect document. These + defect documents can then be assembled into defect phase diagrams using the DefectThermoBuilder. + + In order to make the build process easier, an entry must exist inside of the task doc that identifies it + as a point defect calculation. Currently this is the Pymatgen defect object keyed by "defect". In the future, + this may be changed to having a defect transformation in the transformation history. + + The process is as follows: + + 1.) Find all documents containing the defect query. + 2.) Find all documents that do not contain the defect query, and which have DOS and dielectric data already + calculated. These are the candidate bulk tasks. + 3.) For each candidate defect task, attempt to match to a candidate bulk task of the same number of sites + (+/- 1) with the required properties for analysis. Reject defects that do not have a corresponding + bulk calculation. + 4.) Convert (defect, bulk task) doc pairs to DefectDocs + 5.) Post-process and validate defect document + 6.) Update the defect store + """ + + # TODO how to incorporate into settings? + DEFAULT_ALLOWED_DFCT_TASKS = [ + TaskType.Structure_Optimization.value, + ] + + DEFAULT_ALLOWED_BULK_TASKS = [ + TaskType.Structure_Optimization.value, + TaskType.Static.value, + ] + + def __init__( + self, + tasks: Store, + defects: Store, + dielectric: Store, + electronic_structure: Store, + materials: Store, + electrostatic_potentials: Store, + task_validation: Optional[Store] = None, + query: Optional[Dict] = None, + bulk_query: Optional[Dict] = None, + allowed_dfct_types: Optional[List[str]] = DEFAULT_ALLOWED_DFCT_TASKS, + allowed_bulk_types: Optional[List[str]] = DEFAULT_ALLOWED_BULK_TASKS, + task_schema: Literal[ + "cp2k" + ] = "cp2k", # TODO cp2k specific right now, but this will go in common eventually + settings: Dict | None = None, + **kwargs, + ): + """ + Args: + tasks: Store of task documents + defects: Store of defect documents to generate + dielectric: Store of dielectric data + electronic_structure: Store of electronic structure data + materials: Store of materials documents + electrostatic_potentials: Store of electrostatic potential data. These + are generally stored in seperately from the tasks on GridFS due to their size. + task_validation: Store of task validation documents. If true, then only tasks that have passed + validation will be considered. + query: dictionary to limit tasks to be analyzed. NOT the same as the defect_query property + allowed_task_types: list of task_types that can be processed + settings: EmmetBuildSettings object + """ + + self.tasks = tasks + self.defects = defects + self.materials = materials + self.dielectric = dielectric + self.electronic_structure = electronic_structure + self.electrostatic_potentials = electrostatic_potentials + self.task_validation = task_validation + self._allowed_dfct_types = ( + allowed_dfct_types # TODO How to incorporate into getitems? + ) + self._allowed_bulk_types = ( + allowed_bulk_types # TODO How to incorporate into getitems? + ) + + settings = settings if settings else {} + self.settings = Atomate2Settings(**settings) # TODO don't think this is right + self.query = query if query else {} + self.bulk_query = bulk_query if bulk_query else {} + self.timestamp = None + self._mpid_map = {} + self.task_schema = task_schema + self.kwargs = kwargs + + # TODO Long term, schemas should be part of the matching and grouping process so that a builder can be run on a mixture + self.query.update( + { + "output.@module": f"atomate2.{self.task_schema}.schemas.task", + "output.@class": "TaskDocument", + } + ) + self.bulk_query.update( + { + "output.@module": f"atomate2.{self.task_schema}.schemas.task", + "output.@class": "TaskDocument", + } + ) + self._defect_query = "output.additional_json.info.defect" + + self._required_defect_properties = [ + self._defect_query, + self.tasks.key, + "output.output.energy", + "output.output.structure", + "output.input", + "output.nsites", + "output.cp2k_objects.v_hartree", + ] + + self._required_bulk_properties = [ + self.tasks.key, + "output.output.energy", + "output.output.structure", + "output.input", + "output.cp2k_objects.v_hartree", + "output.output.vbm", + ] + + self._optional_defect_properties = [] + self._optional_bulk_properties = [] + + sources = [ + tasks, + dielectric, + electronic_structure, + materials, + electrostatic_potentials, + ] + if self.task_validation: + sources.append(self.task_validation) + super().__init__(sources=sources, targets=[defects], **kwargs) + + @property + def defect_query(self) -> str: + """ + The standard query for defect tasks. + """ + return self._defect_query + + # TODO Hartree pot should be required but only for charged defects + @property + def required_defect_properties(self) -> List: + """ + Properties essential to processing a defect task. + """ + return self._required_defect_properties + + @property + def required_bulk_properties(self) -> List: + """ + Properties essential to processing a bulk task. + """ + return self._required_bulk_properties + + @property + def optional_defect_properties(self) -> List: + """ + Properties that are optional for processing a defect task. + """ + return self._optional_defect_properties + + @property + def optional_bulk_properties(self) -> List: + """ + Properties that are optional for bulk tasks. + """ + return self._optional_bulk_properties + + @property + def mpid_map(self) -> Dict: + return self._mpid_map + + @property + def allowed_dfct_types(self) -> set: + return {TaskType(t) for t in self._allowed_dfct_types} + + @property + def allowed_bulk_types(self) -> set: + return {TaskType(t) for t in self._allowed_bulk_types} + + def ensure_indexes(self): + """ + Ensures indicies on the tasks and materials collections + """ + + # Basic search index for tasks + self.tasks.ensure_index(self.tasks.key) + self.tasks.ensure_index("output.last_updated") + self.tasks.ensure_index("output.state") + self.tasks.ensure_index("output.formula_pretty") # TODO is necessary? + + # Search index for materials + self.materials.ensure_index("material_id") + self.materials.ensure_index("last_updated") + self.materials.ensure_index("task_ids") + + # Search index for defects + self.defects.ensure_index("material_id") + self.defects.ensure_index("last_updated") + self.defects.ensure_index("task_ids") + + if self.task_validation: + self.task_validation.ensure_index("task_id") + self.task_validation.ensure_index("valid") + + def prechunk(self, number_splits: int) -> Iterator[Dict]: + + tag_query = {} + if len(self.settings.BUILD_TAGS) > 0 and len(self.settings.EXCLUDED_TAGS) > 0: + tag_query["$and"] = [ + {"tags": {"$in": self.settings.BUILD_TAGS}}, + {"tags": {"$nin": self.settings.EXCLUDED_TAGS}}, + ] + elif len(self.settings.BUILD_TAGS) > 0: + tag_query["tags"] = {"$in": self.settings.BUILD_TAGS} + + # Get defect tasks + temp_query = self.query.copy() + temp_query.update(tag_query) + temp_query.update( + {d: {"$exists": True, "$ne": None} for d in self.required_defect_properties} + ) + temp_query.update({self.defect_query: {"$exists": True}, "state": "successful"}) + defect_tasks = { + doc[self.tasks.key] + for doc in self.tasks.query( + criteria=temp_query, properties=[self.tasks.key] + ) + } + + # Get bulk tasks + temp_query = self.bulk_query.copy() + temp_query.update(tag_query) + temp_query.update({d: {"$exists": True} for d in self.required_bulk_properties}) + temp_query.update( + {self.defect_query: {"$exists": False}, "state": "successful"} + ) + bulk_tasks = { + doc[self.tasks.key] + for doc in self.tasks.query( + criteria=temp_query, properties=[self.tasks.key] + ) + } + + N = np.ceil(len(defect_tasks) / number_splits) + for task_chunk in grouper(defect_tasks, N): + yield {"query": {"task_id": {"$in": task_chunk + list(bulk_tasks)}}} + + def get_items(self) -> Iterator[List[Dict]]: + """ + Gets all items to process into defect documents. + This does no datetime checking; relying on on whether + task_ids are included in the Defect Collection. + + The procedure is as follows: + + 1. Get all tasks with standard "defect" query tag + 2. Filter all tasks by skipping tasks which are already in the Defect Store + 3. Get all tasks that could be used as bulk + 4. Filter all bulks which do not have corresponding Dielectric and + ElectronicStructure data (if a band gap exists for that task). + 5. Group defect tasks by defect matching + 6. Given defect object in a group, bundle them with bulk tasks + identified with structure matching + 7. Yield the item bundles + + Returns: + Iterator of (defect documents, task bundles) + + The defect document is an existing defect doc to be updated with new data, or None + + task bundles bundle are all the tasks that correspond to the same defect and all possible + bulk tasks that could be matched to them. + """ + + self.logger.info("Defect builder started") + self.logger.info( + f"Allowed defect types: {[task_type.value for task_type in self.allowed_dfct_types]}" + ) + self.logger.info( + f"Allowed bulk types: {[task_type.value for task_type in self.allowed_bulk_types]}" + ) + + self.logger.info("Setting indexes") + self.ensure_indexes() + + # Save timestamp to mark buildtime for material documents + self.timestamp = datetime.utcnow() + + self.logger.info("Finding tasks to process") + + ##### Get defect tasks ##### + temp_query = self.query.copy() + temp_query.update( + {d: {"$exists": True, "$ne": None} for d in self.required_defect_properties} + ) + temp_query.update( + {self.defect_query: {"$exists": True}, "output.state": "successful"} + ) + defect_tasks = { + doc[self.tasks.key] + for doc in self.tasks.query( + criteria=temp_query, properties=[self.tasks.key] + ) + } + + # TODO Seems slow + not_allowed = { + doc[self.tasks.key] + for doc in self.tasks.query( + criteria={self.tasks.key: {"$in": list(defect_tasks)}}, + properties=["output.calcs_reversed"], + ) + if TaskType(doc["output"]["calcs_reversed"][0]["task_type"]) + not in self.allowed_dfct_types + } + if not_allowed: + self.logger.debug( + f"{len(not_allowed)} defect tasks dropped. Not allowed TaskType" + ) + defect_tasks = defect_tasks - not_allowed + + ##### Get bulk tasks ##### + temp_query = self.bulk_query.copy() + temp_query.update( + {d: {"$exists": True, "$ne": None} for d in self.required_bulk_properties} + ) + temp_query.update( + {self.defect_query: {"$exists": False}, "output.state": "successful"} + ) + bulk_tasks = { + doc[self.tasks.key] + for doc in self.tasks.query( + criteria=temp_query, properties=[self.tasks.key] + ) + } + + # TODO seems slow + not_allowed = { + doc[self.tasks.key] + for doc in self.tasks.query( + criteria={self.tasks.key: {"$in": list(bulk_tasks)}}, + properties=["output.calcs_reversed"], + ) + if TaskType(doc["output"]["calcs_reversed"][0]["task_type"]) + not in self.allowed_bulk_types + } + if not_allowed: + self.logger.debug( + f"{len(not_allowed)} bulk tasks dropped. Not allowed TaskType" + ) + bulk_tasks = bulk_tasks - not_allowed + + # TODO Not the same validation behavior as material builders? + # If validation store exists, find tasks that are invalid and remove them + if self.task_validation: + validated = { + doc[self.tasks.key] + for doc in self.task_validation.query({}, [self.task_validation.key]) + } + + defect_tasks = defect_tasks.intersection(validated) + bulk_tasks = bulk_tasks.intersection(validated) + + invalid_ids = { + doc[self.tasks.key] + for doc in self.task_validation.query( + {"is_valid": False}, [self.task_validation.key] + ) + } + self.logger.info(f"Removing {len(invalid_ids)} invalid tasks") + defect_tasks = defect_tasks - invalid_ids + bulk_tasks = bulk_tasks - invalid_ids + + processed_defect_tasks = { + t_id + for d in self.defects.query({}, ["task_ids"]) + for t_id in d.get("task_ids", []) + } + all_tasks = defect_tasks | bulk_tasks + + self.logger.debug(f"All tasks: {len(all_tasks)}") + self.logger.debug(f"Bulk tasks before filter: {len(bulk_tasks)}") + bulk_tasks = set(filter(self.__preprocess_bulk, bulk_tasks)) + self.logger.debug(f"Bulk tasks after filter: {len(bulk_tasks)}") + self.logger.debug(f"All defect tasks: {len(defect_tasks)}") + unprocessed_defect_tasks = defect_tasks - processed_defect_tasks + + if not unprocessed_defect_tasks: + self.logger.info("No unprocessed defect tasks. Exiting") + return + elif not bulk_tasks: + self.logger.info("No compatible bulk calculations. Exiting.") + return + + self.logger.info( + f"Found {len(unprocessed_defect_tasks)} unprocessed defect tasks" + ) + self.logger.info( + f"Found {len(bulk_tasks)} bulk tasks with dielectric properties" + ) + + # Set total for builder bars to have a total + self.total = len(unprocessed_defect_tasks) + + # yield list of defects that are of the same type, matched to an appropriate bulk calc + self.logger.info(f"Starting defect matching.") + + for defect, defect_task_group in self.__filter_and_group_tasks( + unprocessed_defect_tasks + ): + task_ids = self.__match_defects_to_bulks(bulk_tasks, defect_task_group) + if not task_ids: + continue + doc = self.__get_defect_doc(defect) + if doc: + self.logger.info(f"DOC IS {doc.defect.__repr__()}") + item_bundle = self.__get_item_bundle(task_ids) + m = next(iter(task_ids.values()))[1] + material_id = self.mpid_map[m] + yield doc, item_bundle, material_id, defect_task_group + + def process_item(self, items): + """ + Process a group of defect tasks that correspond to the same defect into a single defect + document. If the DefectDoc already exists, then update it and return it. If it does not, + create a new DefectDoc + + Args: + items: (DefectDoc or None, [(defect task dict, bulk task dict, dielectric dict), ... ] + + returns: the defect document as a dictionary + """ + defect_doc, item_bundle, material_id, task_ids = items + self.logger.info( + f"Processing group of {len(item_bundle)} defects into DefectDoc" + ) + if item_bundle: + for _, (defect_task, bulk_task, dielectric) in item_bundle.items(): + if not defect_doc: + defect_doc = DefectDoc.from_tasks( + defect_task=defect_task, + bulk_task=bulk_task, + dielectric=dielectric, + query=self.defect_query, + key=self.tasks.key, + material_id=material_id, + ) + else: + defect_doc.update_one( + defect_task, + bulk_task, + dielectric, + query=self.defect_query, + key=self.tasks.key, + ) # TODO Atomate2Store wrapper + defect_doc.task_ids = list( + set(task_ids + defect_doc.task_ids) + ) # TODO should I store the bulk id too? + return jsanitize( + defect_doc.dict(), allow_bson=True, enum_values=True, strict=True + ) + return {} + + def update_targets(self, items): + """ + Inserts the new task_types into the task_types collection + """ + + items = [item for item in items if item] + + if len(items) > 0: + self.logger.info(f"Updating {len(items)} defects") + for item in items: + item.update({"_bt": self.timestamp}) + self.defects.remove_docs( + { + "task_ids": item["task_ids"], + } + ) + self.defects.update(items, key="task_ids") + else: + self.logger.info("No items to update") + + def __filter_and_group_tasks(self, tasks): + """ + Groups defect tasks. Tasks are grouped according to the reduced representation + of the defect, and so tasks with different settings (e.g. supercell size, functional) + will be grouped together. + + Args: + tasks: task_ids (according to self.tasks.key) for unprocessed defects + + returns: + [ (defect, [task_ids] ), ...] where task_ids correspond to the same defect + """ + + props = [self.defect_query, self.tasks.key, "output.structure"] + + self.logger.debug(f"Finding equivalent tasks for {len(tasks)} defects") + + sm = StructureMatcher(allow_subset=False) # TODO build settings + defects = [ + { + self.tasks.key: t[self.tasks.key], + "defect": self.__get_defect_from_task(t), + "structure": Structure.from_dict(t["output"]["structure"]), + } + for t in self.tasks.query( + criteria={self.tasks.key: {"$in": list(tasks)}}, properties=props + ) + ] + for d in defects: + # TODO remove oxidation state because spins/oxidation cause errors in comparison. + # but they shouldnt if those props are close in value + d["structure"].remove_oxidation_states() + d["defect"].user_charges = [d["structure"].charge] + + def key(x): + s = x["defect"].structure + return get_sg(s), s.composition.reduced_composition + + def are_equal(x, y): + """To decide if defects are equal.""" + if x["structure"].charge != y["structure"].charge: + return False + if x["defect"] == y["defect"]: + return True + return False + + sorted_s_list = sorted(enumerate(defects), key=lambda x: key(x[1])) + all_groups = [] + + # For each pre-grouped list of structures, perform actual matching. + for k, g in groupby(sorted_s_list, key=lambda x: key(x[1])): + unmatched = list(g) + while len(unmatched) > 0: + i, refs = unmatched.pop(0) + matches = [i] + inds = list( + filter( + lambda j: are_equal(refs, unmatched[j][1]), + list(range(len(unmatched))), + ) + ) + matches.extend([unmatched[i][0] for i in inds]) + unmatched = [ + unmatched[i] for i in range(len(unmatched)) if i not in inds + ] + all_groups.append( + ( + defects[i]["defect"], + [defects[i][self.tasks.key] for i in matches], + ) + ) + + self.logger.debug(f"{len(all_groups)} groups") + return all_groups + + def __get_defect_from_task(self, task): + """ + Using the defect_query property, retrieve a pymatgen defect object from the task document + """ + defect = unpack(self.defect_query, task) + return MontyDecoder().process_decoded(defect) + + def __get_defect_doc(self, defect): + """ + Given a defect, find the DefectDoc corresponding to it in the defects store if it exists + + returns: DefectDoc or None + """ + material_id = self._get_mpid(defect.structure) + docs = [ + DefectDoc(**doc) + for doc in self.defects.query( + criteria={"material_id": material_id}, properties=None + ) + ] + for doc in docs: + if self.__defect_match(defect, doc.defect): + return doc + return None + + def __defect_match(self, x, y): + """Match two defects, including there charges""" + sm = StructureMatcher() + if x.user_charges[0] != y.user_charges[0]: + return False + + # Elem. changes needed to distinguish ghost vacancies + if x.element_changes == y.element_changes and sm.fit( + x.defect_structure, y.defect_structure + ): + return True + + return False + + # TODO should move to returning dielectric doc or continue returning the total diel tensor? + def __get_dielectric(self, key): + """ + Given a bulk task's task_id, find the material_id, and then use it to query the dielectric store + and retrieve the total dielectric tensor for defect analysis. If no dielectric exists, as would + be the case for metallic systems, return None. + """ + for diel in self.dielectric.query( + criteria={"material_id": key}, properties=["total"] + ): + return diel["total"] + return None + + # TODO retrieving the electrostatic potential is by far the most expesive part of the builder. Any way to reduce? + def __get_item_bundle(self, task_ids): + """ + Gets a group of items that can be processed together into a defect document. + + Args: + bulk_tasks: possible bulk tasks to match to defects + defect_task_group: group of equivalent defects (defined by PointDefectComparator) + + returns: dict {run type: (defect task dict, bulk_task_dict, dielectric dict)} + """ + return { + rt: ( + self.tasks.query_one(criteria={self.tasks.key: pairs[0]}, load=True), + self.tasks.query_one(criteria={self.tasks.key: pairs[1]}, load=True), + self.__get_dielectric(self._mpid_map[pairs[1]]), + ) + for rt, pairs in task_ids.items() + } + + def _get_mpid(self, structure): + """ + Given a structure, determine if an equivalent structure exists, with a material_id, + in the materials store. + + Args: + structure: Candidate structure + + returns: material_id, if one exists, else None + """ + sga = SpacegroupAnalyzer( + structure, symprec=self.settings.SYMPREC + ) # TODO Add angle tolerance + mats = self.materials.query( + criteria={ + "chemsys": structure.composition.chemical_system, + }, + properties=["structure", "material_id"], + ) + # TODO coudl more than one material match true? + sm = StructureMatcher( + primitive_cell=True, comparator=ElementComparator() + ) # TODO add tolerances + for m in mats: + if sm.fit(structure, Structure.from_dict(m["structure"])): + return m["material_id"] + return None + + def __match_defects_to_bulks(self, bulk_ids, defect_ids) -> list[tuple]: + """ + Given task_ids of bulk and defect tasks, match the defects to a bulk task that has + commensurate: + - Composition + - Number of sites + - Symmetry + """ + self.logger.debug(f"Finding bulk/defect task combinations.") + self.logger.debug(f"Bulk tasks: {bulk_ids}") + self.logger.debug(f"Defect tasks: {defect_ids}") + + # TODO mongo projection on array doesn't work (see above) + props = [ + self.tasks.key, + self.defect_query, + "output.input", + "output.nsites", + "output.output.structure", + "output.output.energy", + "output.calcs_reversed", + ] + defects = list( + self.tasks.query( + criteria={self.tasks.key: {"$in": list(defect_ids)}}, properties=props + ) + ) + ps = self.__get_pristine_supercell(defects[0]) + ps.remove_oxidation_states() # TODO might cause problems + bulks = list( + self.tasks.query( + criteria={ + self.tasks.key: {"$in": list(bulk_ids)}, + "output.formula_pretty": jsanitize(ps.composition.reduced_formula), + }, + properties=props, + ) + ) + + pairs = [ + (defect, bulk) + for bulk in bulks + for defect in defects + if self.__are_bulk_and_defect_commensurate(bulk, defect) + ] + self.logger.debug(f"Found {len(pairs)} commensurate bulk/defect pairs") + + def key(x): + return -x[0]["output"]["nsites"], x[0]["output"]["output"]["energy"] + + def _run_type(x): + return x[0]["output"]["calcs_reversed"][0]["run_type"] + + rt_pairs = {} + for rt, group in groupby(pairs, key=_run_type): + rt_pairs[rt] = [ + (defect[self.tasks.key], bulk[self.tasks.key]) + for defect, bulk in sorted(list(group), key=key) + ] + + # Return only the first (best) pair for each rt + return {rt: lst[0] for rt, lst in rt_pairs.items()} + + # TODO Checking for same dft settings (e.g. OT/diag) is a little cumbersome. + # Maybe, in future, task doc can be defined to have OT/diag as part of input summary + # for fast querying + def __are_bulk_and_defect_commensurate(self, b, d): + """ + Check if a bulk and defect task are commensurate. + + Checks for: + 1. Same run type. + 2. Same pristine structures with no supercell reduction + 3. Compatible DFT settings + """ + # TODO add settings + sm = StructureMatcher( + ltol=1e-3, + stol=0.1, + angle_tol=1, + primitive_cell=False, + scale=True, + attempt_supercell=False, + allow_subset=False, + comparator=ElementComparator(), + ) + rtb = b.get("output").get("input").get("xc").split("+U")[0] + rtd = d.get("output").get("input").get("xc").split("+U")[0] + baux = { + dat["element"]: dat.get("auxiliary_basis") + for dat in b["output"]["input"]["atomic_kind_info"]["atomic_kinds"].values() + } + daux = { + dat["element"]: dat.get("auxiliary_basis") + for dat in d["output"]["input"]["atomic_kind_info"]["atomic_kinds"].values() + } + + if rtb == rtd: + if sm.fit( + self.__get_pristine_supercell(d), self.__get_pristine_supercell(b) + ): + cib = Cp2kInput.from_dict( + b["output"]["calcs_reversed"][0]["input"]["cp2k_input"] + ) + cid = Cp2kInput.from_dict( + d["output"]["calcs_reversed"][0]["input"]["cp2k_input"] + ) + bis_ot = cib.check("force_eval/dft/scf/ot") + dis_ot = cid.check("force_eval/dft/scf/ot") + if (bis_ot and dis_ot) or (not bis_ot and not dis_ot): + for el in baux: + if baux[el] != daux[el]: + return False + return True + return False + + def __preprocess_bulk(self, task): + """ + Given a TaskDoc that could be a bulk for defect analysis, check to see if it can be used. Bulk + tasks must have: + + (1) Correspond to an existing material_id in the materials store + (2) If the bulk is not a metal, then the dielectric tensor must exist in the dielectric store + (3) If bulk is not a metal, electronic structure document must exist in the store + + """ + self.logger.debug(f"Preprocessing bulk task {task}") + t = next( + self.tasks.query( + criteria={self.tasks.key: task}, + properties=["output.output.structure", "mpid"], + ) + ) + + struc = Structure.from_dict( + t.get("output").get("output").get("structure") + ) # TODO specific to atomate2 + mpid = self._get_mpid(struc) + if not mpid: + self.logger.debug(f"No material id found for bulk task {task}") + return False + self._mpid_map[task] = mpid + self.logger.debug(f"Material ID: {mpid}") + + elec = self.electronic_structure.query_one( + properties=["band_gap"], criteria={self.electronic_structure.key: mpid} + ) + if not elec: + self.logger.debug(f"Electronic structure data not found for {mpid}") + return False + + # TODO right now pulling dos from electronic structure, should just pull summary document + if elec["band_gap"] > 0: + diel = self.__get_dielectric(mpid) + if not diel: + self.logger.info( + f"Task {task} for {mpid} ({struc.composition.reduced_formula}) requires " + f"dielectric properties, but none found in dielectric store" + ) + return False + + return True + + def __get_pristine_supercell(self, task): + """ + Given a task document for a defect calculation, retrieve the un-defective, pristine supercell. + - If defect transform exists, the following transform's input will be returned + - If no follow up transform exists, the calculation input will be returned + + If defect cannot be found in task, return the input structure. + + scale_matrix = np.array(scaling_matrix, int) + if scale_matrix.shape != (3, 3): + scale_matrix = np.array(scale_matrix * np.eye(3), int) + new_lattice = Lattice(np.dot(scale_matrix, self._lattice.matrix)) + """ + d = unpack(query=self.defect_query, d=task) + out_structure = MontyDecoder().process_decoded( + task["output"]["output"]["structure"] + ) + if d: + defect = MontyDecoder().process_decoded(d) + s = defect.structure.copy() + sc_mat = out_structure.lattice.matrix.dot(np.linalg.inv(s.lattice.matrix)) + s.make_supercell(sc_mat.round()) + return s + else: + return out_structure + + +class DefectiveMaterialBuilder(Builder): + + """ + This builder creates collections of the DefectThermoDoc object. + + (1) Find all DefectDocs that correspond to the same bulk material + given by material_id + (2) Create a new DefectThermoDoc for all of those documents + (3) Insert/Update the defect_thermos store with the new documents + """ + + def __init__( + self, + defects: Store, + defect_thermos: Store, + materials: Store, + query: Optional[Dict] = None, + **kwargs, + ): + """ + Args: + defects: Store of defect documents (generated by DefectBuilder) + defect_thermos: Store of DefectThermoDocs to generate. + materials: Store of MaterialDocs to construct phase diagram + electronic_structures: Store of DOS objects + query: dictionary to limit tasks to be analyzed + """ + + self.defects = defects + self.defect_thermos = defect_thermos + self.materials = materials + + self.query = query if query else {} + self.timestamp = None + self.kwargs = kwargs + + super().__init__( + sources=[defects, materials], targets=[defect_thermos], **kwargs + ) + + def ensure_indexes(self): + """ + Ensures indicies on the collections + """ + + # Basic search index for tasks + self.defects.ensure_index("material_id") + self.defects.ensure_index("defect_id") + + # Search index for materials + self.defect_thermos.ensure_index("material_id") + + # TODO need to only process new tasks. Fast builder so currently is OK for small collections + def get_items(self) -> Iterator[List[Dict]]: + """ + Gets items to process into DefectThermoDocs. + + returns: + iterator yielding tuples containing: + - group of DefectDocs belonging to the same bulk material as indexed by material_id, + - materials in the chemsys of the bulk material for constructing phase diagram + - Dos of the bulk material for constructing phase diagrams/getting doping + + """ + + self.logger.info("Defect thermo builder started") + self.logger.info("Setting indexes") + self.ensure_indexes() + + # Save timestamp to mark build time for defect thermo documents + self.timestamp = datetime.utcnow() + + # Get all tasks + self.logger.info("Finding tasks to process") + temp_query = dict(self.query) + temp_query["state"] = "successful" + + # unprocessed_defect_tasks = all_tasks - processed_defect_tasks + + all_docs = [doc for doc in self.defects.query(self.query)] + + self.logger.debug(f"Found {len(all_docs)} defect docs to process") + + def filterfunc(x): + if not self.materials.query_one( + criteria={"material_id": x["material_id"]}, properties=None + ): + self.logger.debug( + f"No material with MPID={x['material_id']} in the material store" + ) + return False + return True + defect = MontyDecoder().process_decoded(x["defect"]) + for el in defect.element_changes: + if el not in self.thermo: + self.logger.debug(f"No entry for {el} in Thermo Store") + return False + + return True + + for key, group in groupby( + filter(filterfunc, sorted(all_docs, key=lambda x: x["material_id"])), + key=lambda x: x["material_id"], + ): + try: + yield list(group) + except LookupError as exception: + raise exception + + def process_item(self, defects): + """ + Process a group of defects belonging to the same material into a defect thermo doc + """ + defect_docs = [DefectDoc(**d) for d in defects] + self.logger.info(f"Processing {len(defect_docs)} defects") + defect_thermo_doc = DefectiveMaterialDoc.from_docs( + defect_docs, material_id=defect_docs[0].material_id + ) + return defect_thermo_doc.dict() + + def update_targets(self, items): + """ + Inserts the new DefectThermoDocs into the defect_thermos store + """ + items = [item for item in items if item] + for item in items: + item.update({"_bt": self.timestamp}) + + if len(items) > 0: + self.logger.info(f"Updating {len(items)} defect thermo docs") + self.defect_thermos.update( + docs=jsanitize(items, allow_bson=True, enum_values=True, strict=True), + key=self.defect_thermos.key, + ) + else: + self.logger.info("No items to update") + + def __get_electronic_structure(self, material_id): + """ + Gets the electronic structure of the bulk material + """ + self.logger.info(f"Getting electronic structure for {material_id}") + + # TODO This is updated to return the whole query because a.t.m. the + # DOS part of the electronic builder isn't working, so I'm using + # this to pull direct from the store of dos objects with no processing. + dosdoc = self.electronic_structures.query_one( + criteria={self.electronic_structures.key: material_id}, + properties=None, + ) + t_id = ElectronicStructureDoc(**dosdoc).dos.total["1"].task_id + dos = self.dos.query_one( + criteria={"task_id": int(t_id)}, properties=None + ) # TODO MPID str/int issues + return dos + + def __get_materials(self, key) -> List: + """ + Given a group of DefectDocs, use the bulk material_id to get materials in the chemsys from the + materials store. + """ + bulk = self.materials.query_one(criteria={"material_id": key}, properties=None) + if not bulk: + raise LookupError( + f"The bulk material ({key}) for these defects cannot be found in the materials store" + ) + return MaterialsDoc(**bulk) + + def __get_thermos(self, composition) -> List: + return list( + self.thermo.query(criteria={"elements": {"$size": 1}}, properties=None) + ) + + +class DefectValidator(Builder): + def __init__( + self, + tasks: Store, + defect_validation: Store, + chunk_size: int = 1000, + defect_query="output.additional_json.info.defect", + ): + self.tasks = tasks + self.defect_validation = defect_validation + self.chunk_size = chunk_size + self.defect_query = defect_query + super().__init__( + sources=tasks, targets=defect_validation, chunk_size=chunk_size + ) + + def get_items(self): + self.logger.info("Getting tasks") + tids = list( + self.tasks.query( + criteria={self.defect_query: {"$exists": True}}, + properties=[self.tasks.key], + ) + ) + self.logger.info(f"{len(tids)} to process") + yield from self.tasks.query() + + def process_item(self, item): + from atomate2.cp2k.schemas.defect import DefectValidation + + tid = item[self.tasks.key] + return jsanitize( + DefectValidation.process_task(item, tid).dict(), + allow_bson=True, + enum_values=True, + strict=True, + ) + + def update_targets(self, items: List): + """ + Inserts the new task_types into the task_types collection + """ + items = [item for item in items if item] + if len(items) > 0: + self.logger.info(f"Updating {len(items)} defects") + self.defect_validation.update(items, key=self.defect_validation.key) + else: + self.logger.info("No items to update") + return super().update_targets(items) + + +def unpack(query, d): + """ + Unpack a mongo-style query into dictionary retrieval + """ + if not d: + return None + if not query: + return d + if isinstance(d, List): + return unpack(query[1:], d.__getitem__(int(query.pop(0)))) + if isinstance(query, str): + for seperator in [".", ":", "->"]: + tmp = query.split(seperator) + if len(tmp) > 1: + return unpack(query.split("."), d) + return unpack(query[1:], d.__getitem__(query.pop(0))) + + +# TODO SHOULD GO IN COMMON +def get_sg(struc, symprec=0.01) -> int: + """helper function to get spacegroup with a loose tolerance""" + try: + return struc.get_space_group_info(symprec=symprec)[1] + except Exception: + return -1 diff --git a/src/atomate2/cp2k/drones.py b/src/atomate2/cp2k/drones.py index b009e61d41..efa7a2beb0 100644 --- a/src/atomate2/cp2k/drones.py +++ b/src/atomate2/cp2k/drones.py @@ -1,4 +1,4 @@ -"""Drones for parsing VASP calculations and related outputs.""" +"""Drones for parsing CP2K calculations and related outputs.""" from __future__ import annotations diff --git a/src/atomate2/cp2k/flows/defect.py b/src/atomate2/cp2k/flows/defect.py new file mode 100644 index 0000000000..adc705037f --- /dev/null +++ b/src/atomate2/cp2k/flows/defect.py @@ -0,0 +1,311 @@ +"""Flows used in the calculation of defect properties.""" + +from __future__ import annotations + +import logging +from copy import deepcopy +from dataclasses import dataclass, field +from pathlib import Path +from typing import Iterable, Literal, Mapping + +from jobflow import Flow, Maker, OutputReference, job +from numpy.typing import NDArray +from pymatgen.analysis.defects.core import Defect +from pymatgen.analysis.defects.supercells import get_sc_fromstruct +from pymatgen.analysis.defects.thermo import DefectEntry +from pymatgen.entries.computed_entries import ComputedStructureEntry +from pymatgen.io.common import VolumetricData + +from atomate2.cp2k.flows.core import ( + HybridCellOptFlowMaker, + HybridRelaxFlowMaker, + HybridStaticFlowMaker, +) +from atomate2.cp2k.jobs.base import BaseCp2kMaker +from atomate2.cp2k.jobs.defect import ( + DefectCellOptMaker, + DefectHybridCellOptMaker, + DefectHybridRelaxMaker, + DefectHybridStaticMaker, + DefectRelaxMaker, + DefectStaticMaker, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class DefectHybridStaticFlowMaker(HybridStaticFlowMaker): + + pbe_maker: BaseCp2kMaker = field(default_factory=DefectStaticMaker) + hybrid_maker: BaseCp2kMaker = field( + default=DefectHybridStaticMaker( + copy_cp2k_kwargs={"additional_cp2k_files": ("info.json",)} + ) + ) + + +@dataclass +class DefectHybridRelaxFlowMaker(HybridRelaxFlowMaker): + + pbe_maker: BaseCp2kMaker = field(default_factory=DefectStaticMaker) + hybrid_maker: BaseCp2kMaker = field( + default=DefectHybridRelaxMaker( + copy_cp2k_kwargs={"additional_cp2k_files": ("info.json",)} + ) + ) + + +@dataclass +class DefectHybridCellOptFlowMaker(HybridCellOptFlowMaker): + + pbe_maker: BaseCp2kMaker = field(default_factory=DefectStaticMaker) + hybrid_maker: BaseCp2kMaker = field( + default=DefectHybridCellOptMaker( + copy_cp2k_kwargs={"additional_cp2k_files": ("info.json",)} + ) + ) + + +# TODO close to being able to put this in common. Just need a switch that decides +# which core flow/job to use based on software +@dataclass +class FormationEnergyMaker(Maker): + """ + Run a collection of defect jobs and (possibly) the bulk supercell + for determination of defect formation energies. + + Parameters + ---------- + name: This flow's name. i.e. "defect formation energy" + run_bulk: whether to run the bulk supercell as a static ("static") + calculation, a full relaxation ("relax"), or to skip it (False) + hybrid_functional: If provided, this activates hybrid version of the + workflow. Provide functional as a parameter that the input set + can recognize. e.g. "PBE0" or "HSE06" + initialize_with_pbe: If hybrid functional is provided, this enables + the use of a static PBE run before the hybrid calc to provide a + starting guess for CP2K HF module. + supercell_matrix: If provided, the defect supercell will be created + by this 3x3 matrix. Else other parameters will be used. + max_atoms: Maximum number of atoms allowed in the supercell. + min_atoms: Minimum number of atoms allowed in the supercell. + min_length: Minimum length of the smallest supercell lattice + vector. + force_diagonal: If True, return a transformation with a + diagonal transformation matrix. + """ + + name: str = "defect formation energy" + run_bulk: Literal["static", "relax"] | bool = field(default="static") + hybrid_functional: str | None = field(default=None) + initialize_with_pbe: bool = field(default=True) + + supercell_matrix: NDArray = field(default=None) + min_atoms: int = field(default=80) + max_atoms: int = field(default=240) + min_length: int = field(default=10) + force_diagonal: bool = field(default=False) + + def __post_init__(self): + if self.run_bulk == "relax": + if self.hybrid_functional: + self.bulk_maker = DefectHybridCellOptMaker( + name="bulk hybrid relax", + transformations=None, + initialize_with_pbe=self.initialize_with_pbe, + hybrid_functional=self.hybrid_functional, + ) + else: + self.bulk_maker = DefectCellOptMaker( + name="bulk relax", transformations=None + ) + + elif self.run_bulk == "static": + if self.hybrid_functional: + self.bulk_maker = DefectHybridStaticFlowMaker( + name="bulk hybrid static", + initialize_with_pbe=self.initialize_with_pbe, + hybrid_functional=self.hybrid_functional, + ) + else: + self.bulk_maker = DefectStaticMaker(name="bulk static") + + if self.hybrid_functional: + self.def_maker = DefectHybridRelaxFlowMaker( + hybrid_functional=self.hybrid_functional, + initialize_with_pbe=self.initialize_with_pbe, + ) + self.def_maker.pbe_maker.supercell_matrix = self.supercell_matrix + self.def_maker.hybrid_maker.supercell_matrix = self.supercell_matrix + + self.def_maker.pbe_maker.max_atoms = self.max_atoms + self.def_maker.hybrid_maker.max_atoms = self.max_atoms + + self.def_maker.pbe_maker.min_atoms = self.min_atoms + self.def_maker.hybrid_maker.min_atoms = self.min_atoms + + self.def_maker.pbe_maker.min_length = self.min_length + self.def_maker.hybrid_maker.min_length = self.min_length + + self.def_maker.pbe_maker.force_diagonal = self.force_diagonal + self.def_maker.hybrid_maker.force_diagonal = self.force_diagonal + + else: + self.def_maker = DefectRelaxMaker() + self.def_maker.supercell_matrix = self.supercell_matrix + self.def_maker.max_atoms = self.max_atoms + self.def_maker.min_atoms = self.min_atoms + self.def_maker.min_length = self.min_length + self.def_maker.force_diagonal = self.force_diagonal + + def make( + self, + defects: Iterable[Defect], + charges: bool | Iterable[int] = False, + dielectric: NDArray | int | float | None = None, + prev_cp2k_dir: str | Path | None = None, + collect_outputs: bool = True, + ): + """Make a flow to run multiple defects in order to calculate their formation + energy diagram. + + Parameters + ---------- + defects: list[Defect] + List of defects objects to calculate the formation energy diagram for. + prev_cp2k_dir: str | Path | None + If provided, this acts as prev_dir for the bulk calculation only + Returns + ------- + flow: Flow + The workflow to calculate the formation energy diagram. + """ + jobs = [] + defect_outputs: dict[str, dict[int, tuple[Defect, OutputReference]]] = { + defect.name: {} for defect in defects + } # TODO DEFECT NAMES ARE NOT UNIQUE HASHES + bulk_structure = ensure_defects_same_structure(defects) + + sc_mat = ( + self.supercell_matrix + if self.supercell_matrix + else get_sc_fromstruct( + bulk_structure, + self.min_atoms, + self.max_atoms, + self.min_length, + self.force_diagonal, + ) + ) + + if self.run_bulk: + s = bulk_structure.copy() + s.make_supercell(sc_mat) + bulk_job = self.bulk_maker.make( + bulk_structure * sc_mat, prev_cp2k_dir=prev_cp2k_dir + ) + jobs.append(bulk_job) + + for defect in defects: + if charges is True: + chgs = defect.get_charge_states() + else: + chgs = charges if charges else [0] + for charge in chgs: + dfct = deepcopy(defect) + dfct.user_charges = [charge] + defect_job = self.def_maker.make(dfct) + jobs.append(defect_job) + defect_outputs[defect.name][int(charge)] = (defect, defect_job.output) + + if self.run_bulk and defects and collect_outputs: + collect_job = collect_defect_outputs( + defect_outputs=defect_outputs, + bulk_output=bulk_job.output if self.run_bulk else None, + dielectric=dielectric, + ) + jobs.append(collect_job) + else: + collect_job = None + return Flow( + jobs=jobs, + name=self.name, + output=jobs[-1].output if collect_job else None, + ) + + +# TODO this is totally code agnostic and should be in common +@job +def collect_defect_outputs( + defect_outputs: Mapping[str, Mapping[int, OutputReference]], + bulk_output: OutputReference, + dielectric: NDArray | int | float | None, +) -> dict: + """Collect all the outputs from the defect calculations. + This job will combine the structure and entry fields to create a + ComputerStructureEntry object. + Parameters + ---------- + defects_output: + The output from the defect calculations. + bulk_sc_dir: + The directory containing the bulk supercell calculation. + dielectric: + The dielectric constant used to construct the formation energy diagram. + """ + outputs: dict[str, dict[str, dict]] = {"results": {}} + if not dielectric: + logger.warn( + "Dielectric constant not provided. Defect formation energies will be uncorrected." + ) + for defect_name, defects_with_charges in defect_outputs.items(): + defect_entries = [] + fnv_plots = {} + for charge, defect_and_output in defects_with_charges.items(): + defect, output_with_charge = defect_and_output + logger.info(f"Processing {defect_name} with charge state={charge}") + defect_entry = DefectEntry( + defect=defect, + charge_state=charge, + sc_entry=ComputedStructureEntry( + structure=bulk_output.structure, + energy=output_with_charge.output.energy - bulk_output.output.energy, + ), + ) + defect_entries.append(defect_entry) + plot_data = defect_entry.get_freysoldt_correction( + defect_locpot=VolumetricData.from_dict( + output_with_charge.cp2k_objects["v_hartree"] + ), + bulk_locpot=VolumetricData.from_dict( + output_with_charge.cp2k_objects["v_hartree"] + ), + dielectric=dielectric, + ) + fnv_plots[int(charge)] = plot_data + outputs["results"][defect_name] = dict( + defect=defect, defect_entries=defect_entries, fnv_plots=fnv_plots + ) + return outputs + + +# TODO should be in common +def ensure_defects_same_structure(defects: Iterable[Defect]): + """Ensure that the defects are valid. + Parameters + ---------- + defects + The defects to check. + Raises + ------ + ValueError + If any defect is invalid. + """ + struct = None + for defect in defects: + if struct is None: + struct = defect.structure + elif struct != defect.structure: + raise ValueError("All defects must have the same host structure.") + return struct diff --git a/src/atomate2/cp2k/jobs/defect.py b/src/atomate2/cp2k/jobs/defect.py new file mode 100644 index 0000000000..fcaeec9817 --- /dev/null +++ b/src/atomate2/cp2k/jobs/defect.py @@ -0,0 +1,184 @@ +"""Jobs for defect calculations.""" + +from __future__ import annotations + +import logging +from copy import deepcopy +from dataclasses import dataclass, field +from pathlib import Path + +from numpy.typing import NDArray +from pymatgen.analysis.defects.core import Defect, Vacancy +from pymatgen.core import Structure + +from atomate2.cp2k.jobs.base import BaseCp2kMaker, cp2k_job +from atomate2.cp2k.sets.base import Cp2kInputGenerator, recursive_update +from atomate2.cp2k.sets.defect import ( + DefectCellOptSetGenerator, + DefectHybridCellOptSetGenerator, + DefectHybridRelaxSetGenerator, + DefectHybridStaticSetGenerator, + DefectRelaxSetGenerator, + DefectStaticSetGenerator, +) + +logger = logging.getLogger(__name__) + +DEFECT_TASK_DOC = { + "average_v_hartree": True, + "store_volumetric_data": ("v_hartree",), +} + + +@dataclass +class BaseDefectMaker(BaseCp2kMaker): + + task_document_kwargs: dict = field(default_factory=lambda: DEFECT_TASK_DOC) + supercell_matrix: NDArray = field(default=None) + min_atoms: int = field(default=80) + max_atoms: int = field(default=240) + min_length: int = field(default=10) + force_diagonal: bool = field(default=False) + + @cp2k_job + def make(self, defect: Defect | Structure, prev_cp2k_dir: str | Path | None = None): + if isinstance(defect, Defect): + + structure = defect.get_supercell_structure( + sc_mat=self.supercell_matrix, + dummy_species=defect.site.species + if isinstance(defect, Vacancy) + else None, + min_atoms=self.min_atoms, + max_atoms=self.max_atoms, + min_length=self.min_length, + force_diagonal=self.force_diagonal, + ) + + if isinstance(defect, Vacancy): + structure.add_site_property( + "ghost", [False] * (len(structure.sites) - 1) + [True] + ) + + if defect.user_charges: + if len(defect.user_charges) > 1: + raise ValueError( + "Multiple user charges found. Individual defect jobs can only contain 1." + ) + else: + charge = defect.user_charges[0] + else: + charge = 0 + + # provenance stuff + recursive_update( + self.write_additional_data, + { + "info.json": { + "defect": deepcopy(defect), + "sc_mat": self.supercell_matrix, + } + }, + ) + + else: + structure = deepcopy(defect) + charge = structure.charge + + structure.set_charge(charge) + return super().make.original( + self, structure=structure, prev_cp2k_dir=prev_cp2k_dir + ) + + +@dataclass +class DefectStaticMaker(BaseDefectMaker): + + name: str = field(default="defect static") + input_set_generator: Cp2kInputGenerator = field( + default_factory=DefectStaticSetGenerator + ) + + +@dataclass +class DefectRelaxMaker(BaseDefectMaker): + """ + Maker to create a relaxation job for point defects. + + Adds an initial random perturbation and ensures that the output contains + the hartree potential for finite size corrections. + """ + + name: str = field(default="defect relax") + input_set_generator: Cp2kInputGenerator = field( + default_factory=DefectRelaxSetGenerator + ) + transformations: tuple[str, ...] = field( + default=("PerturbStructureTransformation",) + ) + transformation_params: tuple[dict, ...] | None = field( + default=({"distance": 0.01},) + ) + + +@dataclass +class DefectCellOptMaker(BaseDefectMaker): + """ + Maker to create a cell for point defects. + + Adds an initial random perturbation and ensures that the output contains + the hartree potential for finite size corrections. + """ + + name: str = field(default="defect relax") + input_set_generator: Cp2kInputGenerator = field( + default_factory=DefectCellOptSetGenerator + ) + transformations: tuple[str, ...] = field( + default=("PerturbStructureTransformation",) + ) + transformation_params: tuple[dict, ...] | None = field( + default=({"distance": 0.01},) + ) + + +@dataclass +class DefectHybridStaticMaker(BaseDefectMaker): + + name: str = field(default="defect hybrid static") + hybrid_functional: str = "PBE0" + input_set_generator: Cp2kInputGenerator = field( + default_factory=DefectHybridStaticSetGenerator + ) + + +@dataclass +class DefectHybridRelaxMaker(BaseDefectMaker): + + name: str = field(default="defect hybrid relax") + hybrid_functional: str = "PBE0" + input_set_generator: Cp2kInputGenerator = field( + default_factory=DefectHybridRelaxSetGenerator + ) + transformations: tuple[str, ...] = field( + default=("PerturbStructureTransformation",) + ) + transformation_params: tuple[dict, ...] | None = field( + default=({"distance": 0.01},) + ) + + +@dataclass +class DefectHybridCellOptMaker(BaseDefectMaker): + + name: str = field(default="defect hybrid cell opt") + hybrid_functional: str = "PBE0" + input_set_generator: Cp2kInputGenerator = field( + default_factory=DefectHybridCellOptSetGenerator + ) + transformations: tuple[str, ...] = field( + default=("PerturbStructureTransformation",) + ) + transformation_params: tuple[dict, ...] | None = field( + default=({"distance": 0.01},) + ) diff --git a/src/atomate2/cp2k/schemas/defect.py b/src/atomate2/cp2k/schemas/defect.py new file mode 100644 index 0000000000..e78d01da96 --- /dev/null +++ b/src/atomate2/cp2k/schemas/defect.py @@ -0,0 +1,522 @@ +from datetime import datetime +from typing import Callable, ClassVar, Dict, List, Mapping, Set, Tuple, Type, TypeVar + +import numpy as np +from monty.json import MontyDecoder +from monty.tempfile import ScratchDir +from pydantic import BaseModel, Field +from pymatgen.analysis.defects.core import Adsorbate, Defect +from pymatgen.analysis.defects.corrections.freysoldt import ( + get_freysoldt2d_correction, + get_freysoldt_correction, +) +from pymatgen.analysis.defects.finder import DefectSiteFinder +from pymatgen.analysis.defects.thermo import DefectEntry, MultiFormationEnergyDiagram +from pymatgen.analysis.phase_diagram import PhaseDiagram +from pymatgen.core import Element +from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry +from pymatgen.io.cp2k.utils import get_truncated_coulomb_cutoff +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +from atomate2 import SETTINGS +from atomate2.common.schemas.structure import StructureMetadata +from atomate2.cp2k.schemas.calc_types.enums import RunType +from atomate2.cp2k.schemas.task import Cp2kObject, TaskDocument + +__all__ = ["DefectDoc"] + +T = TypeVar("T", bound="DefectDoc") +S = TypeVar("S", bound="DefectiveMaterialDoc") +V = TypeVar("V", bound="DefectValidation") + + +class DefectDoc(StructureMetadata): + """ + A document used to represent a single defect. e.g. a O vacancy with a -2 charge. + This document can contain an arbitrary number of defect entries, originating from + pairs (defect and bulk) of calculations. This document provides access to the "best" + calculation of each run_type. + """ + + property_name: ClassVar[str] = "defect" + defect: Defect = Field( + None, description="Pymatgen defect object for this defect doc" + ) + charge: int = Field(None, description="Charge state for this defect") + name: str = Field( + None, description="Name of this defect as generated by the defect object" + ) + material_id: str = Field( + None, description="Unique material ID for the bulk material" + ) # TODO Change to MPID + defect_ids: Mapping[RunType, str] = Field( + None, description="Map run types of defect entry to task id" + ) + bulk_ids: Mapping[RunType, str] = Field( + None, description="Map run types of bulk entry to task id" + ) + task_ids: List[str] = Field( + None, description="All defect task ids used in creating this defect doc." + ) + defect_entries: Mapping[RunType, DefectEntry] = Field( + None, description="Dictionary for tracking entries for CP2K calculations" + ) + bulk_entries: Mapping[RunType, ComputedStructureEntry] = Field( + None, description="Computed structure entry for the bulk calc." + ) + vbm: Mapping[RunType, float] = Field( + None, + description="VBM for bulk task of each run type. Used for aligning potential", + ) + last_updated: datetime = Field( + description="Timestamp for when this document was last updated", + default_factory=datetime.utcnow, + ) + created_at: datetime = Field( + description="Timestamp for when this material document was first created", + default_factory=datetime.utcnow, + ) + metadata: Dict = Field(None, description="Metadata for this defect") + valid: Mapping[RunType, Dict] = Field( + None, description="Whether each run type has a valid entry" + ) + + def update_one( + self, defect_task, bulk_task, dielectric, query="defect", key="task_id" + ): + + # Metadata + self.last_updated = datetime.now() + self.created_at = datetime.now() + + defect = self.get_defect_from_task(query=query, task=defect_task) + d_id = defect_task[key] + b_id = bulk_task[key] + defect_task = TaskDocument(**defect_task["output"]) + bulk_task = TaskDocument(**bulk_task["output"]) # TODO Atomate2Store + defect_entry, valid = self.get_defect_entry_from_tasks( + defect_task, bulk_task, defect, dielectric + ) + bulk_entry = self.get_bulk_entry_from_task(bulk_task) + + rt = defect_task.calcs_reversed[0].run_type + defect_task.calcs_reversed[0].task_type + defect_task.calcs_reversed[0].calc_type + current_largest_sc = ( + self.defect_entries[rt].sc_entry.composition.num_atoms + if rt in self.defect_entries + else 0 + ) + potential_largest_sc = defect_entry.sc_entry.composition.num_atoms + if potential_largest_sc > current_largest_sc or ( + potential_largest_sc == current_largest_sc + and defect_entry.sc_entry.energy < self.defect_entries[rt].sc_entry.energy + ): + self.defect_entries[rt] = defect_entry + self.defect_ids[rt] = d_id + self.bulk_entries[rt] = bulk_entry + self.bulk_ids[rt] = b_id + self.vbm[rt] = bulk_task.output.vbm + self.valid[rt] = valid + + self.task_ids = list(set(self.task_ids) | {d_id}) + + def update_many( + self, defect_tasks: List, bulk_tasks: List, dielectrics: List, query="defect" + ): + for defect_task, bulk_task, dielectric in zip( + defect_tasks, bulk_tasks, dielectrics + ): + self.update_one( + defect_task=defect_task, + bulk_task=bulk_task, + dielectric=dielectric, + query=query, + ) + + @classmethod + def from_tasks( + cls: Type[T], + defect_task, + bulk_task, + dielectric, + query="defect", + key="task_id", + material_id=None, + ) -> T: + """ + The standard way to create this document. + Args: + tasks: A list of defect,bulk task pairs which will be used to construct a + series of DefectEntry objects. + query: How to retrieve the defect object stored in the task. + """ + defect_task_id = defect_task[key] + defect = cls.get_defect_from_task(query=query, task=defect_task) + defect_task = TaskDocument(**defect_task["output"]) + bulk_task_id = bulk_task[key] + bulk_task = TaskDocument(**bulk_task["output"]) + + # Metadata + last_updated = datetime.now() + created_at = datetime.now() + + rt = defect_task.calcs_reversed[0].run_type + + metadata = {} + defect_entry, valid = cls.get_defect_entry_from_tasks( + defect_task, bulk_task, defect, dielectric + ) + valid = {rt: valid} + defect_entries = {rt: defect_entry} + bulk_entries = {rt: cls.get_bulk_entry_from_task(bulk_task)} + vbm = {rt: bulk_task.output.vbm} + + metadata["defect_origin"] = ( + "intrinsic" + if all( + el in defect_entries[rt].defect.structure.composition + for el in defect_entries[rt].defect.element_changes + ) + else "extrinsic" + ) + + data = { + "defect_entries": defect_entries, + "bulk_entries": bulk_entries, + "defect_ids": {rt: defect_task_id}, + "bulk_ids": {rt: bulk_task_id}, + "last_updated": last_updated, + "created_at": created_at, + "task_ids": [defect_task_id], + "material_id": material_id, + "defect": defect_entries[rt].defect, + "charge": defect_entries[rt].charge_state, + "name": defect_entries[rt].defect.name, + "vbm": vbm, + "metadata": metadata, + "valid": valid, + } + prim = SpacegroupAnalyzer( + defect_entries[rt].defect.structure + ).get_primitive_standard_structure() + data.update(StructureMetadata.from_structure(prim).dict()) + return cls(**data) + + @classmethod + def get_defect_entry_from_tasks( + cls, + defect_task: TaskDocument, + bulk_task: TaskDocument, + defect: Defect, + dielectric=None, + ): + """ + Extract a defect entry from a single pair (defect and bulk) of tasks. + + Args: + defect_task: task dict for the defect calculation + bulk_task: task dict for the bulk calculation + dielectric: Dielectric doc if the defect is charged. If not present, no + dielectric corrections will be performed, even if the defect is charged. + query: Mongo-style query to retrieve the defect object from the defect task + """ + parameters = cls.get_parameters_from_tasks( + defect_task=defect_task, bulk_task=bulk_task + ) + if dielectric: + parameters["dielectric"] = dielectric + + corrections, metadata = cls.get_correction_from_parameters(parameters) + + sc_entry = ComputedStructureEntry( + structure=parameters["final_defect_structure"], + energy=parameters["defect_energy"], + ) + + defect_entry = DefectEntry( + defect=defect, + charge_state=parameters["charge_state"], + sc_entry=sc_entry, + sc_defect_frac_coords=parameters["defect_frac_sc_coords"], + corrections=corrections, + ) + parameters["defect"] = defect + valid = DefectValidation().process_entry(parameters) + return defect_entry, valid + + @classmethod + def get_bulk_entry_from_task(cls, bulk_task: TaskDocument): + return ComputedStructureEntry( + structure=bulk_task.structure, + energy=bulk_task.output.energy, + ) + + @classmethod + def get_correction_from_parameters(cls, parameters) -> Tuple[Dict, Dict]: + corrections = {} + metadata = {} + for correction in ["get_freysoldt_correction", "get_freysoldt2d_correction"]: + corr, met = getattr(cls, correction)(parameters) + corrections.update(corr) + metadata.update(met) + return corrections, metadata + + @classmethod + def get_freysoldt_correction(cls, parameters) -> Tuple[Dict, Dict]: + if parameters["charge_state"] and not parameters.get("2d"): + result = get_freysoldt_correction( + q=parameters["charge_state"], + dielectric=np.array( + parameters["dielectric"] + ), # TODO pmg-analysis expects np array here + defect_locpot=parameters["defect_v_hartree"], + bulk_locpot=parameters["bulk_v_hartree"], + defect_frac_coords=parameters["defect_frac_sc_coords"], + ) + return {"freysoldt": result.correction_energy}, result.metadata + return {}, {} + + @classmethod + def get_freysoldt2d_correction(cls, parameters): + + from pymatgen.io.vasp.outputs import VolumetricData as VaspVolumetricData + + if parameters["charge_state"] and parameters.get("2d"): + eps_parallel = ( + parameters["dielectric"][0][0] + parameters["dielectric"][1][1] + ) / 2 + eps_perp = parameters["dielectric"][2][2] + dielectric = (eps_parallel - 1) / (1 - 1 / eps_perp) + with ScratchDir("."): + + # TODO builder ensure structures are commensurate, but the + # sxdefectalign2d requires exact match between structures + # (to about 6 digits of precision). No good solution right now, + # Just setting def lattice with bulk lattice, which will shift + # the locpot data + parameters["defect_v_hartree"].structure.lattice = parameters[ + "bulk_v_hartree" + ].structure.lattice + + lref = VaspVolumetricData( + structure=parameters["bulk_v_hartree"].structure, + data=parameters["bulk_v_hartree"].data, + ) + ldef = VaspVolumetricData( + structure=parameters["defect_v_hartree"].structure, + data=parameters["defect_v_hartree"].data, + ) + lref.write_file("LOCPOT.ref") + ldef.write_file("LOCPOT.def") + + result = get_freysoldt2d_correction( + q=parameters["charge_state"], + dielectric=dielectric, + defect_locpot=ldef, + bulk_locpot=lref, + defect_frac_coords=parameters["defect_frac_sc_coords"], + energy_cutoff=520, + slab_buffer=2, + ) + return {"freysoldt": result.correction_energy}, result.metadata + return {}, {} + + @classmethod + def get_defect_from_task(cls, query, task): + """ + Unpack a Mongo-style query and retrieve a defect object from a task. + """ + defect = unpack(query.split("."), task) + return MontyDecoder().process_decoded(defect) + + @classmethod + def get_parameters_from_tasks( + cls, defect_task: TaskDocument, bulk_task: TaskDocument + ): + """ + Get parameters necessary to create a defect entry from defect and bulk + task dicts + Args: + defect_task: task dict for the defect calculation + bulk_task: task dict for the bulk calculation. + """ + final_defect_structure = defect_task.structure + final_bulk_structure = bulk_task.structure + + ghost = [ + index + for index, prop in enumerate( + final_defect_structure.site_properties.get("ghost") + ) + if prop + ] + if ghost: + defect_frac_sc_coords = final_defect_structure[ghost[0]].frac_coords + else: + defect_frac_sc_coords = DefectSiteFinder(SETTINGS.SYMPREC).get_defect_fpos( + defect_structure=final_defect_structure, + base_structure=final_bulk_structure, + ) + parameters = { + "defect_energy": defect_task.output.energy, + "bulk_energy": bulk_task.output.energy, + "initial_defect_structure": defect_task.input.structure, + "final_defect_structure": final_defect_structure, + "charge_state": defect_task.output.structure.charge, + "defect_frac_sc_coords": defect_frac_sc_coords, + "defect_v_hartree": MontyDecoder().process_decoded( + defect_task.cp2k_objects[Cp2kObject.v_hartree] # type: ignore + ), # TODO CP2K spec name + "bulk_v_hartree": MontyDecoder().process_decoded( + bulk_task.cp2k_objects[Cp2kObject.v_hartree] # type: ignore + ), # TODO CP2K spec name + } + + if defect_task.tags and "2d" in defect_task.tags: + parameters["2d"] = True + + return parameters + + +class DefectValidation(BaseModel): + """Validate a task document for defect processing.""" + + MAX_ATOMIC_RELAXATION: float = Field( + 0.02, + description="Threshold for the mean absolute displacement of atoms outside a defect's radius of isolution", + ) + + DESORPTION_DISTANCE: float = Field( + 3, description="Distance to consider adsorbate as desorbed" + ) + + def process_entry(self, parameters) -> Dict: + """ + Gets a dictionary of {validator: result}. Result true for passing, + false for failing. + """ + v = {} + v.update(self._atomic_relaxation(parameters)) + v.update(self._desorption(parameters)) + return v + + def _atomic_relaxation(self, parameters): + """ + Returns false if the mean displacement outside the isolation radius is greater + than the cutoff. + """ + in_struc = parameters["initial_defect_structure"] + out_struc = parameters["final_defect_structure"] + sites = out_struc.get_sites_in_sphere( + parameters["defect_frac_sc_coords"], + get_truncated_coulomb_cutoff(in_struc), + include_index=True, + ) + inside_sphere = [site.index for site in sites] + outside_sphere = [i for i in range(len(out_struc)) if i not in inside_sphere] + distances = np.array( + [site.distance(in_struc[i]) for i, site in enumerate(out_struc)] + ) + distances_outside = distances[outside_sphere] + if np.mean(distances_outside) > self.MAX_ATOMIC_RELAXATION: + return {"atomic_relaxation": False} + return {"atomic_relaxation": True} + + def _desorption(self, parameters): + """Returns false if any atom is too far from all other atoms.""" + if isinstance(parameters["defect"], Adsorbate): + out_struc = parameters["final_defect_structure"] + defect_site = out_struc.get_sites_in_sphere( + out_struc.lattice.get_cartesian_coords( + parameters["defect_frac_sc_coords"] + ), + 0.1, + include_index=True, + )[0] + distances = [ + defect_site.distance(site) + for i, site in enumerate(out_struc) + if i != defect_site.index + ] + if all(d > self.DESORPTION_DISTANCE for d in distances): + return {"desorption": False} + return {"desorption": True} + + +class DefectiveMaterialDoc(StructureMetadata): + """Document containing all / many defect tasks for a single material ID.""" + + property_name: ClassVar[str] = "defective material" + + material_id: str = Field( + None, description="Unique material ID for the bulk material" + ) # TODO Change to MPID + + defect_docs: List[DefectDoc] = Field(None, description="Defect Docs") + + last_updated: datetime = Field( + description="Timestamp for when this document was last updated", + default_factory=datetime.utcnow, + ) + + created_at: datetime = Field( + description="Timestamp for when this material document was first created", + default_factory=datetime.utcnow, + ) + + metadata: Dict = Field(None, description="Metadata for this object") + + @classmethod + def from_docs(cls: Type["S"], defect_docs: DefectDoc, material_id: str) -> S: + return cls( + defect_docs=defect_docs, + material_id=material_id, + last_updated=max(d.last_updated for d in defect_docs), + created_at=datetime.now(), + ) + + @property + def element_set(self) -> set: + els = {Element(e) for e in self.defect_docs[0].defect.structure.symbol_set} + for d in self.defect_docs: + els = els | set(d.defect.element_changes.keys()) + return els + + def get_formation_energy_diagram( + self, + run_type: RunType | str, + atomic_entries: List[ComputedEntry], + phase_diagram: PhaseDiagram, + filters: List[Callable] | None = None, + ) -> MultiFormationEnergyDiagram: + + filters = filters if filters else [lambda _: True] + els: Set[Element] = set() + defect_entries = [] + bulk_entries = [] + vbms = [] + if isinstance(run_type, str): + run_type = RunType(run_type) + for doc in filter(lambda x: all(f(x) for f in filters), self.defect_docs): + if doc.defect_entries.get(run_type): + els = els | set(doc.defect.element_changes.keys()) + defect_entries.append(doc.defect_entries.get(run_type)) + bulk_entries.append(doc.bulk_entries.get(run_type)) + vbms.append(doc.vbm.get(run_type)) + + return MultiFormationEnergyDiagram.with_atomic_entries( + bulk_entry=bulk_entries[0], + defect_entries=defect_entries, + atomic_entries=atomic_entries, + phase_diagram=phase_diagram, + vbm=vbms[0], + ) + + +def unpack(query, d): + if not query: + return d + if isinstance(d, List): + return unpack(query[1:], d.__getitem__(int(query.pop(0)))) + return unpack(query[1:], d.__getitem__(query.pop(0))) diff --git a/src/atomate2/cp2k/sets/base.py b/src/atomate2/cp2k/sets/base.py index eafc50bd5c..c98a274618 100644 --- a/src/atomate2/cp2k/sets/base.py +++ b/src/atomate2/cp2k/sets/base.py @@ -172,7 +172,7 @@ class Cp2kInputGenerator(InputGenerator): user_input_settings: dict = field(default_factory=dict) user_kpoints_settings: dict | Kpoints = field(default_factory=dict) auto_kspacing: bool = True - use_structure_charge: bool = False + use_structure_charge: bool = True sort_structure: bool = True symprec: float = SETTINGS.SYMPREC force_gamma: bool = False diff --git a/src/atomate2/cp2k/sets/defect.py b/src/atomate2/cp2k/sets/defect.py new file mode 100644 index 0000000000..0ac4200512 --- /dev/null +++ b/src/atomate2/cp2k/sets/defect.py @@ -0,0 +1,67 @@ +"""Module defining defect input set generators.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from atomate2.cp2k.sets.core import ( + CellOptSetGenerator, + HybridCellOptSetGenerator, + HybridRelaxSetGenerator, + HybridStaticSetGenerator, + RelaxSetGenerator, + StaticSetGenerator, +) + +logger = logging.getLogger(__name__) + +DEFECT_SET_UPDATES = {"print_v_hartree": True, "print_pdos": True, "print_dos": True} + + +@dataclass +class DefectStaticSetGenerator(StaticSetGenerator): + def get_input_updates(self, *args, **kwargs) -> dict: + updates = super().get_input_updates(*args, **kwargs) + updates.update(DEFECT_SET_UPDATES) + return updates + + +@dataclass +class DefectRelaxSetGenerator(RelaxSetGenerator): + def get_input_updates(self, *args, **kwargs) -> dict: + updates = super().get_input_updates(*args, **kwargs) + updates.update(DEFECT_SET_UPDATES) + return updates + + +@dataclass +class DefectCellOptSetGenerator(CellOptSetGenerator): + def get_input_updates(self, *args, **kwargs) -> dict: + updates = super().get_input_updates(*args, **kwargs) + updates.update(DEFECT_SET_UPDATES) + return updates + + +@dataclass +class DefectHybridStaticSetGenerator(HybridStaticSetGenerator): + def get_input_updates(self, *args, **kwargs) -> dict: + updates = super().get_input_updates(*args, **kwargs) + updates.update(DEFECT_SET_UPDATES) + return updates + + +@dataclass +class DefectHybridRelaxSetGenerator(HybridRelaxSetGenerator): + def get_input_updates(self, *args, **kwargs) -> dict: + updates = super().get_input_updates(*args, **kwargs) + updates.update(DEFECT_SET_UPDATES) + return updates + + +@dataclass +class DefectHybridCellOptSetGenerator(HybridCellOptSetGenerator): + def get_input_updates(self, *args, **kwargs) -> dict: + updates = super().get_input_updates(*args, **kwargs) + updates.update(DEFECT_SET_UPDATES) + return updates diff --git a/tests/cp2k/sets/test_defect.py b/tests/cp2k/sets/test_defect.py new file mode 100644 index 0000000000..d2f42b505b --- /dev/null +++ b/tests/cp2k/sets/test_defect.py @@ -0,0 +1,17 @@ +import pytest + +def test_input_generators(si_structure): + from atomate2.cp2k.sets.defect import ( + DefectSetGenerator, DefectStaticSetGenerator, DefectRelaxSetGenerator, DefectCellOptSetGenerator, + DefectHybridStaticSetGenerator, DefectHybridRelaxSetGenerator, DefectHybridCellOptSetGenerator + ) + + # check that all generators give the correct printing + for gen in [ + DefectSetGenerator(), DefectStaticSetGenerator(), DefectRelaxSetGenerator(), + DefectCellOptSetGenerator(), DefectHybridStaticSetGenerator(), + DefectHybridRelaxSetGenerator(), DefectHybridCellOptSetGenerator() + ]: + input_set = gen.get_input_set(si_structure) + assert input_set.cp2k_input.check("FORCE_EVAL/DFT/PRINT/PDOS") or input_set.cp2k_input.check("FORCE_EVAL/DFT/PRINT/DOS") + assert input_set.cp2k_input.check("FORCE_EVAL/DFT/PRINT/V_HARTREE_CUBE")