Skip to content

Commit

Permalink
Refactoring special function implementations (#112)
Browse files Browse the repository at this point in the history
* refactoring special function implementations

* fix out of bounds jnp.take in eri

* additional binom test cases

* fix test expect
  • Loading branch information
hatemhelal authored Sep 30, 2023
1 parent ec94c8b commit 023e114
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 86 deletions.
105 changes: 26 additions & 79 deletions pyscf_ipu/experimental/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,101 +6,47 @@

import jax.numpy as jnp
import numpy as np
from jax import jit, lax, tree_map, vmap
from jax import jit, tree_map, vmap
from jax.ops import segment_sum
from jax.scipy.special import gammainc, gammaln

from .basis import Basis
from .orbital import batch_orbitals
from .primitive import Primitive, product
from .types import Float3, FloatN, FloatNx3, FloatNxN, IntN
from .special import binom, binom_factor, factorial, factorial2, gammanu
from .types import Float3, FloatN, FloatNx3, FloatNxN
from .units import LMAX

# Maximum value an individual component of the angular momentum lmn can take
# Used for static ahead-of-time compilation of functions involving lmn.
LMAX = 4

"""
Special functions used in integral evaluation
"""
JAX implementation for integrals over Gaussian basis functions. Based upon the
closed-form expressions derived in
Taketa, H., Huzinaga, S., & O-ohata, K. (1966). Gaussian-expansion methods for
molecular integrals. Journal of the physical society of Japan, 21(11), 2313-2324.
<https://doi.org/10.1143/JPSJ.21.2313>
def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where(i <= n, i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial_gamma(n: IntN) -> IntN:
"""Appoximate factorial by evaluating the gamma function in log-space.
This approximation is exact for small integers (n < 10).
"""
approx = jnp.exp(gammaln(n + 1))
return jnp.rint(approx)


factorial = factorial_fori


def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))

Hereafter referred to as the "THO paper"
def binom(x: IntN, y: IntN, nmax: int = LMAX) -> IntN:
bang = partial(factorial, nmax=nmax)
c = x * bang(x - 1) / (bang(y) * bang(x - y))
return jnp.where(x == y, 1, c)
Related work:

def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN:
"""
eq 2.11 from THO but simplified using SymPy and converted to jax
t, u = symbols("t u", real=True, positive=True)
nu = Symbol("nu", integer=True, nonnegative=True)
expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1)))
f = lambdify((nu, t), expr, modules="scipy")
?f
We evaulate this in log-space to avoid overflow/nan
"""
t = jnp.maximum(t, epsilon)
x = nu + 0.5
gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x)
return jnp.exp(gn)


@partial(vmap, in_axes=(0, None, None, None, None))
def binom_factor(s: IntN, i: int, j: int, pa: float, pb: float):
"""
Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An
[1] Augspurger JD, Dykstra CE. General quantum mechanical operators. An
open-ended approach for one-electron integrals with Gaussian bases. Journal of
computational chemistry. 1990 Jan;11(1):105-11.
<https://doi.org/10.1002/jcc.540110113>
"""
def term(t):
return binom(i, s - t) * binom(j, t) * pa ** (i - s + t) * pb ** (j - t)

def body_fun(t, val):
mask = (t <= s) & (t >= (s - i)) & (t <= j)
return val + jnp.where(mask, term(t), 0.0)

