diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index aef363655e..b12d75ae35 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -29,9 +29,11 @@ cast, constant, get_underlying_scalar_constant_value, + join, moveaxis, ones_like, register_infer_shape, + split, switch, zeros_like, ) @@ -99,6 +101,7 @@ ) from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import Shape, Shape_i +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( complex_dtypes, @@ -167,6 +170,76 @@ def local_0_dot_x(fgraph, node): return [constant_zero] +@register_stabilize +@node_rewriter([Blockwise]) +def local_block_diag_dot_to_dot_block_diag(fgraph, node): + r""" + Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))`` + + BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity + of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than + a single dot on the larger matrix. + """ + if not isinstance(node.op.core_op, BlockDiagonal): + return + + def check_for_block_diag(x): + return x.owner and ( + isinstance(x.owner.op, BlockDiagonal) + or isinstance(x.owner.op, Blockwise) + and isinstance(x.owner.op.core_op, BlockDiagonal) + ) + + # Check that the BlockDiagonal is an input to a Dot node: + clients = list(get_clients_at_depth(fgraph, node, depth=1)) + if not clients or len(clients) > 1 or not isinstance(clients[0].op, Dot): + return + + [dot_node] = clients + op = dot_node.op + x, y = dot_node.inputs + + if not (check_for_block_diag(x) or check_for_block_diag(y)): + return None + + # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the + # non-block diagonal, and return a new block diagonal + if check_for_block_diag(x) and not check_for_block_diag(y): + components = x.owner.inputs + y_splits = split( + y, + splits_size=[component.shape[-1] for component in components], + n_splits=len(components), + ) + new_components = [ + op(component, y_split) for component, y_split in zip(components, y_splits) + ] + new_output = join(0, *new_components) + + elif not check_for_block_diag(x) and check_for_block_diag(y): + components = y.owner.inputs + x_splits = split( + x, + splits_size=[component.shape[0] for component in components], + n_splits=len(components), + axis=1, + ) + + new_components = [ + op(x_split, component) for component, x_split in zip(components, x_splits) + ] + new_output = join(1, *new_components) + + # Case 2: Both inputs are BlockDiagonal. Do nothing + else: + # TODO: If shapes are statically known and all components have equal shapes, we could rewrite + # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)]) + return None + + copy_stack_trace(node.outputs[0], new_output) + return {dot_node.outputs[0]: new_output} + + @register_canonicalize @node_rewriter([DimShuffle]) def local_lift_transpose_through_dot(fgraph, node): @@ -2496,7 +2569,6 @@ def add_calculate(num, denum, aslist=False, out_type=None): name="add_canonizer_group", ) - register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer") @@ -3619,7 +3691,6 @@ def logmexpm1_to_log1mexp(fgraph, node): ) register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") - # log(sigmoid(x) / (1 - sigmoid(x))) -> x # i.e logit(sigmoid(x)) -> x local_logit_sigmoid = PatternNodeRewriter( @@ -3633,7 +3704,6 @@ def logmexpm1_to_log1mexp(fgraph, node): register_canonicalize(local_logit_sigmoid) register_specialize(local_logit_sigmoid) - # sigmoid(log(x / (1-x)) -> x # i.e., sigmoid(logit(x)) -> x local_sigmoid_logit = PatternNodeRewriter( @@ -3674,7 +3744,6 @@ def local_useless_conj(fgraph, node): register_specialize(local_polygamma_to_tri_gamma) - local_log_kv = PatternNodeRewriter( # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x # During stabilize -x is converted to -1.0 * x diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index c4999fcd33..137b91fb34 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -113,6 +113,7 @@ simplify_mul, ) from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape +from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.type import ( TensorType, cmatrix, @@ -4654,3 +4655,80 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): out.eval({a: a_test, b: b_test}, mode=test_mode), rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), ) + + +@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) +def test_local_block_diag_dot_to_dot_block_diag(left_multiply): + """ + Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) + """ + a = tensor("a", shape=(4, 2)) + b = tensor("b", shape=(2, 4)) + c = tensor("c", shape=(4, 4)) + d = tensor("d", shape=(10, 10)) + + x = pt.linalg.block_diag(a, b, c) + + if left_multiply: + out = x @ d + else: + out = d @ x + + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + assert not any( + isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort() + ) + + fn_expected = pytensor.function( + [a, b, c, d], + out, + mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"), + ) + + rng = np.random.default_rng() + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + np.testing.assert_allclose( + fn(a_val, b_val, c_val, d_val), + fn_expected(a_val, b_val, c_val, d_val), + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) + + +@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) +@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) +def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): + rng = np.random.default_rng() + a_size = int(rng.uniform(0, size)) + b_size = int(rng.uniform(0, size - a_size)) + c_size = size - a_size - b_size + + a = tensor("a", shape=(a_size, a_size)) + b = tensor("b", shape=(b_size, b_size)) + c = tensor("c", shape=(c_size, c_size)) + d = tensor("d", shape=(size,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d + + mode = get_default_mode() + if not rewrite: + mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") + fn = pytensor.function([a, b, c, d], out, mode=mode) + + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + benchmark( + fn, + a_val, + b_val, + c_val, + d_val, + )