@@ -37,7 +37,7 @@ def __init__(
37
37
self ,
38
38
* ,
39
39
lower : bool = True ,
40
- check_finite : bool = True ,
40
+ check_finite : bool = False ,
41
41
on_error : Literal ["raise" , "nan" ] = "raise" ,
42
42
overwrite_a : bool = False ,
43
43
):
@@ -67,29 +67,55 @@ def make_node(self, x):
67
67
def perform (self , node , inputs , outputs ):
68
68
[x ] = inputs
69
69
[out ] = outputs
70
- try :
71
- # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
72
- # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
73
- if self .overwrite_a and x .flags ["C_CONTIGUOUS" ]:
74
- out [0 ] = scipy_linalg .cholesky (
75
- x .T ,
76
- lower = not self .lower ,
77
- check_finite = self .check_finite ,
78
- overwrite_a = True ,
79
- ).T
80
- else :
81
- out [0 ] = scipy_linalg .cholesky (
82
- x ,
83
- lower = self .lower ,
84
- check_finite = self .check_finite ,
85
- overwrite_a = self .overwrite_a ,
86
- )
87
70
88
- except scipy_linalg .LinAlgError :
89
- if self .on_error == "raise" :
90
- raise
71
+ (potrf ,) = scipy_linalg .get_lapack_funcs (("potrf" ,), (x ,))
72
+
73
+ # Quick return for square empty array
74
+ if x .size == 0 :
75
+ out [0 ] = np .empty_like (x , dtype = potrf .dtype )
76
+ return
77
+
78
+ if self .check_finite and not np .isfinite (x ).all ():
79
+ if self .on_error == "nan" :
80
+ out [0 ] = np .full (x .shape , np .nan , dtype = potrf .dtype )
81
+ return
91
82
else :
83
+ raise ValueError ("array must not contain infs or NaNs" )
84
+
85
+ # Squareness check
86
+ if x .shape [0 ] != x .shape [1 ]:
87
+ raise ValueError (
88
+ "Input array is expected to be square but has " f"the shape: { x .shape } ."
89
+ )
90
+
91
+ # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
92
+ # If we have a `C_CONTIGUOUS` array we transpose to benefit from it
93
+ c_contiguous_input = self .overwrite_a and x .flags ["C_CONTIGUOUS" ]
94
+ if c_contiguous_input :
95
+ x = x .T
96
+ lower = not self .lower
97
+ overwrite_a = True
98
+ else :
99
+ lower = self .lower
100
+ overwrite_a = self .overwrite_a
101
+
102
+ c , info = potrf (x , lower = lower , overwrite_a = overwrite_a , clean = True )
103
+
104
+ if info != 0 :
105
+ if self .on_error == "nan" :
92
106
out [0 ] = np .full (x .shape , np .nan , dtype = node .outputs [0 ].type .dtype )
107
+ elif info > 0 :
108
+ raise scipy_linalg .LinAlgError (
109
+ f"{ info } -th leading minor of the array is not positive definite"
110
+ )
111
+ elif info < 0 :
112
+ raise ValueError (
113
+ f"LAPACK reported an illegal value in { - info } -th argument "
114
+ f'on entry to "POTRF".'
115
+ )
116
+ else :
117
+ # Transpose result if input was transposed
118
+ out [0 ] = c .T if c_contiguous_input else c
93
119
94
120
def L_op (self , inputs , outputs , gradients ):
95
121
"""
@@ -201,7 +227,9 @@ def cholesky(
201
227
202
228
"""
203
229
204
- return Blockwise (Cholesky (lower = lower , on_error = on_error ))(x )
230
+ return Blockwise (
231
+ Cholesky (lower = lower , on_error = on_error , check_finite = check_finite )
232
+ )(x )
205
233
206
234
207
235
class SolveBase (Op ):
0 commit comments