@@ -170,8 +170,8 @@ def quaternion_weights(
170
170
quaternions in its last dimension.
171
171
quaternion2: A tensor of shape `[A1, ... , An, 4]` storing normalized
172
172
quaternions in its last dimension.
173
- percent: A `float` or a tensor with a shape broadcastable to the shape `[A1,
174
- ... , An]` .
173
+ percent: A `float` or tensor with shape broadcastable to the shape of input
174
+ vectors .
175
175
eps: A `float` used to make operations safe. When left as None, the function
176
176
automatically picks the best epsilon based on the dtype and the operation.
177
177
name: A name for this op. Defaults to "quaternion_weights".
@@ -198,7 +198,7 @@ def quaternion_weights(
198
198
tensor = quaternion2 , tensor_name = "quaternion2" , has_dim_equals = (- 1 , 4 ))
199
199
shape .compare_batch_dimensions (
200
200
tensors = (quaternion1 , quaternion2 , percent ),
201
- last_axes = ( - 2 , - 2 , - 1 ) ,
201
+ last_axes = - 1 ,
202
202
broadcast_compatible = True ,
203
203
tensor_names = ("quaternion1" , "quaternion2" , "percent" ))
204
204
quaternion1 = asserts .assert_normalized (quaternion1 )
@@ -266,7 +266,7 @@ def vector_weights(vector1: type_alias.TensorLike,
266
266
tensor_names = ("vector1" , "vector2" ))
267
267
shape .compare_batch_dimensions (
268
268
tensors = (vector1 , vector2 , percent ),
269
- last_axes = ( - 2 , - 2 , - 1 ) ,
269
+ last_axes = - 1 ,
270
270
broadcast_compatible = True ,
271
271
tensor_names = ("vector1" , "vector2" , "percent" ))
272
272
normalized1 = tf .nn .l2_normalize (vector1 , axis = - 1 )
0 commit comments