Skip to content

Commit

Permalink
pre-commit formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Mar 26, 2024
1 parent d93738c commit 2e42df7
Showing 1 changed file with 65 additions and 102 deletions.
167 changes: 65 additions & 102 deletions easyunfold/procar.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,9 @@ def _read(self, fobj, parsed_kpoints=None):
fobj.seek(0)
_header = fobj.readline()
# Read the NK, NB and NIONS that are integers
_total_nkpts, nbands, nion = [
int(token) for token in re.sub(r"[^0-9]", " ", fobj.readline()).split()
]
_total_nkpts, nbands, nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()]
if nion != self.nion:
raise ValueError(
f"Mismatch in number of ions in PROCARs supplied: ({nion} vs {self.nion})!"
)
raise ValueError(f'Mismatch in number of ions in PROCARs supplied: ({nion} vs {self.nion})!')

# Count the number of data lines, these lines do not have any alphabets
proj_data, energies, kvecs, kweights, occs = [], [], [], [], []
Expand All @@ -73,12 +69,10 @@ def _read(self, fobj, parsed_kpoints=None):
# Counter for the number of sections in the PROCAR
section_counter = 1
_last_kid = 0
this_procar_parsed_kpoints = (
set()
) # set with tuples of parsed (kvec tuple, section_counter) for this PROCAR
this_procar_parsed_kpoints = set() # set with tuples of parsed (kvec tuple, section_counter) for this PROCAR
while line:
if line.startswith(" k-point"):
line = re.sub(r"(\d)-", r"\1 -", line)
if line.startswith(' k-point'):
line = re.sub(r'(\d)-', r'\1 -', line)
tokens = line.strip().split()
_kid = int(tokens[1])

Expand All @@ -87,14 +81,9 @@ def _read(self, fobj, parsed_kpoints=None):
section_counter += 1
_last_kid = _kid

kvec = tuple(
round(float(val), 5)
for val in tokens[-6:-3] # tuple to make it hashable
) # round to 5 decimal places to ensure proper kpoint matching
if (
kvec not in parsed_kpoints
and (kvec, section_counter) not in this_procar_parsed_kpoints
):
kvec = tuple(round(float(val), 5) for val in tokens[-6:-3] # tuple to make it hashable
) # round to 5 decimal places to ensure proper kpoint matching
if (kvec not in parsed_kpoints and (kvec, section_counter) not in this_procar_parsed_kpoints):
this_procar_parsed_kpoints.add((kvec, section_counter))
kvecs.append(list(kvec))
kweights.append(float(tokens[-1]))
Expand All @@ -104,20 +93,16 @@ def _read(self, fobj, parsed_kpoints=None):
line = fobj.readline()
continue

elif (
not re.search(r"[a-zA-Z]", line)
and line.strip()
and len(line.strip().split()) - 2 == len(self.proj_names)
):
elif (not re.search(r'[a-zA-Z]', line) and line.strip() and len(line.strip().split()) - 2 == len(self.proj_names)):
# only parse data if line is expected length, in case of LORBIT >= 12
proj_data.append([float(token) for token in line.strip().split()[1:-1]])

elif line.startswith("band"):
elif line.startswith('band'):
tokens = line.strip().split()
energies.append(float(tokens[4]))
occs.append(float(tokens[-1]))

elif line.startswith("tot"):
elif line.startswith('tot'):
tot_count += 1

line = fobj.readline()
Expand All @@ -128,10 +113,8 @@ def _read(self, fobj, parsed_kpoints=None):
elif tot_count == len(occs):
self._is_soc = False
else:
raise ValueError(
f"Number of lines starting with 'tot' ({tot_count}) in PROCAR does not match expected "
f"values ({4*len(occs)} or {len(occs)})!"
)
raise ValueError(f"Number of lines starting with 'tot' ({tot_count}) in PROCAR does not match expected "
f'values ({4*len(occs)} or {len(occs)})!')

occs = np.array(occs)
kvecs = np.array(kvecs)
Expand All @@ -140,9 +123,7 @@ def _read(self, fobj, parsed_kpoints=None):

proj_data = np.array(proj_data, dtype=float)

