Skip to content

Commit

Permalink
Rewrite batched dots that do not reduce as multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 28, 2025
1 parent f86a0dc commit 911c6a3
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 9 deletions.
42 changes: 34 additions & 8 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
stack,
switch,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import (
CAReduce,
Elemwise,
Expand Down Expand Up @@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))


# Predefine all batched variations of Dot
_inner_prod = Blockwise(
_dot,
signature="(n),(n)->()",
)

_matrix_vec_prod = Blockwise(
_dot,
signature="(m,k),(k)->(m)",
)

_vec_matrix_prod = Blockwise(
_dot,
signature="(k),(k,n)->(n)",
)

_matrix_matrix_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
Expand Down Expand Up @@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None


@_vectorize_node.register(Dot)
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
# If original input is equivalent to a matrix-matrix product,
# return specialized Matmul Op to avoid unnecessary new Ops.
return matmul(batched_x, batched_y).owner
else:
return vectorize_node_fallback(op, node, batched_x, batched_y)
old_x_ndim = old_x.type.ndim
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_op = _inner_prod
case (2, 1):
batch_op = _matrix_vec_prod
case (1, 2):
batch_op = _vec_matrix_prod
case (2, 2):
batch_op = _matrix_matrix_matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_op(batched_x, batched_y).owner


def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
Expand Down
60 changes: 60 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
Prod,
Sum,
_conj,
_inner_prod,
_matrix_matrix_matmul,
_matrix_vec_prod,
_vec_matrix_prod,
add,
digamma,
dot,
Expand Down Expand Up @@ -242,6 +246,62 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
return None


@register_canonicalize
@register_specialize
@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul])
def local_blockwise_dot_to_mul(fgraph, node):
"""Rewrite blockwise dots that correspond to multiplication without summation.
We don't touch the regular dot, to not interfere with the BLAS optimizations.
"""
a, b = node.inputs
a_static_shape = a.type.shape
b_static_shape = b.type.shape
core_a_ndim = len(node.op.inputs_sig[0])
core_b_ndim = len(node.op.inputs_sig[1])

if core_a_ndim > 2 or core_b_ndim > 2:
# Shouldn't happen, but here just in case
return None

if core_b_ndim == 1:
if a_static_shape[-1] == 1 or b_static_shape[-1] == 1:
if core_a_ndim == 1:
# inner product: (..., 1) * (..., 1) -> (...)
# just squeeze the last dimensions of a and b
new_a = a.squeeze(-1)
new_b = b.squeeze(-1)
else:
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
# the last dimension of b is already aligned for the elemwise multiplication
# after we squeeze the last dimension of a
new_a = a.squeeze(-1)
new_b = b
else:
return None

else:
if a_static_shape[-1] == 1 or b_static_shape[-2] == 1:
if core_a_ndim == 1:
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
# the last dimension of a is already aligned for the elemwise multiplication
# after we squeeze the one to last dimension of b
new_a = a
new_b = b.squeeze(-2)
else:
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# the dimensions of a and b are already aligned for the elemwise multiplication
new_a = a
new_b = b
else:
return None

new_a = copy_stack_trace(a, new_a)
new_b = copy_stack_trace(b, new_b)
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
return [new_out]


def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
Expand Down
53 changes: 52 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Apply, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import (
SequentialNodeRewriter,
Expand Down Expand Up @@ -4590,3 +4591,53 @@ def test_pow_1_rewrite(shape):

x_val = np.random.random(shape).astype(config.floatX)
np.testing.assert_allclose(z.eval({x: x_val}), f(x_val))


@pytest.mark.parametrize(
"a_shape,b_shape",
[
((1,), (1,)),
((3, 1), (1,)),
((1,), (1, 3)),
((3, 1), (1, 3)),
],
ids=str,
)
@pytest.mark.parametrize("batched", (False, True))
def test_local_dot_to_mul(batched, a_shape, b_shape):
a = tensor("a", shape=a_shape)
b = tensor("b", shape=b_shape)

out = dot(a, b)
if batched:
batch_a = tensor("batch_a", shape=(1, 5, *a_shape))
batch_b = tensor("batch_b", shape=(7, 1, *b_shape))
out = vectorize_graph(out, {a: batch_a, b: batch_b})
a = batch_a
b = batch_b

assert (
sum(
isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([out])
if var.owner
)
== 1
)

# For now rewrite only applies to Batched Dots
rewritten_out = rewrite_graph(out)
assert rewritten_out.type.shape == out.type.shape
assert sum(
isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([rewritten_out])
if var.owner
) == (0 if batched else 1)

a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
test_mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
out.eval({a: a_test, b: b_test}, mode=test_mode),
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
)

0 comments on commit 911c6a3

Please sign in to comment.