Skip to content

Commit

Permalink
use M=16, K=32, N=16 in test
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Nov 6, 2024
1 parent de23a1a commit ca2cc8f
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions onnxruntime/test/providers/cpu/math/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,9 @@ TEST(MathOpTest, MatMul_float8E4M3FN) {
// test.AddInput<MLFloat16>("B", {2, 2}, FloatsToMLFloat16s({1.0f, 1.0f, 1.0f, 1.0f}));
// test.AddOutput<MLFloat16>("Y", {2, 2}, FloatsToMLFloat16s({2.0f, 2.0f, 2.0f, 2.0f}));

auto ones_32 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
test.AddInput<MLFloat16>("A", {2, 16}, FloatsToMLFloat16s(ones_32));
test.AddInput<MLFloat16>("B", {16, 2}, FloatsToMLFloat16s(ones_32));
test.AddOutput<MLFloat16>("Y", {2, 2}, FloatsToMLFloat16s({16.0f, 16.0f, 16.0f, 16.0f}));

test.AddInput<MLFloat16>("A", {16, 32}, FloatsToMLFloat16s(std::vector<float>(16 * 32, 1.0f)));
test.AddInput<MLFloat16>("B", {32, 16}, FloatsToMLFloat16s(std::vector<float>(32 * 16, 1.0f)));
test.AddOutput<MLFloat16>("Y", {16, 16}, FloatsToMLFloat16s(std::vector<float>(16 * 16, 16.0f)));

// test.AddInput<MLFloat16>("B", {4, 3}, FloatsToMLFloat16s({10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f}));
// test.AddInput<MLFloat16>("B", {4, 3}, FloatsToMLFloat16s({17.f, 19.f, 21.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f}));
Expand Down

0 comments on commit ca2cc8f

Please sign in to comment.