nkpts = len(
kvecs
) # redetermine nkpts in case some were skipped due to already being parsed
nkpts = len(kvecs) # redetermine nkpts in case some were skipped due to already being parsed

# For spin-polarised calcs, the data from the second (down) spin are located after that of the first (up) spin
# Hence, the number of spins is simply the number of sections
Expand All @@ -156,20 +137,16 @@ def _read(self, fobj, parsed_kpoints=None):

# Reshape the array
if self._is_soc is False:
proj_data = proj_data.reshape(
(
self.nspins,
nkpts // self.nspins,
nbands,
self.nion,
len(self.proj_names),
)
)
proj_data = proj_data.reshape((
self.nspins,
nkpts // self.nspins,
nbands,
self.nion,
len(self.proj_names),
))
proj_xyz = None
else:
proj_data = proj_data.reshape(
(self.nspins, nkpts, nbands, 4, self.nion, len(self.proj_names))
)
proj_data = proj_data.reshape((self.nspins, nkpts, nbands, 4, self.nion, len(self.proj_names)))
# Split the data into xyz projection and total
proj_xyz = proj_data[:, :, :, 1:, :, :]
proj_data = proj_data[:, :, :, 0, :, :]
Expand All @@ -183,12 +160,7 @@ def _read(self, fobj, parsed_kpoints=None):
proj_xyz /= proj_sum

# Update the parsed kpoints
parsed_kpoints.update(
{
kvec_section_counter_tuple[0]
for kvec_section_counter_tuple in this_procar_parsed_kpoints
}
)
parsed_kpoints.update({kvec_section_counter_tuple[0] for kvec_section_counter_tuple in this_procar_parsed_kpoints})

return (
self.nspins,
Expand Down Expand Up @@ -219,15 +191,11 @@ def _read_header_nion_proj_names(self, fobj):
fobj.seek(0)
self.header = fobj.readline()
# Read the NK, NB and NIONS that are integers
_nkpts, _nbands, self.nion = [
int(token) for token in re.sub(r"[^0-9]", " ", fobj.readline()).split()
]
_nkpts, _nbands, self.nion = [int(token) for token in re.sub(r'[^0-9]', ' ', fobj.readline()).split()]
self.proj_names = None # projection names

for line in fobj:
if re.match(
r"^ion.*tot", line
): # only the first "ion" line, in case of LORBIT >= 12
if re.match(r'^ion.*tot', line): # only the first "ion" line, in case of LORBIT >= 12
self.proj_names = line.strip().split()[1:-1]
break

Expand All @@ -237,14 +205,13 @@ def read(self, fobjs_or_paths):
def open_file(fobj_or_path):
if isinstance(fobj_or_path, (str, Path)):
if os.path.exists(fobj_or_path):
return zopen(fobj_or_path, mode="rt") # closed later
if os.path.exists(f"{fobj_or_path}.gz"):
return zopen(f"{fobj_or_path}.gz", mode="rt")
return zopen(fobj_or_path, mode='rt') # closed later
if os.path.exists(f'{fobj_or_path}.gz'):
return zopen(f'{fobj_or_path}.gz', mode='rt')

raise FileNotFoundError( # else raise error
f"File not found: {fobj_or_path} – PROCAR(.gz) file needed for "
f"parsing atomic projections!"
)
f'File not found: {fobj_or_path} – PROCAR(.gz) file needed for '
f'parsing atomic projections!')
return fobj_or_path # already a file-like object, just return it

parsed_kpoints = None
Expand All @@ -270,9 +237,7 @@ def open_file(fobj_or_path):
parsed_kpoints,
) = self._read(fobj, parsed_kpoints=parsed_kpoints)
if current_nspins is not None and current_nspins != nspins:
raise ValueError(
f"Mismatch in number of spins in PROCARs supplied: ({nspins} vs {current_nspins})!"
)
raise ValueError(f'Mismatch in number of spins in PROCARs supplied: ({nspins} vs {current_nspins})!')

if isinstance(fobj_or_path, (str, Path)):
fobj.close() # if file was opened in this loop, close it
Expand All @@ -285,7 +250,7 @@ def open_file(fobj_or_path):
proj_data_list.append(proj_data)
proj_xyz_list.append(proj_xyz)
if len(fobjs_or_paths) > 1: # print progress if reading multiple files
print(f"Finished parsing PROCAR {i + 1}/{len(fobjs_or_paths)}")
print(f'Finished parsing PROCAR {i + 1}/{len(fobjs_or_paths)}')

