Skip to content

Commit

Permalink
Refactor CTCGreedyDecoderSeqLenLayerTest, GreedyDecoderLayerTest, CTC…
Browse files Browse the repository at this point in the history
…LossLayerTest (openvinotoolkit#19842)

* Refactor CTCGreedyDecoderSeqLenLayerTest

* Refactor usingGreedyDecoderLayerTest

* Refactor CTCLossLayerTest
  • Loading branch information
olpipi authored Sep 19, 2023
1 parent f926e0e commit 57df7a4
Show file tree
Hide file tree
Showing 12 changed files with 495 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,32 @@
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include "single_layer_tests/ctc_greedy_decoder.hpp"
#include "single_op_tests/ctc_greedy_decoder.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;
using namespace ngraph::helpers;

namespace {
using ov::test::CTCGreedyDecoderLayerTest;

// Common params
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
const std::vector<ov::element::Type> model_type = {
ov::element::f32,
ov::element::f16
};
std::vector<bool> mergeRepeated{true, false};

std::vector<std::vector<ov::Shape>> input_shapes_static = {
{{ 50, 3, 3 }},
{{ 50, 3, 7 }},
{{ 50, 3, 8 }},
{{ 50, 3, 16 }},
{{ 50, 3, 128 }},
{{ 50, 3, 49 }},
{{ 50, 3, 55 }},
{{ 1, 1, 16 }}};

const auto basicCases = ::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(std::vector<size_t>({ 50, 3, 3 }),
std::vector<size_t>({ 50, 3, 7 }),
std::vector<size_t>({ 50, 3, 8 }),
std::vector<size_t>({ 50, 3, 16 }),
std::vector<size_t>({ 50, 3, 128 }),
std::vector<size_t>({ 50, 3, 49 }),
std::vector<size_t>({ 50, 3, 55 }),
std::vector<size_t>({ 1, 1, 16 })),
::testing::ValuesIn(model_type),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation(input_shapes_static)),
::testing::ValuesIn(mergeRepeated),
::testing::Values(ov::test::utils::DEVICE_CPU));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include "single_layer_tests/ctc_greedy_decoder_seq_len.hpp"
#include "single_op_tests/ctc_greedy_decoder_seq_len.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;
using namespace ngraph::helpers;

