-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactoring special function implementations (#112)
* refactoring special function implementations * fix out of bounds jnp.take in eri * additional binom test cases * fix test expect
- Loading branch information
1 parent
ec94c8b
commit 023e114
Showing
4 changed files
with
184 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from functools import partial | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
from jax import lax | ||
from jax.ops import segment_sum | ||
from jax.scipy.special import betaln, gammainc, gammaln | ||
|
||
from .types import FloatN, IntN | ||
from .units import LMAX | ||
|
||
|
||
def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN: | ||
def body_fun(i, val): | ||
return val * jnp.where(i <= n, i, 1) | ||
|
||
return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) | ||
|
||
|
||
def factorial_gamma(n: IntN) -> IntN: | ||
"""Appoximate factorial by evaluating the gamma function in log-space. | ||
This approximation is exact for small integers (n < 10). | ||
""" | ||
approx = jnp.exp(gammaln(n + 1)) | ||
return jnp.rint(approx) | ||
|
||
|
||
def factorial_lookup(n: IntN, nmax: int = LMAX) -> IntN: | ||
N = np.cumprod(np.arange(1, nmax + 1)) | ||
N = np.insert(N, 0, 1) | ||
N = jnp.array(N, dtype=jnp.uint32) | ||
return N.at[n.astype(jnp.uint32)].get() | ||
|
||
|
||
factorial = factorial_gamma | ||
|
||
|
||
def factorial2_fori(n: IntN, nmax: int = 2 * LMAX) -> IntN: | ||
def body_fun(i, val): | ||
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) | ||
|
||
return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) | ||
|
||
|
||
def factorial2_lookup(n: IntN, nmax: int = 2 * LMAX) -> IntN: | ||
stop = nmax + 1 if nmax % 2 == 0 else nmax + 2 | ||
N = np.arange(1, stop).reshape(-1, 2) | ||
N = np.cumprod(N, axis=0).reshape(-1) | ||
N = np.insert(N, 0, 1) | ||
N = jnp.array(N) | ||
n = jnp.maximum(n, 0) | ||
return N.at[n].get() | ||
|
||
|
||
factorial2 = factorial2_lookup | ||
|
||
|
||
def binom_beta(x: IntN, y: IntN) -> IntN: | ||
approx = 1.0 / ((x + 1) * jnp.exp(betaln(x - y + 1, y + 1))) | ||
return jnp.rint(approx) | ||
|
||
|
||
def binom_fori(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: | ||
bang = partial(factorial_fori, nmax=nmax) | ||
c = x * bang(x - 1) / (bang(y) * bang(x - y)) | ||
return jnp.where(x == y, 1, c) | ||
|
||
|
||
def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: | ||
bang = partial(factorial_lookup, nmax=nmax) | ||
c = x * bang(x - 1) / (bang(y) * bang(x - y)) | ||
return jnp.where(x == y, 1, c) | ||
|
||
|
||
binom = binom_lookup | ||
|
||
|
||
def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: | ||
""" | ||
eq 2.11 from THO but simplified using SymPy and converted to jax | ||
t, u = symbols("t u", real=True, positive=True) | ||
nu = Symbol("nu", integer=True, nonnegative=True) | ||
expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) | ||
f = lambdify((nu, t), expr, modules="scipy") | ||
?f | ||
We evaulate this in log-space to avoid overflow/nan | ||
""" | ||
t = jnp.maximum(t, epsilon) | ||
x = nu + 0.5 | ||
gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) | ||
return jnp.exp(gn) | ||
|
||
|
||
def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN: | ||
""" | ||
Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An | ||
open-ended approach for one-electron integrals with Gaussian bases. Journal of | ||
computational chemistry. 1990 Jan;11(1):105-11. | ||
<https://doi.org/10.1002/jcc.540110113> | ||
""" | ||
s, t = jnp.tril_indices(lmax + 1) | ||
out = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t) | ||
mask = ((s - i) <= t) & (t <= j) | ||
out = jnp.where(mask, out, 0.0) | ||
return segment_sum(out, s, num_segments=lmax + 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,51 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import jax.numpy as jnp | ||
import pytest | ||
from numpy.testing import assert_allclose | ||
|
||
from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma | ||
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64 | ||
from pyscf_ipu.experimental.special import ( | ||
binom_beta, | ||
binom_fori, | ||
binom_lookup, | ||
factorial2_fori, | ||
factorial2_lookup, | ||
factorial_fori, | ||
factorial_gamma, | ||
factorial_lookup, | ||
) | ||
|
||
|
||
def test_factorial(): | ||
n = 16 | ||
x = jnp.arange(n, dtype=jnp.float32) | ||
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n) | ||
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x) | ||
assert_allclose(y_fori, y_gamma, 1e-2) | ||
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) | ||
expect = jnp.array([1, 2, 6, 24, 120, 720, 5040, 40320]) | ||
assert_allclose(factorial_fori(x, x[-1]), expect) | ||
assert_allclose(factorial_lookup(x, x[-1]), expect) | ||
assert_allclose(factorial_gamma(x), expect) | ||
|
||
|
||
def test_factorial2(): | ||
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) | ||
expect = jnp.array([1, 2, 3, 8, 15, 48, 105, 384]) | ||
assert_allclose(factorial2_fori(x), expect) | ||
assert_allclose(factorial2_fori(0), 1) | ||
|
||
assert_allclose(factorial2_lookup(x), expect) | ||
assert_allclose(factorial2_lookup(0), 1) | ||
|
||
|
||
@pytest.mark.parametrize("binom_func", [binom_beta, binom_fori, binom_lookup]) | ||
def test_binom(binom_func): | ||
x = jnp.array([4, 4, 4, 4]) | ||
y = jnp.array([1, 2, 3, 4]) | ||
expect = jnp.array([4, 6, 4, 1]) | ||
assert_allclose(binom_func(x, y), expect) | ||
|
||
zero = jnp.array([0]) | ||
assert_allclose(binom_func(zero, y), jnp.zeros_like(x)) | ||
assert_allclose(binom_func(x, zero), jnp.ones_like(y)) | ||
assert_allclose(binom_func(y, y), jnp.ones_like(y)) | ||
|
||
one = jnp.array([1]) | ||
assert_allclose(binom_func(one, one), one) | ||
assert_allclose(binom_func(zero, -one), zero) | ||
assert_allclose(binom_func(zero, zero), one) |