# Combine along the nkpts axis:
# for occs, eigenvalues, proj_data and proj_xyz, nbands (axis = 2) could differ, so set missing values to zero:
Expand All @@ -297,7 +262,7 @@ def open_file(fobj_or_path):
array_list[i] = np.pad(
arr,
((0, 0), (0, 0), (0, max_nbands - arr.shape[2])),
mode="constant",
mode='constant',
)
elif len(arr.shape) == 5: # proj_xyz_list
array_list[i] = np.pad(
Expand All @@ -309,7 +274,7 @@ def open_file(fobj_or_path):
(0, 0),
(0, 0),
),
mode="constant",
mode='constant',
)
elif len(arr.shape) == 6: # proj_xyz_list
array_list[i] = np.pad(
Expand All @@ -322,10 +287,10 @@ def open_file(fobj_or_path):
(0, 0),
(0, 0),
),
mode="constant",
mode='constant',
)
else:
raise ValueError("Unexpected array shape encountered!")
raise ValueError('Unexpected array shape encountered!')

self.nbands = max_nbands
self.occs = np.concatenate(occs_list, axis=1)
Expand All @@ -338,9 +303,7 @@ def open_file(fobj_or_path):

self.nkpts = self.kvecs.shape[1]

def get_projection(
self, atom_idx: List[int], proj: Union[List[str], str], weight_by_k=False
):
def get_projection(self, atom_idx: List[int], proj: Union[List[str], str], weight_by_k=False):
"""
Get projection for specific atoms and specific projectors
Expand All @@ -353,7 +316,7 @@ def get_projection(
"""
atom_mask = [iatom in atom_idx for iatom in range(self.nion)]
assert any(atom_mask)
if proj == "all":
if proj == 'all':
out = self.proj_data[:, :, :, atom_mask, :].sum(axis=(-1, -2))
else:
if isinstance(proj, str):
Expand All @@ -363,15 +326,15 @@ def get_projection(

# replace any instance of "p" with "px,py,pz" and "d" with "dxy,dyz,dz2,dxz,dx2-y2"
def _replace_p_d(single_proj):
if single_proj == "p":
return ["px", "py", "pz"]
if single_proj == "d":
if single_proj == 'p':
return ['px', 'py', 'pz']
if single_proj == 'd':
return [
"dxy",
"dyz",
"dz2",
"dxz",
"x2-y2",
'dxy',
'dyz',
'dz2',
'dxz',
'x2-y2',
] # dx2-y2 labelled differently in VASP
# PROCAR

Expand All @@ -393,25 +356,25 @@ def _replace_p_d(single_proj):
def as_dict(self) -> dict:
"""Convert the object into a dictionary representation (so it can be saved to json)"""
output = {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
"@version": __version__,
'@module': self.__class__.__module__,
'@class': self.__class__.__name__,
'@version': __version__,
}
for key in [
"_is_soc",
"eigenvalues",
"kvecs",
"kweights",
"nbands",
"nkpts",
"nspins",
"nion",
"occs",
"proj_names",
"proj_data",
"header",
"proj_xyz",
"normalise",
'_is_soc',
'eigenvalues',
'kvecs',
'kweights',
'nbands',
'nkpts',
'nspins',
'nion',
'occs',
'proj_names',
'proj_data',
'header',
'proj_xyz',
'normalise',
]:
output[key] = getattr(self, key)
return output
Expand All @@ -429,7 +392,7 @@ def from_dict(cls, d):
"""

def decode_dict(subdict):
if isinstance(subdict, dict) and "@module" in subdict:
if isinstance(subdict, dict) and '@module' in subdict:
return MontyDecoder().process_decoded(subdict)
return subdict

Expand All @@ -438,7 +401,7 @@ def decode_dict(subdict):

# set the instance variables directly from the dictionary
for key, value in d_decoded.items():
if key in ["@module", "@class", "@version"]:
if key in ['@module', '@class', '@version']:
continue
setattr(instance, key, value)

Expand Down

0 comments on commit 2e42df7

Please sign in to comment.