Skip to content

Commit 236e50d

Browse files
authored
Use lapack func instead of scipy.linalg.cholesky (#1487)
* Use lapack func instead of `scipy.linalg.cholesky` * Now skips 2D checks in perform * Updated the default arguments for `check_finite` to false to match documentation * Add benchmark test case * Refactor out _cholesky helper, add empty test * Remove array and `potrf` copies * Update test_cholesky_raises_on_nan_input
1 parent 7886cf8 commit 236e50d

File tree

3 files changed

+71
-23
lines changed

3 files changed

+71
-23
lines changed

pytensor/tensor/slinalg.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
self,
3838
*,
3939
lower: bool = True,
40-
check_finite: bool = True,
40+
check_finite: bool = False,
4141
on_error: Literal["raise", "nan"] = "raise",
4242
overwrite_a: bool = False,
4343
):
@@ -67,29 +67,55 @@ def make_node(self, x):
6767
def perform(self, node, inputs, outputs):
6868
[x] = inputs
6969
[out] = outputs
70-
try:
71-
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
72-
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
73-
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
74-
out[0] = scipy_linalg.cholesky(
75-
x.T,
76-
lower=not self.lower,
77-
check_finite=self.check_finite,
78-
overwrite_a=True,
79-
).T
80-
else:
81-
out[0] = scipy_linalg.cholesky(
82-
x,
83-
lower=self.lower,
84-
check_finite=self.check_finite,
85-
overwrite_a=self.overwrite_a,
86-
)
8770

88-
except scipy_linalg.LinAlgError:
89-
if self.on_error == "raise":
90-
raise
71+
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,))
72+
73+
# Quick return for square empty array
74+
if x.size == 0:
75+
out[0] = np.empty_like(x, dtype=potrf.dtype)
76+
return
77+
78+
if self.check_finite and not np.isfinite(x).all():
79+
if self.on_error == "nan":
80+
out[0] = np.full(x.shape, np.nan, dtype=potrf.dtype)
81+
return
9182
else:
83+
raise ValueError("array must not contain infs or NaNs")
84+
85+
# Squareness check
86+
if x.shape[0] != x.shape[1]:
87+
raise ValueError(
88+
"Input array is expected to be square but has " f"the shape: {x.shape}."
89+
)
90+
91+
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
92+
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
93+
c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"]
94+
if c_contiguous_input:
95+
x = x.T
96+
lower = not self.lower
97+
overwrite_a = True
98+
else:
99+
lower = self.lower
100+
overwrite_a = self.overwrite_a
101+
102+
c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)
103+
104+
if info != 0:
105+
if self.on_error == "nan":
92106
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
107+
elif info > 0:
108+
raise scipy_linalg.LinAlgError(
109+
f"{info}-th leading minor of the array is not positive definite"
110+
)
111+
elif info < 0:
112+
raise ValueError(
113+
f"LAPACK reported an illegal value in {-info}-th argument "
114+
f'on entry to "POTRF".'
115+
)
116+
else:
117+
# Transpose result if input was transposed
118+
out[0] = c.T if c_contiguous_input else c
93119

94120
def L_op(self, inputs, outputs, gradients):
95121
"""
@@ -201,7 +227,9 @@ def cholesky(
201227
202228
"""
203229

204-
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
230+
return Blockwise(
231+
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
232+
)(x)
205233

206234

207235
class SolveBase(Op):

tests/link/numba/test_slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_cholesky_raises_on_nan_input():
465465

466466
x = pt.tensor(dtype=floatX, shape=(3, 3))
467467
x = x.T.dot(x)
468-
g = pt.linalg.cholesky(x)
468+
g = pt.linalg.cholesky(x, check_finite=True)
469469
f = pytensor.function([x], g, mode="NUMBA")
470470

471471
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):

tests/tensor/test_slinalg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,26 @@ def test_cholesky():
7474
check_upper_triangular(pd, ch_f)
7575

7676

77+
def test_cholesky_performance(benchmark):
78+
rng = np.random.default_rng(utt.fetch_seed())
79+
r = rng.standard_normal((10, 10)).astype(config.floatX)
80+
pd = np.dot(r, r.T)
81+
x = matrix()
82+
chol = cholesky(x)
83+
ch_f = function([x], chol)
84+
benchmark(ch_f, pd)
85+
86+
87+
def test_cholesky_empty():
88+
empty = np.empty([0, 0], dtype=config.floatX)
89+
x = matrix()
90+
chol = cholesky(x)
91+
ch_f = function([x], chol)
92+
ch = ch_f(empty)
93+
assert ch.size == 0
94+
assert ch.dtype == config.floatX
95+
96+
7797
def test_cholesky_indef():
7898
x = matrix()
7999
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)

0 commit comments

Comments
 (0)