Skip to content

Commit

Permalink
Remove redundant fill in SPMM kernel (#3166)
Browse files Browse the repository at this point in the history
* remove redundant fill

* trigger ci
  • Loading branch information
VoVAllen authored and BarclayII committed Jul 21, 2021
1 parent 5f5a6ef commit d15582b
Showing 1 changed file with 0 additions and 16 deletions.
16 changes: 0 additions & 16 deletions src/array/cpu/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>();
std::fill(out_off, out_off + csr.num_rows * dim, 0);
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
});
Expand All @@ -33,8 +31,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
DType *out_off = out.Ptr<DType>();
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + csr.num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + csr.num_rows * dim, 0);
if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
Expand Down Expand Up @@ -66,11 +62,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
Expand All @@ -86,13 +77,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + vec_csr[etype].num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
Expand Down

0 comments on commit d15582b

Please sign in to comment.