Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Semyon1104 committed Jul 18, 2024
1 parent ebee162 commit 12c29fb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/layers/OutputLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ std::pair<std::vector<std::string>, Tensor> OutputLayer::top_k(
case Type::kFloat: {
auto toppair = top_k_vec<float>(*input.as<float>(), labels_, k);
reslabels = toppair.first;
resvector =
make_tensor(toppair.second, input.get_shape(), input.get_bias());
resvector = resvector = make_tensor(toppair.second);
break;
}
case Type::kInt: {
auto toppair = top_k_vec<int>(*input.as<int>(), labels_, k);
reslabels = toppair.first;
resvector = make_tensor(toppair.second, input.get_shape());
resvector = make_tensor(toppair.second);
break;
}
default: {
Expand Down
32 changes: 32 additions & 0 deletions test/single_layer/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,35 @@ TEST(Tensor, cannot_set_bias_with_incorrect_size) {
std::vector<float> incorrect_bias = {0.5F, 1.5F};
ASSERT_ANY_THROW(t.set_bias(incorrect_bias));
}

TEST(Tensor, can_create_multidimensional_tensor) {
Shape sh({2, 3, 2}); // 3D tensor shape
std::vector<int> vals_tensor = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
ASSERT_NO_THROW(make_tensor<int>(vals_tensor, sh));
}

TEST(Tensor, check_get_element_from_multidimensional_tensor) {
Shape sh({2, 3, 2}); // 3D tensor shape
std::vector<int> vals_tensor = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
Tensor t = make_tensor<int>(vals_tensor, sh);
EXPECT_EQ(t.get<int>({1, 2, 1}), 12);
}

TEST(Tensor, cannot_get_element_with_invalid_coordinates) {
Shape sh({2, 3, 2}); // 3D tensor shape
std::vector<int> vals_tensor = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
Tensor t = make_tensor<int>(vals_tensor, sh);
ASSERT_ANY_THROW(t.get<int>({2, 3, 1}));
}
TEST(Tensor, cannot_create_tensor_with_incorrect_shape) {
Shape sh({2, 3});
std::vector<float> vals_tensor = {1.0F, 2.0F, 3.0F, 4.0F,
5.0F}; // Incorrect size
ASSERT_ANY_THROW(make_tensor<float>(vals_tensor, sh));
}

TEST(Tensor, cannot_create_tensor_with_unknown_type) {
std::vector<char> vals_tensor = {'a', 'b', 'c'};
Shape sh({3});
ASSERT_ANY_THROW(make_tensor<char>(vals_tensor, sh));
}

0 comments on commit 12c29fb

Please sign in to comment.