From d6858fe23af29cb28379723175895c9d02fdca51 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 17 Jun 2022 21:22:14 -0500 Subject: [PATCH] Make Elemwise.infer_shape return TensorType-ed values --- aesara/tensor/elemwise.py | 8 +++++--- tests/tensor/test_elemwise.py | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index dd3c25783f..0515ddaec7 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -1,5 +1,5 @@ from copy import copy -from typing import Tuple, Union +from typing import List, Tuple, Union import numpy as np @@ -29,6 +29,7 @@ float_dtypes, lvector, ) +from aesara.tensor.var import TensorVariable from aesara.utils import uniq @@ -802,7 +803,7 @@ def perform(self, node, inputs, output_storage): else: storage[0] = variable - def infer_shape(self, fgraph, node, i_shapes): + def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: if len(node.outputs) > 1: from aesara.tensor.basic_opt import ShapeError @@ -813,7 +814,8 @@ def infer_shape(self, fgraph, node, i_shapes): out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) - return [out_shape] + # The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s + return [tuple(as_tensor_variable(s) for s in out_shape)] def _c_all(self, node, nodename, inames, onames, sub): # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 0cc6e8658c..d4638e5b74 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -26,6 +26,7 @@ bmatrix, bscalar, discrete_dtypes, + lscalar, matrix, scalar, tensor, @@ -815,8 +816,8 @@ def test_partial_static_shape_info(self): assert len(res_shape) == 1 assert len(res_shape[0]) == 2 - assert res_shape[0][0].data == 1 - assert res_shape[0][1].data == 1 + assert aesara.get_scalar_constant_value(res_shape[0][0]) == 1 + assert aesara.get_scalar_constant_value(res_shape[0][1]) == 1 def test_multi_output(self): class CustomElemwise(Elemwise): @@ -841,6 +842,18 @@ def make_node(self, *args): with pytest.raises(ShapeError): z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + def test_shape_types(self): + x = tensor(np.float64, (None, 1)) + y = tensor(np.float64, (50, 10)) + + z = x * y + + assert isinstance(z.owner.op, Elemwise) + + (out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)]) + + assert all(isinstance(v.type, TensorType) for v in out_shape) + def test_not_implemented_elemwise_grad(): # Regression test for unimplemented gradient in an Elemwise Op.