Skip to content

Commit

Permalink
remove the dependency of property num_atoms on the cube file content
Browse files Browse the repository at this point in the history
remove the dependency of property num_atoms on the cube file content.
When writing a new cube file, the num_atoms property is obtained from
self.stc which is ase Atoms object
  • Loading branch information
robinzyb committed Jul 12, 2024
1 parent 221ecdc commit 5e512e3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 25 deletions.
69 changes: 45 additions & 24 deletions cp2kdata/cube/cube.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
from cp2kdata.utils import file_content, interpolate_spline
from cp2kdata.utils import au2A, au2eV
from cp2kdata.cell import Cp2kCell
import os
from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt
import os
from scipy import fft
from ase import Atom, Atoms
from monty.json import MSONable
from copy import deepcopy
import asciichartpy as acp


def square_wave_filter(x, l, cell_z):
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l
y = np.concatenate([y_1st, y_2nd])
return y
from cp2kdata.utils import file_content, interpolate_spline
from cp2kdata.utils import au2A, au2eV
from cp2kdata.cell import Cp2kCell

# parse cp2kcube

Expand Down Expand Up @@ -119,11 +112,12 @@ def get_pav(self, axis="z", interpolate=False):
def get_mav(self, l1, l2=0, ncov=1, interpolate=False):
axis = "z"
pav_x, pav = self.get_pav(axis=axis, interpolate=interpolate)
theta_1_fft = fft.fft(square_wave_filter(pav_x, l1, self.cell_z))
theta_1_fft = fft.fft(self.square_wave_filter(pav_x, l1, self.cell_z))

Check warning on line 115 in cp2kdata/cube/cube.py

View check run for this annotation

Codecov / codecov/patch

cp2kdata/cube/cube.py#L115

Added line #L115 was not covered by tests
pav_fft = fft.fft(pav)
mav_fft = pav_fft*theta_1_fft*self.cell_z/len(pav_x)
if ncov == 2:
theta_2_fft = fft.fft(square_wave_filter(pav_x, l2, self.cell_z))
theta_2_fft = fft.fft(

Check warning on line 119 in cp2kdata/cube/cube.py

View check run for this annotation

Codecov / codecov/patch

cp2kdata/cube/cube.py#L119

Added line #L119 was not covered by tests
self.square_wave_filter(pav_x, l2, self.cell_z))
mav_fft = mav_fft*theta_2_fft*self.cell_z/len(pav_x)
mav = fft.ifft(mav_fft)
return pav_x, np.real(mav)
Expand All @@ -137,14 +131,23 @@ def quick_plot(self, axis="z", interpolate=False, output_dir="./"):
plt.legend()
plt.savefig(os.path.join(output_dir, "pav.png"), dpi=100)

@staticmethod
def square_wave_filter(x, l, cell_z):
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l
y = np.concatenate([y_1st, y_2nd])
return y

Check warning on line 141 in cp2kdata/cube/cube.py

View check run for this annotation

Codecov / codecov/patch

cp2kdata/cube/cube.py#L136-L141

Added lines #L136 - L141 were not covered by tests


class Cp2kCube(MSONable):
# add MSONable use as_dict and from_dict
"""
Documentation for the Cp2kCube class.
"""

def __init__(self, fname=None, cube_vals=None, cell=None, stc=None):
def __init__(self, fname: str = None, cube_vals: np.ndarray = None, cell: Cp2kCell = None, stc: Atoms = None):
print("Warning: This is New Cp2kCube Class, if you want to use old Cp2kCube")
print("try, from cp2kdata.cube.cube import Cp2kCubeOld")
print("New Cp2kCube return raw values in cp2k cube file")
Expand All @@ -159,7 +162,7 @@ def __init__(self, fname=None, cube_vals=None, cell=None, stc=None):
else:
self.cell = cell
if stc is None:
self.stc = self.get_stc()
self.stc = self._parse_stc()
else:
self.stc = stc

