Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen-sun committed Feb 6, 2024
1 parent 7e9a018 commit 1ea8845
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 22 deletions.
21 changes: 12 additions & 9 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,8 @@ static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args,
static const char* keywords[6] = {"placement", "sbp", "check_meta", "sync_data", "copy", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$O!O!O!:local_to_global",
const_cast<char**>(keywords), &placement_obj, &sbp_obj,
&PyBool_Type, &check_meta_obj, &PyBool_Type, &sync_data_obj, &PyBool_Type, &copy_obj)) {
&PyBool_Type, &check_meta_obj, &PyBool_Type, &sync_data_obj,
&PyBool_Type, &copy_obj)) {
return NULL;
}
const bool check_meta = (check_meta_obj == Py_True);
Expand All @@ -722,8 +723,9 @@ static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args,
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(sbp_obj)));
sbp = functional::PyUnpackSbpParallelSequence(sbp_obj);
}
return PyTensor_New(ASSERT_PTR(functional::ToGlobal(
tensor, functional::PyUnpackParallelDesc(placement_obj), sbp, {}, check_meta, sync_data, copy)));
return PyTensor_New(
ASSERT_PTR(functional::ToGlobal(tensor, functional::PyUnpackParallelDesc(placement_obj), sbp,
{}, check_meta, sync_data, copy)));
END_HANDLE_ERRORS
}

Expand All @@ -740,10 +742,11 @@ static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args,
PyObject* check_meta_obj = Py_False;
PyObject* sync_data_obj = Py_True;
PyObject* copy_obj = Py_False;
static const char* keywords[7] = {"placement", "sbp", "grad_sbp", "check_meta", "sync_data", "copy", NULL};
static const char* keywords[7] = {"placement", "sbp", "grad_sbp", "check_meta",
"sync_data", "copy", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$OO!O!O!:global_to_global",
const_cast<char**>(keywords), &placement_obj, &sbp_obj,
&grad_sbp_obj, &PyBool_Type, &check_meta_obj, &PyBool_Type,
&grad_sbp_obj, &PyBool_Type, &check_meta_obj, &PyBool_Type,
&sync_data_obj, &PyBool_Type, &copy_obj)) {
return NULL;
}
Expand Down Expand Up @@ -785,8 +788,8 @@ static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args,
} else if (functional::PySbpParallelSequenceCheck(grad_sbp_obj)) {
grad_sbp = functional::PyUnpackSbpParallelSequence(grad_sbp_obj);
}
return PyTensor_New(
ASSERT_PTR(functional::ToGlobal(tensor, placement, sbp, grad_sbp, check_meta, sync_data, copy)));
return PyTensor_New(ASSERT_PTR(
functional::ToGlobal(tensor, placement, sbp, grad_sbp, check_meta, sync_data, copy)));
END_HANDLE_ERRORS
}

Expand Down Expand Up @@ -850,8 +853,8 @@ static PyObject* PyTensorObject_type_as(PyObject* self, PyObject* args, PyObject
for (int32_t i = 0; i < ndsbp->sbp_parallel_size(); i++) {
sbp.emplace_back(ndsbp->sbp_parallel(i));
}
return PyTensor_New(
ASSERT_PTR(functional::ToGlobal(value_tensor, placement, sbp, {}, true, true, /*copy=*/false)));
return PyTensor_New(ASSERT_PTR(
functional::ToGlobal(value_tensor, placement, sbp, {}, true, true, /*copy=*/false)));
END_HANDLE_ERRORS
}

