Skip to content

Commit 6dc8101

Browse files
Relax slerp.interpolate() input shape test and make it coherent with docstring.
PiperOrigin-RevId: 503185356
1 parent 98c0d63 commit 6dc8101

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tensorflow_graphics/math/interpolation/slerp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def quaternion_weights(
170170
quaternions in its last dimension.
171171
quaternion2: A tensor of shape `[A1, ... , An, 4]` storing normalized
172172
quaternions in its last dimension.
173-
percent: A `float` or a tensor with a shape broadcastable to the shape `[A1,
174-
... , An]`.
173+
percent: A `float` or tensor with shape broadcastable to the shape of input
174+
vectors.
175175
eps: A `float` used to make operations safe. When left as None, the function
176176
automatically picks the best epsilon based on the dtype and the operation.
177177
name: A name for this op. Defaults to "quaternion_weights".
@@ -198,7 +198,7 @@ def quaternion_weights(
198198
tensor=quaternion2, tensor_name="quaternion2", has_dim_equals=(-1, 4))
199199
shape.compare_batch_dimensions(
200200
tensors=(quaternion1, quaternion2, percent),
201-
last_axes=(-2, -2, -1),
201+
last_axes=-1,
202202
broadcast_compatible=True,
203203
tensor_names=("quaternion1", "quaternion2", "percent"))
204204
quaternion1 = asserts.assert_normalized(quaternion1)
@@ -266,7 +266,7 @@ def vector_weights(vector1: type_alias.TensorLike,
266266
tensor_names=("vector1", "vector2"))
267267
shape.compare_batch_dimensions(
268268
tensors=(vector1, vector2, percent),
269-
last_axes=(-2, -2, -1),
269+
last_axes=-1,
270270
broadcast_compatible=True,
271271
tensor_names=("vector1", "vector2", "percent"))
272272
normalized1 = tf.nn.l2_normalize(vector1, axis=-1)

tensorflow_graphics/math/interpolation/tests/slerp_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_unnormalized_quaternion_weights_exception_raised(self):
127127
@parameterized.parameters(
128128
((4,), (4,), (1,)),
129129
((None, 4), (None, 4), (None, 1)),
130-
((None, 4), (None, 4), (None, 4)),
130+
((5, 1, 4), (5, 1, 4), (3, 1)),
131131
)
132132
def test_quaternion_weights_exception_not_raised(self, *shapes):
133133
"""Tests that valid input shapes do not raise exceptions for qslerp."""
@@ -140,6 +140,8 @@ def test_quaternion_weights_exception_not_raised(self, *shapes):
140140
(1,)),
141141
("Not all batch dimensions are broadcast-compatible.", (1, 4), (3, 4),
142142
(2,)),
143+
("Not all batch dimensions are broadcast-compatible.", (5, 1, 4),
144+
(5, 1, 4), (3,)),
143145
)
144146
def test_quaternion_weights_exception_raised(self, error_msg, *shapes):
145147
"""Tests that the shape exceptions are properly raised for qslerp."""

0 commit comments

Comments
 (0)