Skip to content

Commit

Permalink
add integration by range for Cp2kCube class
Browse files Browse the repository at this point in the history
  • Loading branch information
robinzyb committed Jul 12, 2024
1 parent 4722358 commit 5c56022
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pub-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
--wheel
--outdir dist/
.
- name: Publish package distributions to PyPI
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
25 changes: 15 additions & 10 deletions cp2kdata/cell.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from ase.geometry.cell import cellpar_to_cell
from ase.geometry.cell import cell_to_cellpar
import numpy.typing as npt
from copy import deepcopy


import numpy as np
import numpy.typing as npt
from numpy.linalg import LinAlgError
from copy import deepcopy
from ase.geometry.cell import cellpar_to_cell
from ase.geometry.cell import cell_to_cellpar

from cp2kdata.log import get_logger

logger = get_logger(__name__)

class Cp2kCell:
def __init__(
Expand Down Expand Up @@ -37,7 +42,7 @@ def __init__(
[0, 0, cell_param]
]
)
print("input cell_param is a float, the cell is assumed to be cubic")
logger.info("Input cell_param is a float, the cell is assumed to be cubic")
elif cell_param.shape == (3,):
self.cell_matrix = np.array(
[
Expand All @@ -46,24 +51,24 @@ def __init__(
[0, 0, cell_param[2]]
]
)
print("the length of input cell_param is 3, "
logger.info("The length of input cell_param is 3, "
"the cell is assumed to be orthorhombic")
elif cell_param.shape == (6,):
self.cell_matrix = cellpar_to_cell(cell_param)
print("the length of input cell_param is 6, "
logger.info("The length of input cell_param is 6, "
"the Cp2kCell assumes it is [a, b, c, alpha, beta, gamma], "
"which will be converted to cell matrix")
elif cell_param.shape == (3, 3):
self.cell_matrix = cell_param
print("input cell_param is a matrix with shape of (3,3), "
logger.info("Input cell_param is a matrix with shape of (3,3), "
"the cell is read as is")
else:
raise ValueError("The input cell_param is not supported")

if (grid_point is None) and (grid_spacing_matrix is None):
self.grid_point = None
self.grid_spacing_matrix = None
print("No grid point information")
logger.info("No grid point information")
elif (grid_point is None) and (grid_spacing_matrix is not None):
self.grid_spacing_matrix = grid_spacing_matrix
self.grid_point = np.round(
Expand Down Expand Up @@ -91,7 +96,7 @@ def get_dv(self):
try:
return np.linalg.det(self.grid_spacing_matrix)
except LinAlgError as ae:
print("No grid point information is available")
logger.exception("No grid point information is available")

def get_cell_param(self):
return self.cell_param
Expand Down
72 changes: 68 additions & 4 deletions cp2kdata/cube/cube.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from copy import deepcopy

import numpy as np
Expand All @@ -10,7 +9,9 @@
import asciichartpy as acp

from cp2kdata.log import get_logger
from cp2kdata.utils import file_content, interpolate_spline
from cp2kdata.utils import file_content
from cp2kdata.utils import interpolate_spline
from cp2kdata.utils import find_closet_idx_by_value
from cp2kdata.units import au2A, au2eV
from cp2kdata.cell import Cp2kCell

Expand Down Expand Up @@ -231,9 +232,72 @@ def write_cube(self, fname, comments='#'):
if grid_point[2] % 6 != 0:
fw.write('\n')

def get_integration(self):
def get_integration(self,
start_x: float=None,
end_x: float=None,
start_y: float=None,
end_y: float=None,
start_z: float=None,
end_z: float=None
)-> float:

logger.info("Start to calculate the integration of the cube file")


_cell_angles = self.cell.get_cell_angles()
_grid_point = self.cell.grid_point
_gs_matrix = self.cell.grid_spacing_matrix

_x_array = np.arange(0, _grid_point[0])*_gs_matrix[0][0]
_y_array = np.arange(0, _grid_point[1])*_gs_matrix[1][1]
_z_array = np.arange(0, _grid_point[2])*_gs_matrix[2][2]

if (start_x is not None) or (end_x is not None) or (start_y is not None) or (end_y is not None) or (start_z is not None) or (end_z is not None):
if np.all(_cell_angles == 90.0):
if start_x is None:
_idx_start_x = None
else:
_idx_start_x = find_closet_idx_by_value(_x_array, start_x)
if end_x is None:
_idx_end_x = None
else:
_idx_end_x = find_closet_idx_by_value(_x_array, end_x)
if start_y is None:
_idx_start_y = None
else:
_idx_start_y = find_closet_idx_by_value(_y_array, start_y)
if end_y is None:
_idx_end_y = None
else:
_idx_end_y = find_closet_idx_by_value(_y_array, end_y)
if start_z is None:
_idx_start_z = None
else:
_idx_start_z = find_closet_idx_by_value(_z_array, start_z)
if end_z is None:
_idx_end_z = None
else:
_idx_end_z = find_closet_idx_by_value(_z_array, end_z)

_slice_x = slice(_idx_start_x, _idx_end_x)
_slice_y = slice(_idx_start_y, _idx_end_y)
_slice_z = slice(_idx_start_z, _idx_end_z)
else:
raise ValueError("To use integration by range, all cell angles should be 90 degree")
else:
_slice_x = slice(None)
_slice_y = slice(None)
_slice_z = slice(None)


logger.info(f"The integration range for x is from {_x_array[_slice_x][0]:.3f} Bohr to {_x_array[_slice_x][-1]:.3f} Bohr")
logger.info(f"The integration range for y is from {_y_array[_slice_y][0]:.3f} Bohr to {_y_array[_slice_y][-1]:.3f} Bohr")
logger.info(f"The integration range for z is from {_z_array[_slice_z][0]:.3f} Bohr to {_z_array[_slice_z][-1]:.3f} Bohr")

_cube_vals_to_integrate = self.cube_vals[_slice_x, _slice_y, _slice_z]
dv = self.cell.get_dv()
result = np.sum(self.cube_vals)*dv
result = np.sum(_cube_vals_to_integrate)*dv

return result

def get_cell(self):
Expand Down
5 changes: 4 additions & 1 deletion cp2kdata/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
level = logging._nameToLevel.get(level_name, logging.INFO)

# format to include timestamp and module
logging.basicConfig(format='CP2KDATA| %(levelname)-8s %(name)-40s: %(message)s', level=level)
if level_name == 'DEBUG':
logging.basicConfig(format='CP2KDATA| %(asctime)s - %(levelname)-8s %(name)-40s: %(message)s', level=level)
else:
logging.basicConfig(format='CP2KDATA| %(message)s', level=level)
# suppress transitions logging
# logging.getLogger('transitions.core').setLevel(logging.WARNING)

Expand Down

0 comments on commit 5c56022

Please sign in to comment.