Skip to content

Commit

Permalink
Merge pull request #65 from robinzyb/devel
Browse files Browse the repository at this point in the history
remove the dependency of property num_atoms on the cube file content
  • Loading branch information
robinzyb committed Jul 12, 2024
2 parents 6e26ed0 + 5e512e3 commit 1883beb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 25 deletions.
69 changes: 45 additions & 24 deletions cp2kdata/cube/cube.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
from cp2kdata.utils import file_content, interpolate_spline
from cp2kdata.utils import au2A, au2eV
from cp2kdata.cell import Cp2kCell
import os
from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt
import os
from scipy import fft
from ase import Atom, Atoms
from monty.json import MSONable
from copy import deepcopy
import asciichartpy as acp


def square_wave_filter(x, l, cell_z):
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l
y = np.concatenate([y_1st, y_2nd])
return y
from cp2kdata.utils import file_content, interpolate_spline
from cp2kdata.utils import au2A, au2eV
from cp2kdata.cell import Cp2kCell

# parse cp2kcube

Expand Down Expand Up @@ -119,11 +112,12 @@ def get_pav(self, axis="z", interpolate=False):
def get_mav(self, l1, l2=0, ncov=1, interpolate=False):
axis = "z"
pav_x, pav = self.get_pav(axis=axis, interpolate=interpolate)
theta_1_fft = fft.fft(square_wave_filter(pav_x, l1, self.cell_z))
theta_1_fft = fft.fft(self.square_wave_filter(pav_x, l1, self.cell_z))
pav_fft = fft.fft(pav)
mav_fft = pav_fft*theta_1_fft*self.cell_z/len(pav_x)
if ncov == 2:
theta_2_fft = fft.fft(square_wave_filter(pav_x, l2, self.cell_z))
theta_2_fft = fft.fft(
self.square_wave_filter(pav_x, l2, self.cell_z))
mav_fft = mav_fft*theta_2_fft*self.cell_z/len(pav_x)
mav = fft.ifft(mav_fft)
return pav_x, np.real(mav)
Expand All @@ -137,14 +131,23 @@ def quick_plot(self, axis="z", interpolate=False, output_dir="./"):
plt.legend()
plt.savefig(os.path.join(output_dir, "pav.png"), dpi=100)

@staticmethod
def square_wave_filter(x, l, cell_z):
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l
y = np.concatenate([y_1st, y_2nd])
return y


class Cp2kCube(MSONable):
# add MSONable use as_dict and from_dict
"""
Documentation for the Cp2kCube class.
"""

def __init__(self, fname=None, cube_vals=None, cell=None, stc=None):
def __init__(self, fname: str = None, cube_vals: np.ndarray = None, cell: Cp2kCell = None, stc: Atoms = None):
print("Warning: This is New Cp2kCube Class, if you want to use old Cp2kCube")
print("try, from cp2kdata.cube.cube import Cp2kCubeOld")
print("New Cp2kCube return raw values in cp2k cube file")
Expand All @@ -159,7 +162,7 @@ def __init__(self, fname=None, cube_vals=None, cell=None, stc=None):
else:
self.cell = cell
if stc is None:
self.stc = self.get_stc()
self.stc = self._parse_stc()
else:
self.stc = stc

Expand All @@ -179,6 +182,12 @@ def read_cell(self):

@property
def num_atoms(self):
return len(self.stc)

def _parse_num_atoms(self):
"""
be used to parse the number of atoms from the cube file only
"""
line = file_content(self.file, 2)
num_atoms = int(line.split()[0])
return num_atoms
Expand Down Expand Up @@ -215,9 +224,10 @@ def __sub__(self, others):
raise RuntimeError("Unspported Class")
return other_copy

def get_stc(self):
def _parse_stc(self):
num_atoms = self._parse_num_atoms()
atom_list = []
for i in range(self.num_atoms):
for i in range(num_atoms):
stc_vals = file_content(self.file, (6+i, 6+i+1))
stc_vals = stc_vals.split()
atom = Atom(
Expand All @@ -231,12 +241,14 @@ def get_stc(self):
stc.set_cell(self.cell.cell_matrix*au2A)
return stc

def get_stc(self):
return self.stc

def copy(self):
return deepcopy(self)

def get_pav(self, axis='z', interpolate=False):


# do the planar average along specific axis
lengths = self.cell.get_cell_lengths()
grid_point = self.cell.grid_point
Expand All @@ -247,7 +259,7 @@ def get_pav(self, axis='z', interpolate=False):
length = lengths[0]

np.testing.assert_array_equal(
self.cell.get_cell_angles()[[1,2]],
self.cell.get_cell_angles()[[1, 2]],
np.array([90.0, 90.0]),
err_msg=f"The axis x is not perpendicular to yz plane, the pav can not be used!"
)
Expand All @@ -269,7 +281,7 @@ def get_pav(self, axis='z', interpolate=False):
length = lengths[2]

np.testing.assert_array_equal(
self.cell.get_cell_angles()[[0,1]],
self.cell.get_cell_angles()[[0, 1]],
np.array([90.0, 90.0]),
err_msg=f"The axis z is not perpendicular to xy plane, the pav can not be used!"
)
Expand Down Expand Up @@ -297,11 +309,11 @@ def get_mav(self, l1, l2=0, ncov=1, interpolate=False, axis="z"):
length = cell_length[axis]

pav_x, pav = self.get_pav(axis=axis, interpolate=interpolate)
theta_1_fft = fft.fft(square_wave_filter(pav_x, l1, length))
theta_1_fft = fft.fft(self.square_wave_filter(pav_x, l1, length))
pav_fft = fft.fft(pav)
mav_fft = pav_fft*theta_1_fft*length/len(pav_x)
if ncov == 2:
theta_2_fft = fft.fft(square_wave_filter(pav_x, l2, length))
theta_2_fft = fft.fft(self.square_wave_filter(pav_x, l2, length))
mav_fft = mav_fft*theta_2_fft*length/len(pav_x)
mav = fft.ifft(mav_fft)
return pav_x, np.real(mav)
Expand Down Expand Up @@ -443,6 +455,15 @@ def read_cube_vals(fname, num_atoms, grid_point):
cube_vals = cube_vals.reshape(grid_point)
return cube_vals

@staticmethod
def square_wave_filter(x, l, cell_z):
half_l = l/2
x_1st, x_2nd = np.array_split(x, 2)
y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l
y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l
y = np.concatenate([y_1st, y_2nd])
return y


class Cp2kCubeTraj:
def __init__(cube_dir, prefix):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cube/test_cube.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from cp2kdata import Cp2kCube
import os

import pytest
import numpy as np

from cp2kdata import Cp2kCube


path_prefix = "tests/test_cube/"
cube_list = [Cp2kCube(os.path.join(path_prefix, "Si_bulk8-v_hartree-1_0.cube"))]
Expand Down

0 comments on commit 1883beb

Please sign in to comment.