diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp index c0eaa737e72..00580d213d9 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" @@ -269,19 +270,27 @@ class BroadcastMinMax : public BroadcastBinaryGrad { const auto& x_shape = *(x->shape()); const Shape& left_extended_x_shape = CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes()); - const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape); - const std::vector x_axis = - std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; - broad_x_ = JUST(functional::BroadcastLike(x, out_grads.at(0), x_axis)); + if (left_extended_x_shape == out_shape) { + broad_x_ = JUST(functional::ReshapeLike(x, JUST(VectorAt(out_grads, 0)))); + } else { + const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape); + const std::vector x_axis = + std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; + broad_x_ = JUST(functional::BroadcastLike(x, JUST(VectorAt(out_grads, 0)), x_axis)); + } } if (ctx->broadcast_y) { const auto& y_shape = *(y->shape()); const Shape& left_extended_y_shape = CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes()); - const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape); - const std::vector y_axis = - std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; - broad_y_ = JUST(functional::BroadcastLike(y, out_grads.at(0), y_axis)); + if (left_extended_y_shape == out_shape) { + broad_y_ = JUST(functional::ReshapeLike(y, JUST(VectorAt(out_grads, 0)))); + } else { + const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape); + const std::vector y_axis = + std::vector{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}; + broad_y_ = JUST(functional::BroadcastLike(y, JUST(VectorAt(out_grads, 0)), y_axis)); + } } const auto& broad_grads = JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_)); diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index c00d2ac8d20..baf634f1679 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -427,25 +427,36 @@ class BroadcastLikeFunctor { Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& like, const std::vector& broadcast_axes) const { + const Shape& x_shape = *x->shape(); + const Shape& like_shape = *like->shape(); + if (x_shape == like_shape) { return x; } MutableAttrMap attrs; if (broadcast_axes.empty()) { - int64_t like_ndim = like->shape()->NumAxes(); - int64_t x_ndim = x->shape()->NumAxes(); + int64_t like_ndim = like_shape.NumAxes(); + int64_t x_ndim = x_shape.NumAxes(); int64_t num_prepend = like_ndim - x_ndim; std::vector prepend_shape(num_prepend, 1); - std::vector broadcast_axes; - for (int i = 0; i < x_ndim; ++i) { prepend_shape.emplace_back(x->shape()->At(i)); } + std::vector broadcast_axes; + for (int i = 0; i < x_ndim; ++i) { prepend_shape.emplace_back(x_shape.At(i)); } for (int i = 0; i < num_prepend; ++i) { broadcast_axes.emplace_back(i); } for (int i = num_prepend; i < prepend_shape.size(); ++i) { - if (prepend_shape[i] != like->shape()->At(i)) { - if (prepend_shape[i] == 1) { broadcast_axes.emplace_back(i); } - CHECK_GE_OR_RETURN(prepend_shape[i], 1) - << Error::RuntimeError() << "output with shape " << x->shape()->ToString() - << " doesn't match the broadcast shape " << like->shape()->ToString(); + if (prepend_shape[i] != like_shape.At(i)) { + if (prepend_shape[i] == 1) { + broadcast_axes.emplace_back(i); + } else { + return Error::RuntimeError() << "The expanded size of the tensor " + << "(" << like_shape.At(i) << ")" + << " must match the existing size (" << prepend_shape[i] + << ") at non-singleton dimension " << i + << ". Target sizes: " << like_shape.ToString() + << ". Tensor sizes: " << x_shape.ToString(); + } } } + JUST(attrs.SetAttr>("broadcast_axes", broadcast_axes)); + } else { + JUST(attrs.SetAttr>("broadcast_axes", broadcast_axes)); } - JUST(attrs.SetAttr>("broadcast_axes", broadcast_axes)); return OpInterpUtil::Dispatch(*op_, {x, JUST(like->detach())}, attrs); } diff --git a/python/oneflow/test/exceptions/test_array_functor.py b/python/oneflow/test/exceptions/test_array_functor.py index 991b8e90d31..595768957b1 100644 --- a/python/oneflow/test/exceptions/test_array_functor.py +++ b/python/oneflow/test/exceptions/test_array_functor.py @@ -31,7 +31,7 @@ def test_broadcast_like_runtime_error(test_case): like = flow.ones((2, 2, 2), dtype=flow.float32, requires_grad=True) y = flow.broadcast_like(x, like) test_case.assertTrue( - "doesn't match the broadcast shape" in str(context.exception) + "The expanded size of the tensor" in str(context.exception) ) def test_concat_index_error(test_case): diff --git a/python/oneflow/test/modules/test_max.py b/python/oneflow/test/modules/test_max.py index 546a22a5ddd..919eabec6bc 100644 --- a/python/oneflow/test/modules/test_max.py +++ b/python/oneflow/test/modules/test_max.py @@ -98,6 +98,14 @@ def test_max_broadcast_dtype_promotion(test_case): y = random_tensor(ndim, *b_dims, dtype=int).to(device) return torch.max(x, y) + @autotest(n=3, auto_backward=True, check_graph=True) + def test_max_with_diff_size(test_case): + x = flow.rand(1, 1, 4, requires_grad=True) + y = flow.rand(1, 4, requires_grad=True) + x = random_tensor(3, 1, 1, 4) + y = random_tensor(2, 1, 4) + return torch.max(x, y) + if __name__ == "__main__": unittest.main()