Skip to content

Commit

Permalink
use precise summation
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Sep 13, 2024
1 parent e1b2206 commit 18aff83
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
5 changes: 2 additions & 3 deletions qcmanybody/tests/test_computer_het4_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,10 @@ def het_tetramer():
15,
id="4b_nocp_rtd_ee"),
])
def test_nbody_het4_grad(mbe_keywords, anskeyE, anskeyG, bodykeys, outstrs, calcinfo_nmbe, het_tetramer, request, mbe_data_grad_dtz):
def test_nbody_het4_grad(mbe_keywords, anskeyE, anskeyG, bodykeys, outstrs, calcinfo_nmbe, het_tetramer, request, mbe_data_grad_dtz, monkeypatch):
_inner = request.node.name.split("[")[1].split("]")[0]
kwdsln, pattern, progln = _inner, "", "psi4"
print("LANE", kwdsln, pattern, progln)
os.environ["QCMANYBODY_EMBEDDING_CHARGES"] = "1"
monkeypatch.setenv("QCMANYBODY_EMBEDDING_CHARGES", "1")

mbe_keywords = ManyBodyKeywords(**mbe_keywords)
mbe_data_grad_dtz["molecule"] = het_tetramer
Expand Down
24 changes: 17 additions & 7 deletions qcmanybody/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import json
import math
import re
import string
from typing import Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -222,13 +223,22 @@ def sum_cluster_data(
shape = find_shape(data[first_key])
ret = shaped_zero(shape)

for frag, bas in compute_list:
egh = data[labeler(mc_level_lbl, frag, bas)]

if vmfc:
sign = (-1) ** (nb - len(frag))

ret += sign * egh
precise_sum_func = math.fsum if isinstance(ret, float) else np.sum
ret = precise_sum_func(
(((-1) ** (nb - len(frag))) if vmfc else 1) * (data[labeler(mc_level_lbl, frag, bas)])
for frag, bas in compute_list
)

# A more readable format for the above but not ammenable to using specialty summation functions
# ```
# for frag, bas in compute_list:
# egh = data[labeler(mc_level_lbl, frag, bas)]
#
# if vmfc:
# sign = (-1) ** (nb - len(frag))
#
# ret += sign * egh
# ```

return ret

Expand Down

0 comments on commit 18aff83

Please sign in to comment.