Skip to content

Commit

Permalink
Implement gradient for vector repetitions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 5, 2025
1 parent 4ac1e63 commit 43d2438
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 55 deletions.
103 changes: 54 additions & 49 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,12 @@ class Repeat(Op):

__props__ = ("axis",)

def __init__(self, axis=None):
def __init__(self, axis: int | None = None):
if axis is not None:
if not isinstance(axis, int) or axis < 0:
raise ValueError(
f"Repeat only accepts positive integer axis or None, got {axis}"
)
self.axis = axis

def make_node(self, x, repeats):
Expand Down Expand Up @@ -687,48 +692,51 @@ def make_node(self, x, repeats):
out_shape = list(x.type.shape)
out_shape[self.axis] = None

out_type = TensorType(
x.dtype, shape=tuple(1 if s == 1 else None for s in out_shape)
)

out_type = TensorType(x.dtype, shape=out_shape)
return Apply(self, [x, repeats], [out_type()])

def perform(self, node, inputs, output_storage):
x = inputs[0]
repeats = inputs[1]
z = output_storage[0]
z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
[x, repeats] = inputs
output_storage[0][0] = np.repeat(x, repeats=repeats, axis=self.axis)

def connection_pattern(self, node):
return [[True], [False]]

def grad(self, inputs, gout):
(x, repeats) = inputs
(gz,) = gout
axis = self.axis
if repeats.ndim == 0:
# When axis is a scalar (same number of reps for all elements),
# We can split the repetitions into their own axis with reshape and sum them back
# to the original element location
sum_axis = x.ndim if axis is None else axis + 1
shape = list(x.shape)
shape.insert(sum_axis, repeats)
gx = gz.reshape(shape).sum(axis=sum_axis)

elif repeats.ndim == 1:
# To sum the gradients that belong to the same repeated x,
# We create a repeated eye and dot product it with the gradient.
axis_size = x.size if self.axis is None else x.shape[self.axis]
tiled_eye = repeat(ptb.eye(axis_size), repeats, axis=0)

if self.axis is None:
axis = x.ndim
gx = gz @ tiled_eye
# Undo the ravelling when axis=None
gx = gx.reshape(x.shape)
else:
if self.axis >= 0:
axis = self.axis + 1
else:
axis = self.axis + x.ndim + 1

shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats)
# Place gradient axis at end for dot product
gx = ptb.moveaxis(gz, self.axis, -1)
gx = gx @ tiled_eye
# Place gradient back into the correct axis
gx = ptb.moveaxis(gx, -1, self.axis)

return [
gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis),
DisconnectedType()(),
]
elif repeats.ndim == 1:
# For this implementation, we would need to specify the length
# of repeats in order to split gz in the right way to sum
# the good part.
raise NotImplementedError()
else:
raise ValueError()

return [gx, DisconnectedType()()]

def infer_shape(self, fgraph, node, ins_shapes):
i0_shapes = ins_shapes[0]
repeats = node.inputs[1]
Expand Down Expand Up @@ -757,7 +765,7 @@ def infer_shape(self, fgraph, node, ins_shapes):
return [out_shape]


def repeat(x, repeats, axis=None):
def repeat(a, repeats, axis=None):
"""Repeat elements of an array.
It returns an array which has the same shape as `x`, except along the given
Expand All @@ -782,51 +790,48 @@ def repeat(x, repeats, axis=None):
.. versionadded:: 0.6
"""
a = ptb.as_tensor_variable(a)

if axis is not None:
axis = normalize_axis_index(axis, a.ndim)

repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)

if repeats.ndim > 1:
raise ValueError("The dimension of repeats should not exceed 1.")

if repeats.ndim == 1 and not repeats.broadcastable[0]:
return Repeat(axis=axis)(x, repeats)
# We only use the Repeat Op for vector repeats
return Repeat(axis=axis)(a, repeats)
else:
if repeats.ndim == 1:
repeats = repeats[0]

if x.dtype == "uint64":
if a.dtype == "uint64":
raise TypeError("repeat doesn't support dtype uint64")

if axis is None:
axis = 0
x = x.flatten()
else:
if axis >= x.ndim:
raise ValueError("Axis should not exceed x.ndim-1.")
if axis < 0:
axis = x.ndim + axis
a = a.flatten()

shape = [x.shape[i] for i in range(x.ndim)]
repeat_shape = list(a.shape)

# shape_ is the shape of the intermediate tensor which has
# alloc_shape is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to
# allocate space for this intermediate tensor to replicate x
# along that additional dimension.
shape_ = shape[:]
shape_.insert(axis + 1, repeats)
alloc_shape = repeat_shape[:]
alloc_shape.insert(axis + 1, repeats)

# shape is now the shape of output, where shape[axis] becomes
# repeat_shape is now the shape of output, where shape[axis] becomes
# shape[axis]*repeats.
shape[axis] = shape[axis] * repeats

# dims_ is the dimension of that intermediate tensor.
dims_ = list(np.arange(x.ndim))
dims_.insert(axis + 1, "x")
repeat_shape[axis] = repeat_shape[axis] * repeats

# After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape, and
# return the output z.
z = ptb.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape)
return z
# dimension, we reshape it to the expected output shape
return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
repeat_shape
)


class Bartlett(Op):
Expand Down
23 changes: 17 additions & 6 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,12 +635,23 @@ def test_infer_shape(self, ndim, dtype):
self.op_class,
)

@pytest.mark.parametrize("ndim", range(3))
def test_grad(self, ndim):
a = np.random.random((10,) * ndim).astype(config.floatX)

for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a])
@pytest.mark.parametrize("x_ndim", [2, 3], ids=lambda x: f"x_ndim={x}")
@pytest.mark.parametrize("repeats_ndim", [0, 1], ids=lambda r: f"repeats_ndim={r}")
@pytest.mark.parametrize("axis", [None, 0, 1], ids=lambda a: f"axis={a}")
def test_grad(self, x_ndim, repeats_ndim, axis):
rng = np.random.default_rng(
[653, x_ndim, 2 if axis is None else axis, repeats_ndim]
)
x_test = rng.normal(size=np.arange(3, 3 + x_ndim))
if repeats_ndim == 0:
repeats_size = ()
else:
repeats_size = (x_test.shape[axis] if axis is not None else x_test.size,)
repeats = rng.integers(1, 6, size=repeats_size)
utt.verify_grad(
lambda x: Repeat(axis=axis)(x, repeats),
[x_test],
)

def test_broadcastable(self):
x = TensorType(config.floatX, shape=(None, 1, None))()
Expand Down

0 comments on commit 43d2438

Please sign in to comment.