diff --git a/easyunfold/procar.py b/easyunfold/procar.py index 21b382a..2d75f53 100644 --- a/easyunfold/procar.py +++ b/easyunfold/procar.py @@ -56,13 +56,9 @@ def _read(self, fobj, parsed_kpoints=None): fobj.seek(0) _header = fobj.readline() # Read the NK, NB and NIONS that are integers - _total_nkpts, nbands, nion = [ - int(token) for token in re.sub(r"[^0-9]", " ", fobj.readline()).split() - ] + _total_nkpts, nbands, nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()] if nion != self.nion: - raise ValueError( - f"Mismatch in number of ions in PROCARs supplied: ({nion} vs {self.nion})!" - ) + raise ValueError(f'Mismatch in number of ions in PROCARs supplied: ({nion} vs {self.nion})!') # Count the number of data lines, these lines do not have any alphabets proj_data, energies, kvecs, kweights, occs = [], [], [], [], [] @@ -73,12 +69,10 @@ def _read(self, fobj, parsed_kpoints=None): # Counter for the number of sections in the PROCAR section_counter = 1 _last_kid = 0 - this_procar_parsed_kpoints = ( - set() - ) # set with tuples of parsed (kvec tuple, section_counter) for this PROCAR + this_procar_parsed_kpoints = set() # set with tuples of parsed (kvec tuple, section_counter) for this PROCAR while line: - if line.startswith(" k-point"): - line = re.sub(r"(\d)-", r"\1 -", line) + if line.startswith(' k-point'): + line = re.sub(r'(\d)-', r'\1 -', line) tokens = line.strip().split() _kid = int(tokens[1]) @@ -87,14 +81,9 @@ def _read(self, fobj, parsed_kpoints=None): section_counter += 1 _last_kid = _kid - kvec = tuple( - round(float(val), 5) - for val in tokens[-6:-3] # tuple to make it hashable - ) # round to 5 decimal places to ensure proper kpoint matching - if ( - kvec not in parsed_kpoints - and (kvec, section_counter) not in this_procar_parsed_kpoints - ): + kvec = tuple(round(float(val), 5) for val in tokens[-6:-3] # tuple to make it hashable + ) # round to 5 decimal places to ensure proper kpoint matching + if (kvec not in parsed_kpoints and (kvec, section_counter) not in this_procar_parsed_kpoints): this_procar_parsed_kpoints.add((kvec, section_counter)) kvecs.append(list(kvec)) kweights.append(float(tokens[-1])) @@ -104,20 +93,16 @@ def _read(self, fobj, parsed_kpoints=None): line = fobj.readline() continue - elif ( - not re.search(r"[a-zA-Z]", line) - and line.strip() - and len(line.strip().split()) - 2 == len(self.proj_names) - ): + elif (not re.search(r'[a-zA-Z]', line) and line.strip() and len(line.strip().split()) - 2 == len(self.proj_names)): # only parse data if line is expected length, in case of LORBIT >= 12 proj_data.append([float(token) for token in line.strip().split()[1:-1]]) - elif line.startswith("band"): + elif line.startswith('band'): tokens = line.strip().split() energies.append(float(tokens[4])) occs.append(float(tokens[-1])) - elif line.startswith("tot"): + elif line.startswith('tot'): tot_count += 1 line = fobj.readline() @@ -128,10 +113,8 @@ def _read(self, fobj, parsed_kpoints=None): elif tot_count == len(occs): self._is_soc = False else: - raise ValueError( - f"Number of lines starting with 'tot' ({tot_count}) in PROCAR does not match expected " - f"values ({4*len(occs)} or {len(occs)})!" - ) + raise ValueError(f"Number of lines starting with 'tot' ({tot_count}) in PROCAR does not match expected " + f'values ({4*len(occs)} or {len(occs)})!') occs = np.array(occs) kvecs = np.array(kvecs) @@ -140,9 +123,7 @@ def _read(self, fobj, parsed_kpoints=None): proj_data = np.array(proj_data, dtype=float) - nkpts = len( - kvecs - ) # redetermine nkpts in case some were skipped due to already being parsed + nkpts = len(kvecs) # redetermine nkpts in case some were skipped due to already being parsed # For spin-polarised calcs, the data from the second (down) spin are located after that of the first (up) spin # Hence, the number of spins is simply the number of sections @@ -156,20 +137,16 @@ def _read(self, fobj, parsed_kpoints=None): # Reshape the array if self._is_soc is False: - proj_data = proj_data.reshape( - ( - self.nspins, - nkpts // self.nspins, - nbands, - self.nion, - len(self.proj_names), - ) - ) + proj_data = proj_data.reshape(( + self.nspins, + nkpts // self.nspins, + nbands, + self.nion, + len(self.proj_names), + )) proj_xyz = None else: - proj_data = proj_data.reshape( - (self.nspins, nkpts, nbands, 4, self.nion, len(self.proj_names)) - ) + proj_data = proj_data.reshape((self.nspins, nkpts, nbands, 4, self.nion, len(self.proj_names))) # Split the data into xyz projection and total proj_xyz = proj_data[:, :, :, 1:, :, :] proj_data = proj_data[:, :, :, 0, :, :] @@ -183,12 +160,7 @@ def _read(self, fobj, parsed_kpoints=None): proj_xyz /= proj_sum # Update the parsed kpoints - parsed_kpoints.update( - { - kvec_section_counter_tuple[0] - for kvec_section_counter_tuple in this_procar_parsed_kpoints - } - ) + parsed_kpoints.update({kvec_section_counter_tuple[0] for kvec_section_counter_tuple in this_procar_parsed_kpoints}) return ( self.nspins, @@ -219,15 +191,11 @@ def _read_header_nion_proj_names(self, fobj): fobj.seek(0) self.header = fobj.readline() # Read the NK, NB and NIONS that are integers - _nkpts, _nbands, self.nion = [ - int(token) for token in re.sub(r"[^0-9]", " ", fobj.readline()).split() - ] + _nkpts, _nbands, self.nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()] self.proj_names = None # projection names for line in fobj: - if re.match( - r"^ion.*tot", line - ): # only the first "ion" line, in case of LORBIT >= 12 + if re.match(r'^ion.*tot', line): # only the first "ion" line, in case of LORBIT >= 12 self.proj_names = line.strip().split()[1:-1] break @@ -237,14 +205,13 @@ def read(self, fobjs_or_paths): def open_file(fobj_or_path): if isinstance(fobj_or_path, (str, Path)): if os.path.exists(fobj_or_path): - return zopen(fobj_or_path, mode="rt") # closed later - if os.path.exists(f"{fobj_or_path}.gz"): - return zopen(f"{fobj_or_path}.gz", mode="rt") + return zopen(fobj_or_path, mode='rt') # closed later + if os.path.exists(f'{fobj_or_path}.gz'): + return zopen(f'{fobj_or_path}.gz', mode='rt') raise FileNotFoundError( # else raise error - f"File not found: {fobj_or_path} – PROCAR(.gz) file needed for " - f"parsing atomic projections!" - ) + f'File not found: {fobj_or_path} – PROCAR(.gz) file needed for ' + f'parsing atomic projections!') return fobj_or_path # already a file-like object, just return it parsed_kpoints = None @@ -270,9 +237,7 @@ def open_file(fobj_or_path): parsed_kpoints, ) = self._read(fobj, parsed_kpoints=parsed_kpoints) if current_nspins is not None and current_nspins != nspins: - raise ValueError( - f"Mismatch in number of spins in PROCARs supplied: ({nspins} vs {current_nspins})!" - ) + raise ValueError(f'Mismatch in number of spins in PROCARs supplied: ({nspins} vs {current_nspins})!') if isinstance(fobj_or_path, (str, Path)): fobj.close() # if file was opened in this loop, close it @@ -285,7 +250,7 @@ def open_file(fobj_or_path): proj_data_list.append(proj_data) proj_xyz_list.append(proj_xyz) if len(fobjs_or_paths) > 1: # print progress if reading multiple files - print(f"Finished parsing PROCAR {i + 1}/{len(fobjs_or_paths)}") + print(f'Finished parsing PROCAR {i + 1}/{len(fobjs_or_paths)}') # Combine along the nkpts axis: # for occs, eigenvalues, proj_data and proj_xyz, nbands (axis = 2) could differ, so set missing values to zero: @@ -297,7 +262,7 @@ def open_file(fobj_or_path): array_list[i] = np.pad( arr, ((0, 0), (0, 0), (0, max_nbands - arr.shape[2])), - mode="constant", + mode='constant', ) elif len(arr.shape) == 5: # proj_xyz_list array_list[i] = np.pad( @@ -309,7 +274,7 @@ def open_file(fobj_or_path): (0, 0), (0, 0), ), - mode="constant", + mode='constant', ) elif len(arr.shape) == 6: # proj_xyz_list array_list[i] = np.pad( @@ -322,10 +287,10 @@ def open_file(fobj_or_path): (0, 0), (0, 0), ), - mode="constant", + mode='constant', ) else: - raise ValueError("Unexpected array shape encountered!") + raise ValueError('Unexpected array shape encountered!') self.nbands = max_nbands self.occs = np.concatenate(occs_list, axis=1) @@ -338,9 +303,7 @@ def open_file(fobj_or_path): self.nkpts = self.kvecs.shape[1] - def get_projection( - self, atom_idx: List[int], proj: Union[List[str], str], weight_by_k=False - ): + def get_projection(self, atom_idx: List[int], proj: Union[List[str], str], weight_by_k=False): """ Get projection for specific atoms and specific projectors @@ -353,7 +316,7 @@ def get_projection( """ atom_mask = [iatom in atom_idx for iatom in range(self.nion)] assert any(atom_mask) - if proj == "all": + if proj == 'all': out = self.proj_data[:, :, :, atom_mask, :].sum(axis=(-1, -2)) else: if isinstance(proj, str): @@ -363,15 +326,15 @@ def get_projection( # replace any instance of "p" with "px,py,pz" and "d" with "dxy,dyz,dz2,dxz,dx2-y2" def _replace_p_d(single_proj): - if single_proj == "p": - return ["px", "py", "pz"] - if single_proj == "d": + if single_proj == 'p': + return ['px', 'py', 'pz'] + if single_proj == 'd': return [ - "dxy", - "dyz", - "dz2", - "dxz", - "x2-y2", + 'dxy', + 'dyz', + 'dz2', + 'dxz', + 'x2-y2', ] # dx2-y2 labelled differently in VASP # PROCAR @@ -393,25 +356,25 @@ def _replace_p_d(single_proj): def as_dict(self) -> dict: """Convert the object into a dictionary representation (so it can be saved to json)""" output = { - "@module": self.__class__.__module__, - "@class": self.__class__.__name__, - "@version": __version__, + '@module': self.__class__.__module__, + '@class': self.__class__.__name__, + '@version': __version__, } for key in [ - "_is_soc", - "eigenvalues", - "kvecs", - "kweights", - "nbands", - "nkpts", - "nspins", - "nion", - "occs", - "proj_names", - "proj_data", - "header", - "proj_xyz", - "normalise", + '_is_soc', + 'eigenvalues', + 'kvecs', + 'kweights', + 'nbands', + 'nkpts', + 'nspins', + 'nion', + 'occs', + 'proj_names', + 'proj_data', + 'header', + 'proj_xyz', + 'normalise', ]: output[key] = getattr(self, key) return output @@ -429,7 +392,7 @@ def from_dict(cls, d): """ def decode_dict(subdict): - if isinstance(subdict, dict) and "@module" in subdict: + if isinstance(subdict, dict) and '@module' in subdict: return MontyDecoder().process_decoded(subdict) return subdict @@ -438,7 +401,7 @@ def decode_dict(subdict): # set the instance variables directly from the dictionary for key, value in d_decoded.items(): - if key in ["@module", "@class", "@version"]: + if key in ['@module', '@class', '@version']: continue setattr(instance, key, value)