Skip to content

Commit

Permalink
hide embedding_charges
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Sep 12, 2024
1 parent 78dc0c7 commit e1b2206
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 2 deletions.
7 changes: 7 additions & 0 deletions qcmanybody/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
import os
import string
from collections import Counter, defaultdict
from typing import Any, Dict, Iterable, Literal, Mapping, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -44,6 +45,12 @@ def __init__(
embedding_charges: Mapping[int, Sequence[float]],
):
self.embedding_charges = embedding_charges
if self.embedding_charges:
if not bool(os.environ.get("QCMANYBODY_EMBEDDING_CHARGES", False)): # obscure until further validation
raise ValueError(
f"Embedding charges for EE-MBE are still in testing. Set environment variable QCMANYBODY_EMBEDDING_CHARGES=1 to use at your own risk."
)

if isinstance(molecule, dict):
mol = Molecule(**molecule)
elif isinstance(molecule, Molecule):
Expand Down
4 changes: 3 additions & 1 deletion qcmanybody/models/manybody_input_pydv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ class ManyBodyKeywords(ProtoModel):
embedding_charges: Optional[Dict[int, List[float]]] = Field(
None,
description="Atom-centered point charges to be used on molecule fragments whose basis sets are not included in "
"the computation. Keys: 1-based index of fragment. Values: list of atom charges for that fragment.",
"the computation. Keys: 1-based index of fragment. Values: list of atom charges for that fragment. "
"At present, QCManyBody will only accept non-None values of this keyword if environment variable "
"QCMANYBODY_EMBEDDING_CHARGES is set.",
# TODO embedding charges should sum to fragment charge, right? enforce?
# TODO embedding charges irrelevant to CP (basis sets always present)?
json_schema_extra={
Expand Down
1 change: 0 additions & 1 deletion qcmanybody/models/manybody_output_pydv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

MAX_NBODY = int(os.environ.get("QCMANYBODY_MAX_NBODY", 5)) # 5 covers tetramers

# TODO: bump up default MAX_NBODY and add some warnings or mitigations so insufficient value doesn't fail at very end at Result formation time

json_schema_extras = {
"energy": {"units": "E_h"},
Expand Down
22 changes: 22 additions & 0 deletions qcmanybody/tests/test_computer_het4_gradient.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pprint
import re

Expand Down Expand Up @@ -246,6 +247,7 @@ def test_nbody_het4_grad(mbe_keywords, anskeyE, anskeyG, bodykeys, outstrs, calc
_inner = request.node.name.split("[")[1].split("]")[0]
kwdsln, pattern, progln = _inner, "", "psi4"
print("LANE", kwdsln, pattern, progln)
os.environ["QCMANYBODY_EMBEDDING_CHARGES"] = "1"

mbe_keywords = ManyBodyKeywords(**mbe_keywords)
mbe_data_grad_dtz["molecule"] = het_tetramer
Expand Down Expand Up @@ -313,6 +315,26 @@ def test_nbody_het4_grad(mbe_keywords, anskeyE, anskeyG, bodykeys, outstrs, calc
assert re.search(sumstr[stdoutkey], ret.stdout, re.MULTILINE), f"[j] N-Body pattern not found: {sumstr[stdoutkey]}"


@pytest.mark.parametrize("mbe_keywords,errmsg", [
pytest.param(
{"bsse_type": "nocp", "return_total_data": True, "levels": {4: "p4-hf"}, "embedding_charges": {1: [-1.0], 2: [0.5, -0.5], 3: [-0.5, 0.5], 4: [0]}},
"Embedding charges for EE-MBE are still in testing",
id="4b_nocp_rtd_ee_error"),
])
def test_nbody_ee_error(mbe_keywords, errmsg, het_tetramer, mbe_data_grad_dtz):

mbe_keywords = ManyBodyKeywords(**mbe_keywords)
mbe_data_grad_dtz["molecule"] = het_tetramer
mbe_data_grad_dtz["specification"]["driver"] = "gradient"
mbe_data_grad_dtz["specification"]["keywords"] = mbe_keywords
mbe_model = ManyBodyInput(**mbe_data_grad_dtz)

with pytest.raises(ValueError) as e:
ManyBodyComputer.from_manybodyinput(mbe_model)

assert errmsg in str(e.value), e.value


def test_fragmentless_mol(mbe_data_grad_dtz):
het_tetramer_fragmentless = Molecule(
symbols=["F", "H", "F", "H", "H", "He"],
Expand Down

0 comments on commit e1b2206

Please sign in to comment.