Skip to content

Commit

Permalink
refine subgraph test for specific case
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 16, 2024
1 parent aa53261 commit 0274e76
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
3 changes: 1 addition & 2 deletions src/plugins/intel_cpu/src/nodes/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ void Reshape::initSupportedPrimitiveDescriptors() {
bool canBeInPlace = true;

// CVS-81059 : disable inPlace in following case since it won't be satisfied by framework
if ((!isConstant() && getParentEdgeAt(0)->getParent()->isConstant()) ||
(getParentEdgeAt(0)->getParent()->getChildEdges().size() != 1))
if (!isConstant() && getParentEdgeAt(0)->getParent()->isConstant())
canBeInPlace = false;

NodeConfig config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,39 +91,46 @@ TEST_F(InPlaceReshapeFromConstantCheck, smoke_CPU_InPlaceReshapeFromConstantChec
* params[0] params[1]
* \ /
* \ /
* add----result1
* add---reshape2---result2
* |
* reshape
* reshape1
* |
* MVN
* |
* result2
* result1
*
* If parent of reshape have more than one edges, reshape in place should be not applicable.
* If parent of reshape is shared, reshape in place should be not applicable.
* This is becuase multiple branches could change data on the same port, then pollute result each other.
*/

class InPlaceReshapeShareInputCheck : public SubgraphBaseTest {
protected:
void SetUp() override {
const auto rtPrc = ov::element::f32;
const ov::Shape inpShape = {1, 8, 16, 16};
const ov::Shape inpShape = {1, 16, 16};
targetStaticShapes = {{inpShape, inpShape}};
targetDevice = ov::test::utils::DEVICE_CPU;
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(rtPrc, inpShape),
std::make_shared<ov::op::v0::Parameter>(rtPrc, inpShape)};

auto add = std::make_shared<ov::op::v1::Add>(params[0], params[1]);
auto res0 = std::make_shared<ov::op::v0::Result>(add);
std::vector<int> newShape = {1, 2, 64, 16};
auto targetShape = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, newShape);
auto reshape = std::make_shared<ov::op::v1::Reshape>(add, targetShape, false);
auto mvn = std::make_shared<ov::op::v6::MVN>(reshape,
std::vector<int> newShape1 = {1, 1, 16, 16};
auto targetShape1 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, newShape1);
auto reshape1 = std::make_shared<ov::op::v1::Reshape>(add, targetShape1, false);
auto mvn = std::make_shared<ov::op::v6::MVN>(reshape1,
ov::op::v0::Constant::create(ov::element::i32, ov::Shape{2}, {2, 3}),
true,
0.1,
ov::op::MVNEpsMode::INSIDE_SQRT);
auto res1 = std::make_shared<ov::op::v0::Result>(mvn);
function = std::make_shared<ov::Model>(ov::ResultVector{res0, res1}, params, "reshape_share_input_check");

std::vector<int> newShape2 = {1, 4, 8, 8};
auto targetShape2 = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, newShape2);
auto reshape2 = std::make_shared<ov::op::v1::Reshape>(add, targetShape2, false);

auto res2 = std::make_shared<ov::op::v0::Result>(reshape2);

function = std::make_shared<ov::Model>(ov::ResultVector{res1, res2}, params, "reshape_share_input_check");
}
};

Expand Down

0 comments on commit 0274e76

Please sign in to comment.