Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
self consistent cpp addition
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jul 8, 2022
1 parent bbb258a commit 5be12f7
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,13 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
NDArrayHandle* var_handles,
uint32_t* reqs_array,
NDArrayHandle* grad_handles);
/*!
* \brief mark nonleaf NDArrays as variables during deferredcomputation
* \param num_nleafs number of nonleaf NDArrays
* \param cnt_var count of existing marked nonleaf variables
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class Imperative {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief mark nonleaf variables during DC for computing gradients. */
void MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars);
/*! \brief unmark nonleaf variables to free the memory. */
void DropGrads(const std::vector<NDArray*>& variables);
/*! \brief compute the gradient of outputs w.r.t variables. */
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in autograd_entry_ */
void set_fresh_out_grad(bool state) const;
/*! \brief copy the autograd_entry_ from src NDArray */
void copy_autograd_entry_(const NDArray* src);
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse
Expand Down
12 changes: 12 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles,
*out = s;
API_END_HANDLE_ERROR(delete s;);
}

int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var) {
API_BEGIN();
std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
NDArray* array = reinterpret_cast<NDArray*>(nleaf_handles[i]);
nleafs.emplace_back(array);
}
Imperative::Get()->MarkDCVariables(nleafs, cnt_var);
API_END();
}
12 changes: 12 additions & 0 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
}
}

void Imperative::MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars) {
for (NDArray* nleaf : nleafs) {
if (Imperative::DCInfo::IsNone(*nleaf)) {
LOG(WARNING) << "The marked node doesn't have deferred compute history.";
} else {
nnvm::ObjectPtr node = nleaf->deferredcompute_entry_.node;
node->attrs.dict["mark_id"] = std::to_string(cnt_vars);
}
cnt_vars++;
}
}

// Unmark the variables to free the memory.
void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
for (auto variable : variables) {
Expand Down
4 changes: 4 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ void NDArray::set_fresh_out_grad(bool state) const {
info.fresh_out_grad = state;
}

void NDArray::copy_autograd_entry_(const NDArray* src) {
autograd_entry_ = nnvm::NodeEntry{src->autograd_entry_.node, 0, 0};
}

#if MXNET_USE_ONEDNN == 1

bool NDArray::Chunk::IsDNNL() const {
Expand Down

0 comments on commit 5be12f7

Please sign in to comment.