Skip to content

Commit

Permalink
Inline past_d1 and past_d2
Browse files Browse the repository at this point in the history
  • Loading branch information
scemama committed Jun 13, 2024
1 parent 7ecc086 commit acc0b97
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 73 deletions.
94 changes: 36 additions & 58 deletions src/cipsi/selection.irp.f
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ subroutine fill_buffer_$DOUBLE(i_generator, sp, h1, h2, bannedOrb, banned, fock_
double precision, external :: diag_H_mat_elem_fock
double precision :: E_shift
double precision :: s_weight(N_states,N_states)
PROVIDE dominant_dets_of_cfgs N_dominant_dets_of_cfgs
PROVIDE dominant_dets_of_cfgs N_dominant_dets_of_cfgs thresh_sym excitation_ref hf_bitmask elec_alpha_num
do jstate=1,N_states
do istate=1,N_states
s_weight(istate,jstate) = dsqrt(selection_weight(istate)*selection_weight(jstate))
Expand Down Expand Up @@ -746,7 +746,7 @@ subroutine fill_buffer_$DOUBLE(i_generator, sp, h1, h2, bannedOrb, banned, fock_
do istate=1,N_states
delta_E = E0(istate) - Hii + E_shift
alpha_h_psi = mat(istate, p1, p2)
if (alpha_h_psi == 0.d0) cycle
if (dabs(alpha_h_psi) < mo_integrals_threshold) cycle
val = alpha_h_psi + alpha_h_psi
tmp = dsqrt(delta_E * delta_E + val * val)
Expand Down Expand Up @@ -1000,18 +1000,36 @@ subroutine splash_pq(mask, sp, det, i_gen, N_sel, bannedOrb, banned, mat, intere
if(nt == 4) then
call get_d2(det(1,1,i), phasemask, bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i)))
else if(nt == 3) then
call get_d1(det(1,1,i), phasemask, bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i))) !, hij_cache)
call get_d1(det(1,1,i), phasemask, bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i)), hij_cache)
else
call get_d0(det(1,1,i), phasemask, bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i)), hij_cache)
end if
else if(nt == 4) then
call bitstring_to_list_in_selection(mobMask(1,1), p(1,1), p(0,1), N_int)
call bitstring_to_list_in_selection(mobMask(1,2), p(1,2), p(0,2), N_int)
call past_d2(banned, p, sp)
if(sp == 3) then
do j=1,p(0,2)
do ii=1,p(0,1)
banned(p(ii,1), p(j,2),1) = .true.
end do
end do
else
do ii=1,p(0, sp)
do j=1,ii-1
banned(p(j,sp), p(ii,sp),1) = .true.
banned(p(ii,sp), p(j,sp),1) = .true.
end do
end do
end if
else if(nt == 3) then
call bitstring_to_list_in_selection(mobMask(1,1), p(1,1), p(0,1), N_int)
call bitstring_to_list_in_selection(mobMask(1,2), p(1,2), p(0,2), N_int)
call past_d1(bannedOrb, p)
do ii = 1, p(0, 1)
bannedOrb(p(ii, 1), 1) = .true.
end do
do ii = 1, p(0, 2)
bannedOrb(p(ii, 2), 2) = .true.
end do
end if
end do
Expand Down Expand Up @@ -1042,6 +1060,7 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
integer :: bant
bant = 1
PROVIDE mo_integrals_threshold
tip = p(0,1) * p(0,2)
ma = sp
Expand All @@ -1067,7 +1086,7 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p2 = p(i2, ma)
hij = mo_two_e_integral(p1, p2, h1, h2) - mo_two_e_integral(p2, p1, h1, h2)
if (hij == 0.d0) cycle
if (dabs(hij) < mo_integrals_threshold) cycle
hij = hij * get_phase_bi(phasemask, ma, ma, h1, p1, h2, p2, N_int)
Expand Down Expand Up @@ -1097,7 +1116,7 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p1 = p(turn2(i), 1)
hij = mo_two_e_integral(p1, p2, h1, h2)
if (hij /= 0.d0) then
if (dabs(hij) > mo_integrals_threshold) then
hij = hij * get_phase_bi(phasemask, 1, 2, h1, p1, h2, p2, N_int)
!DIR$ LOOP COUNT AVG(4)
do k=1,N_states
Expand Down Expand Up @@ -1125,7 +1144,7 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p1 = p(i1, ma)
p2 = p(i2, ma)
hij = mo_two_e_integral(p1, p2, h1, h2) - mo_two_e_integral(p2,p1, h1, h2)
if (hij == 0.d0) cycle
if (dabs(hij) < mo_integrals_threshold) cycle
hij = hij * get_phase_bi(phasemask, ma, ma, h1, p1, h2, p2, N_int)
!DIR$ LOOP COUNT AVG(4)
Expand All @@ -1147,7 +1166,7 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p2 = p(i, ma)
hij = mo_two_e_integral(p1, p2, h1, h2)
if (hij == 0.d0) cycle
if (dabs(hij) < mo_integrals_threshold) cycle
hij = hij * get_phase_bi(phasemask, mi, ma, h1, p1, h2, p2, N_int)
if (puti < putj) then
Expand Down Expand Up @@ -1184,7 +1203,7 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
end
subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs, hij_cache)
use bitmasks
implicit none
Expand All @@ -1195,7 +1214,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
double precision, intent(in) :: coefs(N_states)
double precision, intent(inout) :: mat(N_states, mo_num, mo_num)
integer, intent(in) :: h(0:2,2), p(0:4,2), sp
! double precision, intent(in) :: hij_cache(mo_num, mo_num, 2)
double precision, intent(in) :: hij_cache(mo_num, mo_num, 2)
double precision, external :: get_phase_bi, mo_two_e_integral
logical :: ok
Expand Down Expand Up @@ -1237,13 +1256,11 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
p1 = p(1,ma)
p2 = p(2,ma)
if(.not. bannedOrb(puti, mi)) then
call get_mo_two_e_integrals(hfix,p1,p2,mo_num,hij_cache1(1,1),mo_integrals_map)
call get_mo_two_e_integrals(hfix,p2,p1,mo_num,hij_cache1(1,2),mo_integrals_map)
tmp_row = 0d0
do putj=1, hfix-1
if(lbanned(putj, ma)) cycle
if(banned(putj, puti,bant)) cycle
hij = hij_cache1(putj,1) - hij_cache1(putj,2)
hij = hij_cache(hfix,putj,1) - hij_cache(putj,hfix,1)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, ma, putj, p1, hfix, p2, N_int)
!DIR$ LOOP COUNT AVG(4)
Expand All @@ -1255,7 +1272,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs)
do putj=hfix+1, mo_num
if(lbanned(putj, ma)) cycle
if(banned(putj, puti,bant)) cycle
hij = hij_cache1(putj,2) - hij_cache1(putj,1)
hij = hij_cache(putj,hfix,1) - hij_cache(hfix,putj,1)
if (hij /= 0.d0) then
hij = hij * get_phase_bi(phasemask, ma, ma, hfix, p1, putj, p2, N_int)
!DIR$ LOOP COUNT AVG(4)
Expand Down Expand Up @@ -1466,6 +1483,7 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs,
integer, parameter :: bant=1
PROVIDE mo_integrals_threshold
if(sp == 3) then ! AB
h1 = p(1,1)
Expand All @@ -1482,7 +1500,7 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs,
phase = get_phase_bi(phasemask, 1, 2, h1, p1, h2, p2, N_int)
hij = hij_cache(p2,p1,1) * phase
end if
if (hij == 0.d0) cycle
if (dabs(hij) < mo_integrals_threshold) cycle
!DIR$ LOOP COUNT AVG(4)
do k=1,N_states
mat(k, p1, p2) = mat(k, p1, p2) + coefs(k) * hij ! HOTSPOT
Expand All @@ -1501,10 +1519,10 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs,
if(puti == p1 .or. putj == p2 .or. puti == p2 .or. putj == p1) then
call apply_particles(mask, sp,puti,sp,putj, det, ok, N_int)
call i_h_j(gen, det, N_int, hij)
if (hij == 0.d0) cycle
if (dabs(hij) < mo_integrals_threshold) cycle
else
hij = hij_cache(putj,puti,1) - hij_cache(putj,puti,2)
if (hij == 0.d0) cycle
if (dabs(hij) < mo_integrals_threshold) cycle
hij = hij * get_phase_bi(phasemask, sp, sp, puti, p1 , putj, p2, N_int)
end if
!DIR$ LOOP COUNT AVG(4)
Expand All @@ -1518,46 +1536,6 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs,
end
subroutine past_d1(bannedOrb, p)
use bitmasks
implicit none
logical, intent(inout) :: bannedOrb(mo_num, 2)
integer, intent(in) :: p(0:4, 2)
integer :: i,s
do s = 1, 2
do i = 1, p(0, s)
bannedOrb(p(i, s), s) = .true.
end do
end do
end
subroutine past_d2(banned, p, sp)
use bitmasks
implicit none
logical, intent(inout) :: banned(mo_num, mo_num)
integer, intent(in) :: p(0:4, 2), sp
integer :: i,j
if(sp == 3) then
do j=1,p(0,2)
do i=1,p(0,1)
banned(p(i,1), p(j,2)) = .true.
end do
end do
else
do i=1,p(0, sp)
do j=1,i-1
banned(p(j,sp), p(i,sp)) = .true.
banned(p(i,sp), p(j,sp)) = .true.
end do
end do
end if
end
subroutine spot_isinwf(mask, det, i_gen, N, banned, fullMatch, interesting)
use bitmasks
implicit none
Expand Down
29 changes: 14 additions & 15 deletions src/mo_two_e_ints/map_integrals.irp.f
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,6 @@ double precision function get_two_e_integral(i,j,k,l,map)
end