return lax.fori_loop(0, LMAX + 1, body_fun, 0.0)
[2] PyQuante: <https://github.com/rpmuller/pyquante2/>
"""


@partial(vmap, in_axes=(0, 0, 0, 0, None))
def overlap_axis(i: int, j: int, pa: int, pb: int, alpha: float) -> float:
ii = jnp.arange(LMAX + 1)
out = binom_factor(2 * ii, i, j, pa, pb)
out *= factorial2(2 * ii - 1) / (2 * alpha) ** ii
mask = ii <= jnp.floor_divide(i + j, 2)
out = jnp.where(mask, out, 0.0)
def overlap_axis(i: int, j: int, a: float, b: float, alpha: float) -> float:
idx = [(s, t) for s in range(LMAX + 1) for t in range(2 * s + 1)]
s, t = jnp.array(idx, dtype=jnp.uint32).T
out = binom(i, 2 * s - t) * binom(j, t)
out *= a ** (i - (2 * s - t)) * b ** (j - t)
out *= factorial2(2 * s - 1) / (2 * alpha) ** s

mask = (2 * s - i <= t) & (t <= j)
out = jnp.where(mask, out, 0)
return jnp.sum(out)


Expand Down Expand Up @@ -179,7 +125,7 @@ def g_term(l1, l2, pa, pb, cp):
index = i - 2 * r - u
g = (
jnp.power(-1, i + u)
* binom_factor(i, l1, l2, pa, pb)
* jnp.take(binom_factor(l1, l2, pa, pb), i)
* factorial(i)
* jnp.power(cp, index - u)
* jnp.power(epsilon, r + u)
Expand Down Expand Up @@ -251,7 +197,7 @@ def H(l1, l2, a, b, i, r, gamma):
# Note this should match THO Eq 3.5 but that seems to incorrectly show a
# 1/(4 gamma) ^(i- 2r) term which is inconsistent with Eq 2.22.
# Using (4 gamma)^(r - i) matches the reported expressions for H_L
u = factorial(i) * binom_factor(i, l1, l2, a, b)
u = factorial(i) * jnp.take(binom_factor(l1, l2, a, b, 2 * LMAX), i)
v = factorial(r) * factorial(i - 2 * r) * (4 * gamma) ** (i - r)
return u / v

Expand All @@ -269,6 +215,7 @@ def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp):
return segment_sum(c, index, num_segments=4 * LMAX + 1)

# Manual vmap over cartesian axes (x, y, z) as ran into possible bug.
# See https://github.com/graphcore-research/pyscf-ipu/issues/105
args = [a.lmn, b.lmn, c.lmn, d.lmn, pa, pb, qc, qd, qp]
Ci, Cj, Ck = [c_term(*[v.at[i].get() for v in args]) for i in range(3)]

Expand Down
110 changes: 110 additions & 0 deletions pyscf_ipu/experimental/special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial

import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.ops import segment_sum
from jax.scipy.special import betaln, gammainc, gammaln

from .types import FloatN, IntN
from .units import LMAX


def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where(i <= n, i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial_gamma(n: IntN) -> IntN:
"""Appoximate factorial by evaluating the gamma function in log-space.
This approximation is exact for small integers (n < 10).
"""
approx = jnp.exp(gammaln(n + 1))
return jnp.rint(approx)


def factorial_lookup(n: IntN, nmax: int = LMAX) -> IntN:
N = np.cumprod(np.arange(1, nmax + 1))
N = np.insert(N, 0, 1)
N = jnp.array(N, dtype=jnp.uint32)
return N.at[n.astype(jnp.uint32)].get()


factorial = factorial_gamma


def factorial2_fori(n: IntN, nmax: int = 2 * LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial2_lookup(n: IntN, nmax: int = 2 * LMAX) -> IntN:
stop = nmax + 1 if nmax % 2 == 0 else nmax + 2
N = np.arange(1, stop).reshape(-1, 2)
N = np.cumprod(N, axis=0).reshape(-1)
N = np.insert(N, 0, 1)
N = jnp.array(N)
n = jnp.maximum(n, 0)
return N.at[n].get()


factorial2 = factorial2_lookup


def binom_beta(x: IntN, y: IntN) -> IntN:
approx = 1.0 / ((x + 1) * jnp.exp(betaln(x - y + 1, y + 1)))
return jnp.rint(approx)


def binom_fori(x: IntN, y: IntN, nmax: int = LMAX) -> IntN:
bang = partial(factorial_fori, nmax=nmax)
c = x * bang(x - 1) / (bang(y) * bang(x - y))
return jnp.where(x == y, 1, c)


def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN:
bang = partial(factorial_lookup, nmax=nmax)
c = x * bang(x - 1) / (bang(y) * bang(x - y))
return jnp.where(x == y, 1, c)


binom = binom_lookup


def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN:
"""
eq 2.11 from THO but simplified using SymPy and converted to jax
t, u = symbols("t u", real=True, positive=True)
nu = Symbol("nu", integer=True, nonnegative=True)
expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1)))
f = lambdify((nu, t), expr, modules="scipy")
?f
We evaulate this in log-space to avoid overflow/nan
"""
t = jnp.maximum(t, epsilon)
x = nu + 0.5
gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x)
return jnp.exp(gn)


def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN:
"""
Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An
open-ended approach for one-electron integrals with Gaussian bases. Journal of
computational chemistry. 1990 Jan;11(1):105-11.
<https://doi.org/10.1002/jcc.540110113>
"""
s, t = jnp.tril_indices(lmax + 1)
out = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t)
mask = ((s - i) <= t) & (t <= j)
out = jnp.where(mask, out, 0.0)
return segment_sum(out, s, num_segments=lmax + 1)
4 changes: 4 additions & 0 deletions pyscf_ipu/experimental/units.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from jaxtyping import Array

# Maximum value an individual component of the angular momentum lmn can take
# Used for static ahead-of-time compilation of functions involving lmn.
LMAX = 4

BOHR_PER_ANGSTROM = 0.529177210903


Expand Down
51 changes: 44 additions & 7 deletions test/test_numerics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
import pytest
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64
from pyscf_ipu.experimental.special import (
binom_beta,
binom_fori,
binom_lookup,
factorial2_fori,
factorial2_lookup,
factorial_fori,
factorial_gamma,
factorial_lookup,
)


def test_factorial():
n = 16
x = jnp.arange(n, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, 1e-2)
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
expect = jnp.array([1, 2, 6, 24, 120, 720, 5040, 40320])
assert_allclose(factorial_fori(x, x[-1]), expect)
assert_allclose(factorial_lookup(x, x[-1]), expect)
assert_allclose(factorial_gamma(x), expect)


def test_factorial2():
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
expect = jnp.array([1, 2, 3, 8, 15, 48, 105, 384])
assert_allclose(factorial2_fori(x), expect)
assert_allclose(factorial2_fori(0), 1)

assert_allclose(factorial2_lookup(x), expect)
assert_allclose(factorial2_lookup(0), 1)


@pytest.mark.parametrize("binom_func", [binom_beta, binom_fori, binom_lookup])
def test_binom(binom_func):
x = jnp.array([4, 4, 4, 4])
y = jnp.array([1, 2, 3, 4])
expect = jnp.array([4, 6, 4, 1])
assert_allclose(binom_func(x, y), expect)

zero = jnp.array([0])
assert_allclose(binom_func(zero, y), jnp.zeros_like(x))
assert_allclose(binom_func(x, zero), jnp.ones_like(y))
assert_allclose(binom_func(y, y), jnp.ones_like(y))

one = jnp.array([1])
assert_allclose(binom_func(one, one), one)
assert_allclose(binom_func(zero, -one), zero)
assert_allclose(binom_func(zero, zero), one)

0 comments on commit 023e114

Please sign in to comment.