From 5e512e3dde836f3e228602c879552ee4fa861bc7 Mon Sep 17 00:00:00 2001 From: robinzyb <38876805+robinzyb@users.noreply.github.com> Date: Fri, 12 Jul 2024 13:50:59 +0200 Subject: [PATCH] remove the dependency of property num_atoms on the cube file content remove the dependency of property num_atoms on the cube file content. When writing a new cube file, the num_atoms property is obtained from self.stc which is ase Atoms object --- cp2kdata/cube/cube.py | 69 +++++++++++++++++++++++------------- tests/test_cube/test_cube.py | 4 ++- 2 files changed, 48 insertions(+), 25 deletions(-) diff --git a/cp2kdata/cube/cube.py b/cp2kdata/cube/cube.py index d165f64..24bc543 100644 --- a/cp2kdata/cube/cube.py +++ b/cp2kdata/cube/cube.py @@ -1,23 +1,16 @@ -from cp2kdata.utils import file_content, interpolate_spline -from cp2kdata.utils import au2A, au2eV -from cp2kdata.cell import Cp2kCell +import os +from copy import deepcopy + import numpy as np import matplotlib.pyplot as plt -import os from scipy import fft from ase import Atom, Atoms from monty.json import MSONable -from copy import deepcopy import asciichartpy as acp - -def square_wave_filter(x, l, cell_z): - half_l = l/2 - x_1st, x_2nd = np.array_split(x, 2) - y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l - y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l - y = np.concatenate([y_1st, y_2nd]) - return y +from cp2kdata.utils import file_content, interpolate_spline +from cp2kdata.utils import au2A, au2eV +from cp2kdata.cell import Cp2kCell # parse cp2kcube @@ -119,11 +112,12 @@ def get_pav(self, axis="z", interpolate=False): def get_mav(self, l1, l2=0, ncov=1, interpolate=False): axis = "z" pav_x, pav = self.get_pav(axis=axis, interpolate=interpolate) - theta_1_fft = fft.fft(square_wave_filter(pav_x, l1, self.cell_z)) + theta_1_fft = fft.fft(self.square_wave_filter(pav_x, l1, self.cell_z)) pav_fft = fft.fft(pav) mav_fft = pav_fft*theta_1_fft*self.cell_z/len(pav_x) if ncov == 2: - theta_2_fft = fft.fft(square_wave_filter(pav_x, l2, self.cell_z)) + theta_2_fft = fft.fft( + self.square_wave_filter(pav_x, l2, self.cell_z)) mav_fft = mav_fft*theta_2_fft*self.cell_z/len(pav_x) mav = fft.ifft(mav_fft) return pav_x, np.real(mav) @@ -137,6 +131,15 @@ def quick_plot(self, axis="z", interpolate=False, output_dir="./"): plt.legend() plt.savefig(os.path.join(output_dir, "pav.png"), dpi=100) + @staticmethod + def square_wave_filter(x, l, cell_z): + half_l = l/2 + x_1st, x_2nd = np.array_split(x, 2) + y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l + y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l + y = np.concatenate([y_1st, y_2nd]) + return y + class Cp2kCube(MSONable): # add MSONable use as_dict and from_dict @@ -144,7 +147,7 @@ class Cp2kCube(MSONable): Documentation for the Cp2kCube class. """ - def __init__(self, fname=None, cube_vals=None, cell=None, stc=None): + def __init__(self, fname: str = None, cube_vals: np.ndarray = None, cell: Cp2kCell = None, stc: Atoms = None): print("Warning: This is New Cp2kCube Class, if you want to use old Cp2kCube") print("try, from cp2kdata.cube.cube import Cp2kCubeOld") print("New Cp2kCube return raw values in cp2k cube file") @@ -159,7 +162,7 @@ def __init__(self, fname=None, cube_vals=None, cell=None, stc=None): else: self.cell = cell if stc is None: - self.stc = self.get_stc() + self.stc = self._parse_stc() else: self.stc = stc @@ -179,6 +182,12 @@ def read_cell(self): @property def num_atoms(self): + return len(self.stc) + + def _parse_num_atoms(self): + """ + be used to parse the number of atoms from the cube file only + """ line = file_content(self.file, 2) num_atoms = int(line.split()[0]) return num_atoms @@ -215,9 +224,10 @@ def __sub__(self, others): raise RuntimeError("Unspported Class") return other_copy - def get_stc(self): + def _parse_stc(self): + num_atoms = self._parse_num_atoms() atom_list = [] - for i in range(self.num_atoms): + for i in range(num_atoms): stc_vals = file_content(self.file, (6+i, 6+i+1)) stc_vals = stc_vals.split() atom = Atom( @@ -231,12 +241,14 @@ def get_stc(self): stc.set_cell(self.cell.cell_matrix*au2A) return stc + def get_stc(self): + return self.stc + def copy(self): return deepcopy(self) def get_pav(self, axis='z', interpolate=False): - # do the planar average along specific axis lengths = self.cell.get_cell_lengths() grid_point = self.cell.grid_point @@ -247,7 +259,7 @@ def get_pav(self, axis='z', interpolate=False): length = lengths[0] np.testing.assert_array_equal( - self.cell.get_cell_angles()[[1,2]], + self.cell.get_cell_angles()[[1, 2]], np.array([90.0, 90.0]), err_msg=f"The axis x is not perpendicular to yz plane, the pav can not be used!" ) @@ -269,7 +281,7 @@ def get_pav(self, axis='z', interpolate=False): length = lengths[2] np.testing.assert_array_equal( - self.cell.get_cell_angles()[[0,1]], + self.cell.get_cell_angles()[[0, 1]], np.array([90.0, 90.0]), err_msg=f"The axis z is not perpendicular to xy plane, the pav can not be used!" ) @@ -297,11 +309,11 @@ def get_mav(self, l1, l2=0, ncov=1, interpolate=False, axis="z"): length = cell_length[axis] pav_x, pav = self.get_pav(axis=axis, interpolate=interpolate) - theta_1_fft = fft.fft(square_wave_filter(pav_x, l1, length)) + theta_1_fft = fft.fft(self.square_wave_filter(pav_x, l1, length)) pav_fft = fft.fft(pav) mav_fft = pav_fft*theta_1_fft*length/len(pav_x) if ncov == 2: - theta_2_fft = fft.fft(square_wave_filter(pav_x, l2, length)) + theta_2_fft = fft.fft(self.square_wave_filter(pav_x, l2, length)) mav_fft = mav_fft*theta_2_fft*length/len(pav_x) mav = fft.ifft(mav_fft) return pav_x, np.real(mav) @@ -443,6 +455,15 @@ def read_cube_vals(fname, num_atoms, grid_point): cube_vals = cube_vals.reshape(grid_point) return cube_vals + @staticmethod + def square_wave_filter(x, l, cell_z): + half_l = l/2 + x_1st, x_2nd = np.array_split(x, 2) + y_1st = np.heaviside(half_l - np.abs(x_1st), 0)/l + y_2nd = np.heaviside(half_l - np.abs(x_2nd-cell_z), 0)/l + y = np.concatenate([y_1st, y_2nd]) + return y + class Cp2kCubeTraj: def __init__(cube_dir, prefix): diff --git a/tests/test_cube/test_cube.py b/tests/test_cube/test_cube.py index 14e8ddc..3156fe5 100644 --- a/tests/test_cube/test_cube.py +++ b/tests/test_cube/test_cube.py @@ -1,8 +1,10 @@ -from cp2kdata import Cp2kCube import os + import pytest import numpy as np +from cp2kdata import Cp2kCube + path_prefix = "tests/test_cube/" cube_list = [Cp2kCube(os.path.join(path_prefix, "Si_bulk8-v_hartree-1_0.cube"))]