namespace {
using ov::test::CTCGreedyDecoderSeqLenLayerTest;

std::vector<std::vector<size_t>> inputShape{{1, 1, 1}, {1, 6, 10}, {3, 3, 16}, {5, 3, 55}};
std::vector<std::vector<ov::Shape>> shapes1 = {{{1, 1, 1}},
{{1, 6, 10}},
{{3, 3, 16}},
{{5, 3, 55}}};

const std::vector<InferenceEngine::Precision> probPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
const std::vector<ov::element::Type> probPrecisions = {
ov::element::f32,
ov::element::f16
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64
const std::vector<ov::element::Type> idxPrecisions = {
ov::element::i32,
ov::element::i64
};

std::vector<bool> mergeRepeated{true, false};

const auto basicCases = ::testing::Combine(
::testing::ValuesIn(inputShape),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation({shapes1})),
::testing::Values(10),
::testing::ValuesIn(probPrecisions),
::testing::ValuesIn(idxPrecisions),
Expand All @@ -37,9 +37,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_set1, CTCGreedyDecoderSeqLenLayerTest,
basicCases,
CTCGreedyDecoderSeqLenLayerTest::getTestCaseName);

std::vector<std::vector<ov::Shape>> shapes2 = {{{2, 8, 11}},
{{4, 10, 55}}};

INSTANTIATE_TEST_SUITE_P(smoke_set2, CTCGreedyDecoderSeqLenLayerTest,
::testing::Combine(
::testing::ValuesIn(std::vector<std::vector<size_t>>{{2, 8, 11}, {4, 10, 55}}),
::testing::ValuesIn(ov::test::static_shapes_to_test_representation({shapes2})),
::testing::ValuesIn(std::vector<int>{5, 100}),
::testing::ValuesIn(probPrecisions),
::testing::ValuesIn(idxPrecisions),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,25 @@

#include <vector>

#include "single_layer_tests/ctc_loss.hpp"

using namespace LayerTestsDefinitions;
#include "single_op_tests/ctc_loss.hpp"

namespace {
using ov::test::CTCLossLayerTest;

const std::vector<InferenceEngine::Precision> fPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
const std::vector<ov::element::Type> f_type = {
ov::element::f32,
ov::element::f16
};
const std::vector<InferenceEngine::Precision> iPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64
const std::vector<ov::element::Type> i_type = {
ov::element::i32,
ov::element::i64
};

const std::vector<bool> preprocessCollapseRepeated = {true, false};
const std::vector<bool> ctcMergeRepeated = {true, false};
const std::vector<bool> unique = {true, false};

const auto ctcLossArgsSubset1 = ::testing::Combine(
::testing::Values(std::vector<size_t>({2, 3, 3})), // logits shape
::testing::ValuesIn(std::vector<std::vector<int>>({{2, 3}, {3, 3}})), // logits length
::testing::ValuesIn(std::vector<std::vector<std::vector<int>>>(
{{{0, 1, 0}, {1, 0, 1}}, {{0, 1, 2}, {1, 1, 1}}})), // labels
Expand All @@ -38,13 +36,13 @@ const auto ctcLossArgsSubset1 = ::testing::Combine(
INSTANTIATE_TEST_SUITE_P(smoke_Set1, CTCLossLayerTest,
::testing::Combine(
ctcLossArgsSubset1,
::testing::ValuesIn(fPrecisions),
::testing::ValuesIn(iPrecisions),
::testing::Values(ov::test::static_shapes_to_test_representation({{2, 3, 3}})),
::testing::ValuesIn(f_type),
::testing::ValuesIn(i_type),
::testing::Values(ov::test::utils::DEVICE_CPU)),
CTCLossLayerTest::getTestCaseName);

const auto ctcLossArgsSubset2 = ::testing::Combine(
::testing::Values(std::vector<size_t>({3, 6, 8})), // logits shape
::testing::ValuesIn(std::vector<std::vector<int>>({{6, 5, 6}, {5, 5, 5}})), // logits length
::testing::ValuesIn(std::vector<std::vector<std::vector<int>>>(
{{{4, 1, 2, 3, 4, 5}, {5, 4, 3, 0, 1, 0}, {2, 1, 3, 1, 3, 0}},
Expand All @@ -59,8 +57,9 @@ const auto ctcLossArgsSubset2 = ::testing::Combine(
INSTANTIATE_TEST_SUITE_P(smoke_Set2, CTCLossLayerTest,
::testing::Combine(
ctcLossArgsSubset2,
::testing::ValuesIn(fPrecisions),
::testing::ValuesIn(iPrecisions),
::testing::Values(ov::test::static_shapes_to_test_representation({{3, 6, 8}})),
::testing::ValuesIn(f_type),
::testing::ValuesIn(i_type),
::testing::Values(ov::test::utils::DEVICE_CPU)),
CTCLossLayerTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/single_op/ctc_greedy_decoder.hpp"

namespace ov {
namespace test {
TEST_P(CTCGreedyDecoderLayerTest, Inference) {
run();
};
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/single_op/ctc_greedy_decoder_seq_len.hpp"

namespace ov {
namespace test {
TEST_P(CTCGreedyDecoderSeqLenLayerTest, Inference) {
run();
};
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/single_op/ctc_loss.hpp"

namespace ov {
namespace test {
TEST_P(CTCLossLayerTest, Inference) {
run();
}
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>
#include <tuple>
#include <vector>

#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ov {
namespace test {
typedef std::tuple<
ov::element::Type, // Model type
std::vector<InputShape>, // Input shapes
bool, // Merge repeated
std::string // Device name
> ctcGreedyDecoderParams;

class CTCGreedyDecoderLayerTest
: public testing::WithParamInterface<ctcGreedyDecoderParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ctcGreedyDecoderParams>& obj);
protected:
void SetUp() override;
};
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>
#include <tuple>
#include <vector>

#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ov {
namespace test {
typedef std::tuple<
std::vector<InputShape>, // Input shape
int, // Sequence lengths
ov::element::Type, // Probabilities precision
ov::element::Type, // Indices precision
int, // Blank index
bool, // Merge repeated
std::string // Device name
> ctcGreedyDecoderSeqLenParams;

class CTCGreedyDecoderSeqLenLayerTest
: public testing::WithParamInterface<ctcGreedyDecoderSeqLenParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ctcGreedyDecoderSeqLenParams>& obj);

protected:
void SetUp() override;
};

} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>
#include <tuple>
#include <vector>

#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ov {
namespace test {
typedef std::tuple<
std::vector<int>, // logits length
std::vector<std::vector<int>>, // labels
std::vector<int>, // labels length
int, // blank index
bool, // preprocessCollapseRepeated
bool, // ctcMergeRepeated
bool // Unique
> CTCLossParamsSubset;

typedef std::tuple<
CTCLossParamsSubset,
std::vector<InputShape>, // Input shapes
ov::element::Type, // Float point precision
ov::element::Type, // Integer precision
std::string // Device name
> CTCLossParams;

class CTCLossLayerTest : public testing::WithParamInterface<CTCLossParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<CTCLossParams> &obj);

protected:
void SetUp() override;
};

} // namespace test
} // namespace ov
Loading

0 comments on commit 57df7a4

Please sign in to comment.