Skip to content

Commit

Permalink
Fix log density spline extrapolation (#133)
Browse files Browse the repository at this point in the history
* Add extrapolation limits to DensitySpline

* Handle log/exp cutoffs better

* Comment and clean up extrapolation code

* Set extrapolation cutoff limit to r = r_N
  • Loading branch information
msricher authored Dec 17, 2024
1 parent 511a702 commit b742618
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ cython_debug/
.vscode/

# Raw data files
atomdb/datasets/*/raw/
atomdb/datasets/*/db/

# Generated documentation
docs/source/api/
Expand Down
48 changes: 22 additions & 26 deletions atomdb/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,23 @@

r"""AtomDB, a database of atomic and ionic properties."""

from dataclasses import dataclass, field, asdict

from glob import glob

from importlib import import_module

import json

import re
from dataclasses import asdict, dataclass, field
from importlib import import_module
from numbers import Integral

from os import makedirs, path

from msgpack import packb, unpackb

from msgpack_numpy import encode, decode

import numpy as np

from numpy import ndarray

import pooch
import re
import requests

from msgpack import packb, unpackb
from msgpack_numpy import decode, encode
from numpy import ndarray
from scipy.interpolate import CubicSpline

from atomdb.utils import DEFAULT_DATASET, DEFAULT_DATAPATH, DEFAULT_REMOTE
from atomdb.periodic import element_symbol, Element

from atomdb.periodic import Element, element_symbol
from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_DATASET, DEFAULT_REMOTE

__all__ = [
"Species",
Expand Down Expand Up @@ -166,7 +154,9 @@ def __init__(self, x, y, log=False):
self._log = log
self._obj = CubicSpline(
x,
np.log(y) if log else y,
# Clip y values to >= ε^2 if using log because they have to be above 0;
# having them be at least ε^2 seems to work based on my testing
np.log(y.clip(min=np.finfo(float).eps ** 2)) if log else y,
axis=0,
bc_type="not-a-knot",
extrapolate=True,
Expand All @@ -192,7 +182,9 @@ def __call__(self, x, deriv=0):
if not (0 <= deriv <= 2):
raise ValueError(f"Invalid derivative order {deriv}; must be 0 <= `deriv` <= 2")
elif self._log:
y = np.exp(self._obj(x))
# Get y = exp(log y). We'll handle errors from small log y values later.
with np.errstate(over="ignore"):
y = np.exp(self._obj(x))
if deriv == 1:
# d(ρ(r)) = d(log(ρ(r))) * ρ(r)
dlogy = self._obj(x, nu=1)
Expand All @@ -201,9 +193,13 @@ def __call__(self, x, deriv=0):
# d^2(ρ(r)) = d^2(log(ρ(r))) * ρ(r) + [d(ρ(r))]^2/ρ(r)
dlogy = self._obj(x, nu=1)
d2logy = self._obj(x, nu=2)
y = d2logy.flatten() * y + dlogy.flatten() ** 2 * y
y = d2logy.flatten() * y + dlogy.flatten() ** 2 / y
else:
y = self._obj(x, nu=deriv)
# Handle errors from the y = exp(log y) operation -- set NaN to zero
np.nan_to_num(y, nan=0., copy=False)
# Cutoff value: assume y(x) is zero where x > final given point x_n
y[x > self._obj.x[-1]] = 0
return y


Expand All @@ -218,7 +214,7 @@ def default(self, obj):
return JSONEncoder.default(self, obj)


class _AtomicOrbitals(object):
class _AtomicOrbitals:
"""Atomic orbitals class."""

def __init__(self, data) -> None:
Expand Down Expand Up @@ -883,13 +879,13 @@ def datafile(
url=f"{remotepath}{dataset.lower()}/db/repodata.txt",
known_hash=None,
path=path.join(datapath, dataset.lower(), "db"),
fname=f"repo_data.txt",
fname="repo_data.txt",
)
# if the file is not found or remote was not valid, use the local repodata file
except (requests.exceptions.HTTPError, ValueError):
repodata = path.join(datapath, dataset.lower(), "db", "repo_data.txt")

with open(repodata, "r") as f:
with open(repodata) as f:
data = f.read()
files = re.findall(rf"\b{elem}+_{charge}+_{mult}+_{nexc}\.msg\b", data)
species_list = []
Expand Down

0 comments on commit b742618

Please sign in to comment.