Expand All @@ -179,6 +182,12 @@ def read_cell(self):

@property
def num_atoms(self):
return len(self.stc)

def _parse_num_atoms(self):
"""
be used to parse the number of atoms from the cube file only
"""
line = file_content(self.file, 2)
num_atoms = int(line.split()[0])
return num_atoms
Expand Down Expand Up @@ -215,9 +224,10 @@ def __sub__(self, others):
raise RuntimeError("Unspported Class")
return other_copy

def get_stc(self):
def _parse_stc(self):
num_atoms = self._parse_num_atoms()
atom_list = []
for i in range(self.num_atoms):
for i in range(num_atoms):
stc_vals = file_content(self.file, (6+i, 6+i+1))
stc_vals = stc_vals.split()
atom = Atom(
Expand All @@ -231,12 +241,14 @@ def get_stc(self):
stc.set_cell(self.cell.cell_matrix*au2A)
return stc

def get_stc(self):
return self.stc

Check warning on line 245 in cp2kdata/cube/cube.py

View check run for this annotation

Codecov / codecov/patch

cp2kdata/cube/cube.py#L245

Added line #L245 was not covered by tests

def copy(self):
return deepcopy(self)

def get_pav(self, axis='z', interpolate=False):


# do the planar average along specific axis
lengths = self.cell.get_cell_lengths()
grid_point = self.cell.grid_point
Expand All @@ -247,7 +259,7 @@ def get_pav(self, axis='z', interpolate=False):
length = lengths[0]

np.testing.assert_array_equal(
self.cell.get_cell_angles()[[1,2]],
self.cell.get_cell_angles()[[1, 2]],
np.array([90.0, 90.0]),
err_msg=f"The axis x is not perpendicular to yz plane, the pav can not be used!"
)
Expand All @@ -269,7 +281,7 @@ def get_pav(self, axis='z', interpolate=False):
length = lengths[2]

np.testing.assert_array_equal(
self.cell.get_cell_angles()[[0,1]],
self.cell.get_cell_angles()[[0, 1]],
np.array([90.0, 90.0]),
err_msg=f"The axis z is not perpendicular to xy plane, the pav can not be used!"
)
Expand Down Expand Up @@ -297,11 +309,11 @@ def get_mav(self, l1, l2=0, ncov=1, interpolate=False, axis="z"):
length = cell_length[axis]

pav_x, pav = self.get_pav(axis=axis, interpolate=interpolate)
theta_1_fft = fft.fft(square_wave_filter(pav_x, l1, length))
theta_1_fft = fft.fft(self.square_wave_filter(pav_x, l1, length))
pav_fft = fft.fft(pav)
mav_fft = pav_fft*theta_1_fft*length/len(pav_x)
if ncov == 2:
theta_2_fft = fft.fft(square_wave_filter(pav_x, l2, length))
theta_2_fft = fft.fft(self.square_wave_filter(pav_x, l2, length))
mav_fft = mav_fft*theta_2_fft*length/len(pav_x)
mav = fft.ifft(mav_fft)
return pav_x, np.real(mav)
Expand Down Expand Up @@ -443,6 +455,15 @@ def read_cube_vals(fname, num_atoms, grid_point):
cube_vals = cube_vals.reshape(grid_point)
return cube_vals

@staticmethod
def square_wave_filter(x, l, cell_z):
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l
y = np.concatenate([y_1st, y_2nd])
return y


class Cp2kCubeTraj:
def __init__(cube_dir, prefix):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cube/test_cube.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from cp2kdata import Cp2kCube
import os

import pytest
import numpy as np

from cp2kdata import Cp2kCube


path_prefix = "tests/test_cube/"
cube_list = [Cp2kCube(os.path.join(path_prefix, "Si_bulk8-v_hartree-1_0.cube"))]
Expand Down

0 comments on commit 5e512e3

Please sign in to comment.