Skip to content

Commit

Permalink
Make Elemwise.infer_shape return TensorType-ed values
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 18, 2022
1 parent 90a0f73 commit d6858fe
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
8 changes: 5 additions & 3 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import copy
from typing import Tuple, Union
from typing import List, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -29,6 +29,7 @@
float_dtypes,
lvector,
)
from aesara.tensor.var import TensorVariable
from aesara.utils import uniq


Expand Down Expand Up @@ -802,7 +803,7 @@ def perform(self, node, inputs, output_storage):
else:
storage[0] = variable

def infer_shape(self, fgraph, node, i_shapes):
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:

if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError
Expand All @@ -813,7 +814,8 @@ def infer_shape(self, fgraph, node, i_shapes):

out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)

return [out_shape]
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
return [tuple(as_tensor_variable(s) for s in out_shape)]

def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
Expand Down
17 changes: 15 additions & 2 deletions tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
bmatrix,
bscalar,
discrete_dtypes,
lscalar,
matrix,
scalar,
tensor,
Expand Down Expand Up @@ -815,8 +816,8 @@ def test_partial_static_shape_info(self):

assert len(res_shape) == 1
assert len(res_shape[0]) == 2
assert res_shape[0][0].data == 1
assert res_shape[0][1].data == 1
assert aesara.get_scalar_constant_value(res_shape[0][0]) == 1
assert aesara.get_scalar_constant_value(res_shape[0][1]) == 1

def test_multi_output(self):
class CustomElemwise(Elemwise):
Expand All @@ -841,6 +842,18 @@ def make_node(self, *args):
with pytest.raises(ShapeError):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])

def test_shape_types(self):
x = tensor(np.float64, (None, 1))
y = tensor(np.float64, (50, 10))

z = x * y

assert isinstance(z.owner.op, Elemwise)

(out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)])

assert all(isinstance(v.type, TensorType) for v in out_shape)


def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
Expand Down

0 comments on commit d6858fe

Please sign in to comment.