Skip to content

Commit f226317

Browse files
Rewrite scalar solve to division
1 parent 236e50d commit f226317

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@
4747
from pytensor.tensor.slinalg import (
4848
BlockDiagonal,
4949
Cholesky,
50+
CholeskySolve,
5051
Solve,
5152
SolveBase,
53+
SolveTriangular,
5254
_bilinear_solve_discrete_lyapunov,
5355
block_diag,
5456
cholesky,
@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
908910
return None
909911

910912
[input] = node.inputs
913+
914+
# Check if input is a (1, 1) matrix
915+
if all(input.type.broadcastable[-2:]):
916+
return [pt.sqrt(input)]
917+
911918
# Check for use of pt.diag first
912919
if (
913920
input.owner
@@ -1020,3 +1027,42 @@ def slogdet_specialization(fgraph, node):
10201027
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
10211028
}
10221029
return replacements
1030+
1031+
1032+
@register_stabilize
1033+
@register_canonicalize
1034+
@node_rewriter([Blockwise])
1035+
def scalar_solve_to_divison(fgraph, node):
1036+
"""
1037+
Replace solve(a, b) with b / a if a is a (1, 1) matrix
1038+
"""
1039+
1040+
core_op = node.op.core_op
1041+
if not isinstance(core_op, SolveBase):
1042+
return None
1043+
1044+
a, b = node.inputs
1045+
old_out = node.outputs[0]
1046+
if not all(a.broadcastable[-2:]):
1047+
return None
1048+
1049+
# Special handling for different types of solve
1050+
match core_op:
1051+
case SolveTriangular():
1052+
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
1053+
new_out = b / a if not core_op.unit_diagonal else b
1054+
case CholeskySolve():
1055+
new_out = b / a**2
1056+
case Solve():
1057+
new_out = b / a
1058+
case _:
1059+
raise NotImplementedError(
1060+
f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison"
1061+
)
1062+
1063+
if core_op.b_ndim == 1:
1064+
new_out = new_out.squeeze(-1)
1065+
1066+
copy_stack_trace(old_out, new_out)
1067+
1068+
return [new_out]

tests/tensor/rewriting/test_linalg.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytensor.tensor.slinalg import (
3030
BlockDiagonal,
3131
Cholesky,
32+
CholeskySolve,
3233
Solve,
3334
SolveBase,
3435
SolveTriangular,
@@ -993,3 +994,37 @@ def test_slogdet_specialization():
993994
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
994995
nodes = f.maker.fgraph.apply_nodes
995996
assert not any(isinstance(node.op, SLogDet) for node in nodes)
997+
998+
999+
@pytest.mark.parametrize(
1000+
"Op, fn",
1001+
[
1002+
(Solve, pt.linalg.solve),
1003+
(SolveTriangular, pt.linalg.solve_triangular),
1004+
(CholeskySolve, pt.linalg.cho_solve),
1005+
],
1006+
)
1007+
def test_scalar_solve_to_division_rewrite(Op, fn):
1008+
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))
1009+
1010+
a = pt.dmatrix("a", shape=(1, 1))
1011+
b = pt.dvector("b")
1012+
1013+
if Op is CholeskySolve:
1014+
# cho_solve expects a tuple (c, lower) as the first input
1015+
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
1016+
else:
1017+
c = fn(a, b, b_ndim=1)
1018+
1019+
f = function([a, b], c, mode="FAST_RUN")
1020+
nodes = f.maker.fgraph.apply_nodes
1021+
1022+
assert not any(isinstance(node.op, Op) for node in nodes)
1023+
1024+
a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
1025+
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)
1026+
1027+
c_val = np.linalg.solve(a_val, b_val)
1028+
np.testing.assert_allclose(
1029+
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
1030+
)

0 commit comments

Comments
 (0)