Skip to content

Rewrite scalar solve to division #1453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
_bilinear_solve_discrete_lyapunov,
block_diag,
cholesky,
Expand Down Expand Up @@ -908,6 +910,11 @@
return None

[input] = node.inputs

# Check if input is a (1, 1) matrix
if all(input.type.broadcastable[-2:]):
return [pt.sqrt(input)]

# Check for use of pt.diag first
if (
input.owner
Expand Down Expand Up @@ -1020,3 +1027,42 @@
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements


@register_stabilize
@register_canonicalize
@node_rewriter([Blockwise])
def scalar_solve_to_division(fgraph, node):
"""
Replace solve(a, b) with b / a if a is a (1, 1) matrix
"""

core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None

a, b = node.inputs
old_out = node.outputs[0]
if not all(a.broadcastable[-2:]):
return None

# Special handling for different types of solve
match core_op:
case SolveTriangular():
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
new_out = b / a if not core_op.unit_diagonal else b
case CholeskySolve():
new_out = b / a**2
case Solve():
new_out = b / a
case _:
raise NotImplementedError(

Check warning on line 1059 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L1058-L1059

Added lines #L1058 - L1059 were not covered by tests
f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison"
)

if core_op.b_ndim == 1:
new_out = new_out.squeeze(-1)

copy_stack_trace(old_out, new_out)

return [new_out]
31 changes: 31 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,37 @@ def local_exp_log(fgraph, node):
return [exp(x)]


@register_canonicalize
@register_specialize
@node_rewriter([sqrt, sqr])
def local_sqrt_sqr(fgraph, node):
x = node.inputs[0]

if not (x.owner and isinstance(x.owner.op, Elemwise)):
return

prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op

# Case for sqrt(sqr(x)) -> |x|
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
new_out = pt_abs(x.owner.inputs[0])
old_out = node.outputs[0]

# Handle potential integer to float cast by sqr
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]

# Case for sqr(sqrt(x)) -> x
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))

return [new_out]


@register_specialize
@node_rewriter([exp, expm1])
def local_exp_log_nan_switch(fgraph, node):
Expand Down
43 changes: 35 additions & 8 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
CholeskySolve,
Solve,
SolveBase,
SolveTriangular,
Expand Down Expand Up @@ -920,14 +921,6 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)

# Case 2 : eye is degenerate
x = pt.scalar("x")
y = pt.eye(1) * x
z_cholesky = pt.linalg.cholesky(y)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)


def test_slogdet_specialization():
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
Expand Down Expand Up @@ -993,3 +986,37 @@ def test_slogdet_specialization():
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)


@pytest.mark.parametrize(
"Op, fn",
[
(Solve, pt.linalg.solve),
(SolveTriangular, pt.linalg.solve_triangular),
(CholeskySolve, pt.linalg.cho_solve),
],
)
def test_scalar_solve_to_division_rewrite(Op, fn):
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))

a = pt.dmatrix("a", shape=(1, 1))
b = pt.dvector("b")

if Op is CholeskySolve:
# cho_solve expects a tuple (c, lower) as the first input
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
else:
c = fn(a, b, b_ndim=1)

f = function([a, b], c, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes

assert not any(isinstance(node.op, Op) for node in nodes)

a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)

c_val = np.linalg.solve(a_val, b_val)
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
39 changes: 39 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,45 @@ def test_exp_log_nested(self, nested_expression, expected_switches):
assert len(ops_graph) == expected_switches


class TestSqrSqrt:
def setup_method(self):
mode = get_default_mode()
self.mode = mode.including(
"local_sqrt_sqr",
).excluding("fusion")
self.rng = np.random.default_rng()

def test_sqr_sqrt(self):
# sqrt(x) ** 2 -> x
x = pt.tensor("x", shape=(None, None))
out = sqr(sqrt(x))
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])

assert equal_computations([out], [pt_abs(x)])

def test_sqrt_sqr(self):
x = pt.tensor("x", shape=(None, None))
out = sqrt(sqr(x))
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])

expected = switch(
ge(x, np.zeros((1, 1), dtype="int8")),
x,
np.full((1, 1), np.nan, dtype=x.type.dtype),
)

assert equal_computations([out], [expected])

def test_sqr_sqrt_integer_upcast(self):
x = ivector("x")
out = sqr(sqrt(x))
dtype = out.type.dtype
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])

expected = pt.cast(pt_abs(x), dtype=dtype)
assert equal_computations([out], [expected])


class TestLocalSwitchSink:
def setup_method(self):
# condition values
Expand Down