Skip to content

Commit

Permalink
Omit scalar for batch=1
Browse files Browse the repository at this point in the history
  • Loading branch information
amirbawab committed Jun 14, 2019
1 parent 6a58f17 commit c86edb9
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/nn-builder/src/arch/layers/dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,10 @@ wabt::ExprList* FullyConnectedLayer::Backward(wabt::Var input_begin, wabt::Var t
{vi32_1, vi32_2, vi32_3, vi32_4, vi32_5, vf32_1, v128_1}));
END_TIME(C_1)
START_TIME()
Merge(e, NetworkModel()->Snippets().matrix->MatrixScalar(dW_,
MakeF32Const(1.0f/ NetworkModel()->TrainingBatchSize()), dW_,
{vi32_1, vi32_2, vf32_1}));
if(NetworkModel()->TrainingBatchSize() > 1) {
Merge(e, NetworkModel()->Snippets().matrix->MatrixScalar(dW_, MakeF32Const(1.0f / NetworkModel()->TrainingBatchSize()),
dW_, {vi32_1, vi32_2, vf32_1}));
}
END_TIME(C_2)

// D) db[l] = (1/m) dZ[l]
Expand All @@ -210,9 +211,10 @@ wabt::ExprList* FullyConnectedLayer::Backward(wabt::Var input_begin, wabt::Var t
{vi32_1, vi32_2, vi32_3, vf32_1, v128_1}));
END_TIME(D_1)
START_TIME()
Merge(e, NetworkModel()->Snippets().matrix->MatrixScalar(db_,
MakeF32Const(1.0f/NetworkModel()->TrainingBatchSize()),
db_, {vi32_1, vi32_2, vf32_1}));
if(NetworkModel()->TrainingBatchSize() > 1) {
Merge(e, NetworkModel()->Snippets().matrix->MatrixScalar(db_, MakeF32Const(1.0f/NetworkModel()->TrainingBatchSize()),
db_, {vi32_1, vi32_2, vf32_1}));
}
END_TIME(D_2)

if(LayerIndex() > 1) {
Expand Down

0 comments on commit c86edb9

Please sign in to comment.