Expand Down
5 changes: 3 additions & 2 deletions oneflow/api/python/utils/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,9 @@ Maybe<Tensor> MakeTensorFromOtherTensor(const std::shared_ptr<Tensor>& other,
const bool requires_grad) {
std::vector<Symbol<SbpParallel>> grad_sbp_tuple;
bool check_meta = other->is_global() ? false : true;
std::shared_ptr<Tensor> tensor = JUST(functional::ToGlobal(
other, placement, sbp_tuple, grad_sbp_tuple, check_meta, /* sync_data */ true, /*copy=*/false));
std::shared_ptr<Tensor> tensor =
JUST(functional::ToGlobal(other, placement, sbp_tuple, grad_sbp_tuple, check_meta,
/* sync_data */ true, /*copy=*/false));
if (dtype) {
const Symbol<DType>& dtype_ = JUST(dtype);
if (tensor->dtype() != dtype_) {
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/autograd/gradient_funcs/global_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ class LocalToGlobal : public OpExprGradFunction<CastGlobalCaptureState> {
{
Symbol<NdSbp> nd_sbp_constraint = ctx->nd_sbp;
Symbol<ParallelDesc> parallel_desc_constraint = ctx->parallel_desc;
out_grad = JUST(functional::ToGlobal(out_grad, parallel_desc_constraint,
*JUST(GetSbpList(nd_sbp_constraint)), GetNoneSbpList(),
/* check_meta */ false, /* sync_data */ true, /*copy=*/false));
out_grad =
JUST(functional::ToGlobal(out_grad, parallel_desc_constraint,
*JUST(GetSbpList(nd_sbp_constraint)), GetNoneSbpList(),
/* check_meta */ false, /* sync_data */ true, /*copy=*/false));
}
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grad}));
return Maybe<void>::Ok();
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/autograd/gradient_funcs/global_to_global.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ class GlobalToGlobalGradFunction : public OpExprGradFunction<GlobalToGlobalState
const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp));

if (LazyMode::is_enabled()) {
(*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list,
{}, /* check_meta */ false, /* sync_data */ true,
/*copy=*/false));
(*in_grads)[0] =
JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list, {},
/* check_meta */ false, /* sync_data */ true,
/*copy=*/false));
} else {
const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp));
(*in_grads)[0] = JUST(one::functional::ToGlobal(out_grad, ctx->parallel_desc, *grad_sbp_list,
Expand Down
5 changes: 3 additions & 2 deletions oneflow/core/functional/impl/global_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,9 @@ class ToGlobalFunctor {
} else {
DeviceType device_type = parallel_desc->device_type();
if (ccl::IsBroadcastRegistered(device_type)) {
tensor = JUST(LocalToGlobal(x, parallel_desc, sbp_parallels, NullOpt, NullOpt,
local_to_global_op_, check_meta, /* sync_data */ sync_data, copy));
tensor =
JUST(LocalToGlobal(x, parallel_desc, sbp_parallels, NullOpt, NullOpt,
local_to_global_op_, check_meta, /* sync_data */ sync_data, copy));
} else {
// Assuming that the newly adapted hardware device does not support collective
// communication, since local to global may need to synchronize data (through the
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1625,8 +1625,8 @@ class GlobalHannWindowFunctor {
result = JUST(ScalarDiv(JUST(ScalarSub(1, JUST(Cos(div_result)), 1)), 2));
}
}
result = JUST(ToGlobal(result, placement, sbp, {}, /* check_meta */true,
/* sync_data */ true, /*copy=*/false));
result = JUST(ToGlobal(result, placement, sbp, {}, /* check_meta */ true,
/* sync_data */ true, /*copy=*/false));
JUST(result->set_requires_grad(requires_grad));
return result;
}
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/functional/tensor_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ Maybe<void> UnifyInputAndIndicesOnDevice(const std::shared_ptr<Tensor>& x,
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
tensor_indices[i] = JUST(ToGlobal(tensor_index, placement,
std::vector<Symbol<SbpParallel>>(n, broadcast_sbp),
grad_sbp_tuple, /*check_meta=*/false, /*sync_data*/true,
grad_sbp_tuple, /*check_meta=*/false, /*sync_data*/ true,
/*copy=*/false));
}
}
Expand Down

0 comments on commit 1ea8845

Please sign in to comment.