From 023e11487762731117d73d14bc98b2f74226f800 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Sat, 30 Sep 2023 13:56:20 +0100 Subject: [PATCH] Refactoring special function implementations (#112) * refactoring special function implementations * fix out of bounds jnp.take in eri * additional binom test cases * fix test expect --- pyscf_ipu/experimental/integrals.py | 105 +++++++------------------- pyscf_ipu/experimental/special.py | 110 ++++++++++++++++++++++++++++ pyscf_ipu/experimental/units.py | 4 + test/test_numerics.py | 51 +++++++++++-- 4 files changed, 184 insertions(+), 86 deletions(-) create mode 100644 pyscf_ipu/experimental/special.py diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index c6f0dfa..74224cf 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -6,101 +6,47 @@ import jax.numpy as jnp import numpy as np -from jax import jit, lax, tree_map, vmap +from jax import jit, tree_map, vmap from jax.ops import segment_sum -from jax.scipy.special import gammainc, gammaln from .basis import Basis from .orbital import batch_orbitals from .primitive import Primitive, product -from .types import Float3, FloatN, FloatNx3, FloatNxN, IntN +from .special import binom, binom_factor, factorial, factorial2, gammanu +from .types import Float3, FloatN, FloatNx3, FloatNxN +from .units import LMAX -# Maximum value an individual component of the angular momentum lmn can take -# Used for static ahead-of-time compilation of functions involving lmn. -LMAX = 4 - -""" -Special functions used in integral evaluation """ +JAX implementation for integrals over Gaussian basis functions. Based upon the +closed-form expressions derived in + Taketa, H., Huzinaga, S., & O-ohata, K. (1966). Gaussian-expansion methods for + molecular integrals. Journal of the physical society of Japan, 21(11), 2313-2324. + -def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN: - def body_fun(i, val): - return val * jnp.where(i <= n, i, 1) - - return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) - - -def factorial_gamma(n: IntN) -> IntN: - """Appoximate factorial by evaluating the gamma function in log-space. - - This approximation is exact for small integers (n < 10). - """ - approx = jnp.exp(gammaln(n + 1)) - return jnp.rint(approx) - - -factorial = factorial_fori - - -def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN: - def body_fun(i, val): - return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) - - return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) - +Hereafter referred to as the "THO paper" -def binom(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: - bang = partial(factorial, nmax=nmax) - c = x * bang(x - 1) / (bang(y) * bang(x - y)) - return jnp.where(x == y, 1, c) +Related work: - -def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: - """ - eq 2.11 from THO but simplified using SymPy and converted to jax - - t, u = symbols("t u", real=True, positive=True) - nu = Symbol("nu", integer=True, nonnegative=True) - - expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) - f = lambdify((nu, t), expr, modules="scipy") - ?f - - We evaulate this in log-space to avoid overflow/nan - """ - t = jnp.maximum(t, epsilon) - x = nu + 0.5 - gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) - return jnp.exp(gn) - - -@partial(vmap, in_axes=(0, None, None, None, None)) -def binom_factor(s: IntN, i: int, j: int, pa: float, pb: float): - """ - Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An +[1] Augspurger JD, Dykstra CE. General quantum mechanical operators. An open-ended approach for one-electron integrals with Gaussian bases. Journal of computational chemistry. 1990 Jan;11(1):105-11. - """ - def term(t): - return binom(i, s - t) * binom(j, t) * pa ** (i - s + t) * pb ** (j - t) - - def body_fun(t, val): - mask = (t <= s) & (t >= (s - i)) & (t <= j) - return val + jnp.where(mask, term(t), 0.0) - - return lax.fori_loop(0, LMAX + 1, body_fun, 0.0) +[2] PyQuante: +""" @partial(vmap, in_axes=(0, 0, 0, 0, None)) -def overlap_axis(i: int, j: int, pa: int, pb: int, alpha: float) -> float: - ii = jnp.arange(LMAX + 1) - out = binom_factor(2 * ii, i, j, pa, pb) - out *= factorial2(2 * ii - 1) / (2 * alpha) ** ii - mask = ii <= jnp.floor_divide(i + j, 2) - out = jnp.where(mask, out, 0.0) +def overlap_axis(i: int, j: int, a: float, b: float, alpha: float) -> float: + idx = [(s, t) for s in range(LMAX + 1) for t in range(2 * s + 1)] + s, t = jnp.array(idx, dtype=jnp.uint32).T + out = binom(i, 2 * s - t) * binom(j, t) + out *= a ** (i - (2 * s - t)) * b ** (j - t) + out *= factorial2(2 * s - 1) / (2 * alpha) ** s + + mask = (2 * s - i <= t) & (t <= j) + out = jnp.where(mask, out, 0) return jnp.sum(out) @@ -179,7 +125,7 @@ def g_term(l1, l2, pa, pb, cp): index = i - 2 * r - u g = ( jnp.power(-1, i + u) - * binom_factor(i, l1, l2, pa, pb) + * jnp.take(binom_factor(l1, l2, pa, pb), i) * factorial(i) * jnp.power(cp, index - u) * jnp.power(epsilon, r + u) @@ -251,7 +197,7 @@ def H(l1, l2, a, b, i, r, gamma): # Note this should match THO Eq 3.5 but that seems to incorrectly show a # 1/(4 gamma) ^(i- 2r) term which is inconsistent with Eq 2.22. # Using (4 gamma)^(r - i) matches the reported expressions for H_L - u = factorial(i) * binom_factor(i, l1, l2, a, b) + u = factorial(i) * jnp.take(binom_factor(l1, l2, a, b, 2 * LMAX), i) v = factorial(r) * factorial(i - 2 * r) * (4 * gamma) ** (i - r) return u / v @@ -269,6 +215,7 @@ def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp): return segment_sum(c, index, num_segments=4 * LMAX + 1) # Manual vmap over cartesian axes (x, y, z) as ran into possible bug. + # See https://github.com/graphcore-research/pyscf-ipu/issues/105 args = [a.lmn, b.lmn, c.lmn, d.lmn, pa, pb, qc, qd, qp] Ci, Cj, Ck = [c_term(*[v.at[i].get() for v in args]) for i in range(3)] diff --git a/pyscf_ipu/experimental/special.py b/pyscf_ipu/experimental/special.py new file mode 100644 index 0000000..f07369e --- /dev/null +++ b/pyscf_ipu/experimental/special.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import partial + +import jax.numpy as jnp +import numpy as np +from jax import lax +from jax.ops import segment_sum +from jax.scipy.special import betaln, gammainc, gammaln + +from .types import FloatN, IntN +from .units import LMAX + + +def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN: + def body_fun(i, val): + return val * jnp.where(i <= n, i, 1) + + return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) + + +def factorial_gamma(n: IntN) -> IntN: + """Appoximate factorial by evaluating the gamma function in log-space. + + This approximation is exact for small integers (n < 10). + """ + approx = jnp.exp(gammaln(n + 1)) + return jnp.rint(approx) + + +def factorial_lookup(n: IntN, nmax: int = LMAX) -> IntN: + N = np.cumprod(np.arange(1, nmax + 1)) + N = np.insert(N, 0, 1) + N = jnp.array(N, dtype=jnp.uint32) + return N.at[n.astype(jnp.uint32)].get() + + +factorial = factorial_gamma + + +def factorial2_fori(n: IntN, nmax: int = 2 * LMAX) -> IntN: + def body_fun(i, val): + return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) + + return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) + + +def factorial2_lookup(n: IntN, nmax: int = 2 * LMAX) -> IntN: + stop = nmax + 1 if nmax % 2 == 0 else nmax + 2 + N = np.arange(1, stop).reshape(-1, 2) + N = np.cumprod(N, axis=0).reshape(-1) + N = np.insert(N, 0, 1) + N = jnp.array(N) + n = jnp.maximum(n, 0) + return N.at[n].get() + + +factorial2 = factorial2_lookup + + +def binom_beta(x: IntN, y: IntN) -> IntN: + approx = 1.0 / ((x + 1) * jnp.exp(betaln(x - y + 1, y + 1))) + return jnp.rint(approx) + + +def binom_fori(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: + bang = partial(factorial_fori, nmax=nmax) + c = x * bang(x - 1) / (bang(y) * bang(x - y)) + return jnp.where(x == y, 1, c) + + +def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: + bang = partial(factorial_lookup, nmax=nmax) + c = x * bang(x - 1) / (bang(y) * bang(x - y)) + return jnp.where(x == y, 1, c) + + +binom = binom_lookup + + +def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: + """ + eq 2.11 from THO but simplified using SymPy and converted to jax + + t, u = symbols("t u", real=True, positive=True) + nu = Symbol("nu", integer=True, nonnegative=True) + + expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) + f = lambdify((nu, t), expr, modules="scipy") + ?f + + We evaulate this in log-space to avoid overflow/nan + """ + t = jnp.maximum(t, epsilon) + x = nu + 0.5 + gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) + return jnp.exp(gn) + + +def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN: + """ + Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An + open-ended approach for one-electron integrals with Gaussian bases. Journal of + computational chemistry. 1990 Jan;11(1):105-11. + + """ + s, t = jnp.tril_indices(lmax + 1) + out = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t) + mask = ((s - i) <= t) & (t <= j) + out = jnp.where(mask, out, 0.0) + return segment_sum(out, s, num_segments=lmax + 1) diff --git a/pyscf_ipu/experimental/units.py b/pyscf_ipu/experimental/units.py index 4099f40..93de8ed 100644 --- a/pyscf_ipu/experimental/units.py +++ b/pyscf_ipu/experimental/units.py @@ -1,6 +1,10 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from jaxtyping import Array +# Maximum value an individual component of the angular momentum lmn can take +# Used for static ahead-of-time compilation of functions involving lmn. +LMAX = 4 + BOHR_PER_ANGSTROM = 0.529177210903 diff --git a/test/test_numerics.py b/test/test_numerics.py index 1f52af4..5e6545c 100644 --- a/test/test_numerics.py +++ b/test/test_numerics.py @@ -1,14 +1,51 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import jax.numpy as jnp +import pytest from numpy.testing import assert_allclose -from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma -from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64 +from pyscf_ipu.experimental.special import ( + binom_beta, + binom_fori, + binom_lookup, + factorial2_fori, + factorial2_lookup, + factorial_fori, + factorial_gamma, + factorial_lookup, +) def test_factorial(): - n = 16 - x = jnp.arange(n, dtype=jnp.float32) - y_fori = compare_fp32_to_fp64(factorial_fori)(x, n) - y_gamma = compare_fp32_to_fp64(factorial_gamma)(x) - assert_allclose(y_fori, y_gamma, 1e-2) + x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) + expect = jnp.array([1, 2, 6, 24, 120, 720, 5040, 40320]) + assert_allclose(factorial_fori(x, x[-1]), expect) + assert_allclose(factorial_lookup(x, x[-1]), expect) + assert_allclose(factorial_gamma(x), expect) + + +def test_factorial2(): + x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) + expect = jnp.array([1, 2, 3, 8, 15, 48, 105, 384]) + assert_allclose(factorial2_fori(x), expect) + assert_allclose(factorial2_fori(0), 1) + + assert_allclose(factorial2_lookup(x), expect) + assert_allclose(factorial2_lookup(0), 1) + + +@pytest.mark.parametrize("binom_func", [binom_beta, binom_fori, binom_lookup]) +def test_binom(binom_func): + x = jnp.array([4, 4, 4, 4]) + y = jnp.array([1, 2, 3, 4]) + expect = jnp.array([4, 6, 4, 1]) + assert_allclose(binom_func(x, y), expect) + + zero = jnp.array([0]) + assert_allclose(binom_func(zero, y), jnp.zeros_like(x)) + assert_allclose(binom_func(x, zero), jnp.ones_like(y)) + assert_allclose(binom_func(y, y), jnp.ones_like(y)) + + one = jnp.array([1]) + assert_allclose(binom_func(one, one), one) + assert_allclose(binom_func(zero, -one), zero) + assert_allclose(binom_func(zero, zero), one)