double precision function mo_two_e_integral(i,j,k,l)
implicit none
BEGIN_DOC
! Returns one integral <ij|kl> in the MO basis
END_DOC
integer, intent(in) :: i,j,k,l
double precision :: get_two_e_integral
PROVIDE mo_two_e_integrals_in_map mo_integrals_cache
!DIR$ FORCEINLINE
mo_two_e_integral = get_two_e_integral(i,j,k,l,mo_integrals_map)
return
end

subroutine get_mo_two_e_integrals(j,k,l,sze,out_val,map)
use map_module
implicit none
Expand All @@ -223,8 +210,6 @@ subroutine get_mo_two_e_integrals(j,k,l,sze,out_val,map)
integer(key_kind) :: p,q,r,s,i2
PROVIDE mo_two_e_integrals_in_map mo_integrals_cache



if (banned_excitation(j,l)) then
out_val(1:sze) = 0.d0
return
Expand Down Expand Up @@ -351,6 +336,20 @@ subroutine get_mo_two_e_integrals(j,k,l,sze,out_val,map)

end

double precision function mo_two_e_integral(i,j,k,l)
implicit none
BEGIN_DOC
! Returns one integral <ij|kl> in the MO basis
END_DOC
integer, intent(in) :: i,j,k,l
double precision :: get_two_e_integral
PROVIDE mo_two_e_integrals_in_map mo_integrals_cache
!DIR$ FORCEINLINE
mo_two_e_integral = get_two_e_integral(i,j,k,l,mo_integrals_map)
return
end


subroutine get_mo_two_e_integrals_cache(j,k,l,sze,out_val)
use map_module
implicit none
Expand Down

0 comments on commit acc0b97

Please sign in to comment.