Skip to content

Commit 09a4350

Browse files
Check for SolveBase to catch all cases
1 parent cd70d19 commit 09a4350

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@
4747
from pytensor.tensor.slinalg import (
4848
BlockDiagonal,
4949
Cholesky,
50+
CholeskySolve,
5051
Solve,
5152
SolveBase,
53+
SolveTriangular,
5254
_bilinear_solve_discrete_lyapunov,
5355
block_diag,
5456
cholesky,
@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
908910
return None
909911

910912
[input] = node.inputs
913+
914+
# Check if input is a (1, 1) matrix
915+
if all(input.type.broadcastable[:-2]):
916+
return [pt.sqrt(input)]
917+
911918
# Check for use of pt.diag first
912919
if (
913920
input.owner
@@ -1031,15 +1038,28 @@ def scalar_solve_to_divison(fgraph, node):
10311038
"""
10321039

10331040
core_op = node.op.core_op
1034-
if not isinstance(core_op, Solve):
1041+
if not isinstance(core_op, SolveBase):
10351042
return None
10361043

10371044
a, b = node.inputs
10381045
old_out = node.outputs[0]
10391046
if not all(a.broadcastable[-2:]):
10401047
return None
10411048

1042-
new_out = b / a
1049+
# Special handling for different types of solve
1050+
match core_op:
1051+
case SolveTriangular():
1052+
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
1053+
new_out = b / a if not core_op.unit_diagonal else b
1054+
case CholeskySolve():
1055+
new_out = b / a**2
1056+
case Solve():
1057+
new_out = b / a
1058+
case _:
1059+
raise NotImplementedError(
1060+
f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison"
1061+
)
1062+
10431063
if core_op.b_ndim == 1:
10441064
new_out = new_out.squeeze(-1)
10451065

tests/tensor/rewriting/test_linalg.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytensor.tensor.slinalg import (
3030
BlockDiagonal,
3131
Cholesky,
32+
CholeskySolve,
3233
Solve,
3334
SolveBase,
3435
SolveTriangular,
@@ -995,18 +996,30 @@ def test_slogdet_specialization():
995996
assert not any(isinstance(node.op, SLogDet) for node in nodes)
996997

997998

998-
def test_scalar_solve_to_division_rewrite():
999+
@pytest.mark.parametrize(
1000+
"Op, fn",
1001+
[
1002+
(Solve, pt.linalg.solve),
1003+
(SolveTriangular, pt.linalg.solve_triangular),
1004+
(CholeskySolve, pt.linalg.cho_solve),
1005+
],
1006+
)
1007+
def test_scalar_solve_to_division_rewrite(Op, fn):
9991008
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))
10001009

10011010
a = pt.dmatrix("a", shape=(1, 1))
10021011
b = pt.dvector("b")
10031012

1004-
c = pt.linalg.solve(a, b, b_ndim=1)
1013+
if Op is CholeskySolve:
1014+
# cho_solve expects a tuple (c, lower) as the first input
1015+
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
1016+
else:
1017+
c = fn(a, b, b_ndim=1)
10051018

10061019
f = function([a, b], c, mode="FAST_RUN")
10071020
nodes = f.maker.fgraph.apply_nodes
10081021

1009-
assert not any(isinstance(node.op, Solve) for node in nodes)
1022+
assert not any(isinstance(node.op, Op) for node in nodes)
10101023

10111024
a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
10121025
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)

0 commit comments

Comments
 (0)