Skip to content

Commit

Permalink
Optimized DotRT for batch=1
Browse files Browse the repository at this point in the history
  • Loading branch information
amirbawab committed Jun 14, 2019
1 parent c86edb9 commit 7de33c8
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 15 deletions.
50 changes: 45 additions & 5 deletions src/nn-builder/src/snippet/matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,6 @@ wabt::ExprList* MatrixSnippetSimd::MatrixDotRT(nn::ds::NDArray *lhs, nn::snippet
uint32_t width_remainder = lhs_width_bytes % WASMPP_V128_SIZE;
uint32_t simd_width_bytes = lhs_width_bytes - width_remainder;

// Cannot optimize if matrix width bytes is too small
if(lhs_width_bytes < WASMPP_V128_SIZE) {
return MatrixSnippet::MatrixDotRT(lhs, rhs, dst, locals);
}

auto rhs_rows = locals[0];
auto lhs_col_rhs_rows = locals[1];
auto lhs_row_offset = locals[2];
Expand All @@ -638,7 +633,52 @@ wabt::ExprList* MatrixSnippetSimd::MatrixDotRT(nn::ds::NDArray *lhs, nn::snippet
auto res_cell = locals[5];
auto res_128 = locals[6];

// Handle special case where the number of columns is 1
// and rhs has more than 4 elements
wabt::ExprList* e = new wabt::ExprList();
if(lhs->Shape()[1] == 1 && rhs_height_bytes >= WASMPP_V128_SIZE) {

uint32_t height_remainder = rhs_height_bytes % WASMPP_V128_SIZE;
uint32_t simd_height_bytes = rhs_height_bytes - height_remainder;

Merge(e, MakeLocalSet(lhs_row_offset, MakeI32Const(lhs->Memory()->Begin())));
Merge(e, GenerateRangeLoop(label_manager_, dst_row_offset, dst->Memory()->Begin(), dst->Memory()->End(), rhs_height_bytes, {}, [&](BlockBody* b1) {
// Reset rhs pointer to top row
if(rhs.HasBeginVar()) {
b1->Insert(MakeLocalSet(rhs_row_offset, MakeLocalGet(rhs.Var())));
} else {
b1->Insert(MakeLocalSet(rhs_row_offset, MakeI32Const(rhs.Array()->Memory()->Begin())));
}

// Apply SIMD while possible
b1->Insert(GenerateRangeLoop(label_manager_, rhs_rows, 0, simd_height_bytes, simd_type_size, {}, [&](BlockBody* b2) {
auto lhs_op = MakeUnary(Opcode::F32X4Splat, MakeF32Load(MakeLocalGet(lhs_row_offset)));
auto rhs_op = MakeV128Load(MakeBinary(Opcode::I32Add, MakeLocalGet(rhs_rows), MakeLocalGet(rhs_row_offset)));
auto dest_addr = MakeBinary(Opcode::I32Add, MakeLocalGet(dst_row_offset), MakeLocalGet(rhs_rows));
b2->Insert(MakeV128Store(dest_addr, MakeBinary(Opcode::F32X4Mul, lhs_op, rhs_op)));
}));

// Fallback to regular computation
if(height_remainder > 0) {
b1->Insert(GenerateDoWhileLoop(label_manager_, rhs_rows, rhs_height_bytes, type_size, {}, [&](BlockBody* b2) {
auto lhs_op = MakeF32Load(MakeLocalGet(lhs_row_offset));
auto rhs_op = MakeF32Load(MakeBinary(Opcode::I32Add, MakeLocalGet(rhs_rows), MakeLocalGet(rhs_row_offset)));
auto dest_addr = MakeBinary(Opcode::I32Add, MakeLocalGet(dst_row_offset), MakeLocalGet(rhs_rows));
b2->Insert(MakeF32Store(dest_addr, MakeBinary(Opcode::F32Mul, lhs_op, rhs_op)));
}));
}

// Move lhs pointer to next row
b1->Insert(GenerateCompoundAssignment(lhs_row_offset, Opcode::I32Add, MakeI32Const(lhs_width_bytes)));
}));
return e;
}

// Cannot optimize if matrix width bytes is too small
if(lhs_width_bytes < WASMPP_V128_SIZE) {
return MatrixSnippet::MatrixDotRT(lhs, rhs, dst, locals);
}

Merge(e, MakeLocalSet(lhs_row_offset, MakeI32Const(lhs->Memory()->Begin())));
Merge(e, GenerateRangeLoop(label_manager_, dst_row_offset, dst->Memory()->Begin(), dst->Memory()->End(), rhs_height_bytes, {}, [&](BlockBody* b1) {
if(rhs.HasBeginVar()) {
Expand Down
68 changes: 64 additions & 4 deletions src/nn-builder/tests/matrix_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ void MatrixSnippetTest::MatrixVectorAddition_test_1() {
ADD_NN_TEST(module_manager_, "MatrixVectorAddition_1", Type::I32, Type::I32, Type::I32, Type::I32);
}

void MatrixSnippetTest::MatrixRowSum_test_1() {
void MatrixSnippetTest::MatrixHorizontalSum_test_1() {
NN_TEST() {
uint32_t rows = 5;
uint32_t cols = 10;
Expand Down Expand Up @@ -362,7 +362,7 @@ void MatrixSnippetTest::MatrixRowSum_test_1() {
MakeI32Const(dst->Shape()[1])
}));
};
ADD_NN_TEST(module_manager_, "MatrixRowSum_1", Type::I32, Type::I32, Type::I32, Type::F32, Type::V128);
ADD_NN_TEST(module_manager_, "MatrixHorizontalSum_1", Type::I32, Type::I32, Type::I32, Type::F32, Type::V128);
}

void MatrixSnippetSimdTest::MatrixAdditionSimd_test_1() {
Expand Down Expand Up @@ -521,7 +521,7 @@ void MatrixSnippetSimdTest::MatrixVectorAdditionSimd_test_1() {
ADD_NN_TEST(module_manager_, "MatrixVectorAdditionSimd_1", Type::I32, Type::I32, Type::I32, Type::I32);
}

void MatrixSnippetSimdTest::MatrixRowSumSimd_test_1() {
void MatrixSnippetSimdTest::MatrixHorizontalSumSimd_test_1() {
NN_TEST() {
uint32_t rows = 57;
uint32_t cols = 101;
Expand Down Expand Up @@ -570,7 +570,7 @@ void MatrixSnippetSimdTest::MatrixRowSumSimd_test_1() {
MakeI32Const(dst->Shape()[1])
}));
};
ADD_NN_TEST(module_manager_, "MatrixRowSumSimd_1", Type::I32, Type::I32, Type::I32, Type::F32, Type::V128);
ADD_NN_TEST(module_manager_, "MatrixHorizontalSumSimd_1", Type::I32, Type::I32, Type::I32, Type::F32, Type::V128);
}

void MatrixSnippetSimdTest::MatrixDotRTSimd_test_1() {
Expand Down Expand Up @@ -646,6 +646,66 @@ void MatrixSnippetSimdTest::MatrixDotRTSimd_test_1() {
Type::F32, Type::V128);
}

void MatrixSnippetSimdTest::MatrixDotRTSimd_test_2() {
NN_TEST() {
uint32_t lhs_rows = 111;
uint32_t lhs_cols = 1;
uint32_t rhs_rows = 231;
uint32_t rhs_cols = lhs_cols;
uint32_t dst_rows = lhs_rows;
uint32_t dst_cols = rhs_rows;

NEW_MATRIX(lhs, lhs_rows, lhs_cols);
NEW_MATRIX(rhs, rhs_rows, rhs_cols);
NEW_MATRIX(dst, dst_rows, dst_cols);
NEW_MATRIX(expected, dst_rows, dst_cols);

std::vector<std::vector<float>> mat1(lhs_rows, std::vector<float>(lhs_cols, 0));
std::vector<std::vector<float>> mat2(rhs_rows, std::vector<float>(rhs_cols, 0));
std::vector<std::vector<float>> res(dst_rows, std::vector<float>(dst_cols, 0));
float val = 1.2;
for (uint32_t row = 0; row < lhs_rows; row++) {
for (uint32_t col = 0; col < lhs_cols; col++) {
f.Insert(MakeF32Store(MakeI32Const(lhs->GetLinearIndex({row, col})), MakeF32Const(val)));
mat1[row][col] = val;
val++;
}
}
for (uint32_t row = 0; row < rhs_rows; row++) {
for (uint32_t col = 0; col < rhs_cols; col++) {
f.Insert(MakeF32Store(MakeI32Const(rhs->GetLinearIndex({row, col})), MakeF32Const(val)));
mat2[row][col] = val;
val++;
}
}

for (auto i = 0; i < lhs_rows; ++i) {
for (auto j = 0; j < rhs_rows; ++j) {
float res_val = 0;
for (auto k = 0; k < lhs_cols; k++) {
res_val += mat1[i][k] * mat2[j][k];
}
res[i][j] = res_val;
}
}
for (uint32_t row = 0; row < dst_rows; row++) {
for (uint32_t col = 0; col < dst_cols; col++) {
f.Insert(MakeF32Store(MakeI32Const(expected->GetLinearIndex({row, col})), MakeF32Const(res[row][col])));
}
}

f.Insert(matrix_snippet_simd_.MatrixDotRT(lhs, rhs, dst, locals));
f.Insert(MakeCall(test_builtins_->assert_matrix_eq, {
MakeI32Const(dst->Memory()->Begin()),
MakeI32Const(expected->Memory()->Begin()),
MakeI32Const(dst->Shape()[0]),
MakeI32Const(dst->Shape()[1])
}));
};
ADD_NN_TEST(module_manager_, "MatrixDotRTSimd_2", Type::I32, Type::I32, Type::I32, Type::I32, Type::I32,
Type::F32, Type::V128);
}

} // namespace test
} // namespace nn

10 changes: 6 additions & 4 deletions src/nn-builder/tests/matrix_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class MatrixSnippetTest {
TestBuiltins* test_builtins_;
public:
MatrixSnippetTest(wasmpp::ModuleManager* module_manager, TestBuiltins* test_builtins) :
module_manager_(module_manager), test_builtins_(test_builtins), matrix_snippet_(&module_manager->Label()) {}
module_manager_(module_manager), test_builtins_(test_builtins), matrix_snippet_(&module_manager->Label(), nullptr) {}
void MatrixAddition_test_1();
void MatrixSubtraction_test_1();
void MatrixMultiplication_test_1();
Expand All @@ -24,7 +24,7 @@ class MatrixSnippetTest {
void MatrixDotLT_test_1();
void MatrixDotRT_test_1();
void MatrixVectorAddition_test_1();
void MatrixRowSum_test_1();
void MatrixHorizontalSum_test_1();
};

class MatrixSnippetSimdTest {
Expand All @@ -34,14 +34,16 @@ class MatrixSnippetSimdTest {
TestBuiltins* test_builtins_;
public:
MatrixSnippetSimdTest(wasmpp::ModuleManager* module_manager, TestBuiltins* test_builtins) :
module_manager_(module_manager), test_builtins_(test_builtins), matrix_snippet_simd_(&module_manager->Label()) {}
module_manager_(module_manager), test_builtins_(test_builtins),
matrix_snippet_simd_(&module_manager->Label(), nullptr) {}
void MatrixAdditionSimd_test_1();
void MatrixSubtractionSimd_test_1();
void MatrixMultiplicationSimd_test_1();
void MatrixScalarSimd_test_1();
void MatrixDotRTSimd_test_1();
void MatrixDotRTSimd_test_2();
void MatrixVectorAdditionSimd_test_1();
void MatrixRowSumSimd_test_1();
void MatrixHorizontalSumSimd_test_1();
};

} // namespace test
Expand Down
5 changes: 3 additions & 2 deletions src/nn-builder/tests/test-builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ int main(int argc, char *argv[]) {
matrix_snippet_test.MatrixDotLT_test_1();
matrix_snippet_test.MatrixDotRT_test_1();
matrix_snippet_test.MatrixVectorAddition_test_1();
matrix_snippet_test.MatrixRowSum_test_1();
matrix_snippet_test.MatrixHorizontalSum_test_1();

// Create matrix simd tests
nn::test::MatrixSnippetSimdTest matrix_snippet_simd_test(&module_manager, &test_builtins);
Expand All @@ -87,8 +87,9 @@ int main(int argc, char *argv[]) {
matrix_snippet_simd_test.MatrixMultiplicationSimd_test_1();
matrix_snippet_simd_test.MatrixScalarSimd_test_1();
matrix_snippet_simd_test.MatrixDotRTSimd_test_1();
matrix_snippet_simd_test.MatrixDotRTSimd_test_2();
matrix_snippet_simd_test.MatrixVectorAdditionSimd_test_1();
matrix_snippet_simd_test.MatrixRowSumSimd_test_1();
matrix_snippet_simd_test.MatrixHorizontalSumSimd_test_1();

assert(module_manager.Validate());
if(!output_file.empty()) {
Expand Down

0 comments on commit 7de33c8

Please sign in to comment.