Skip to content

Commit

Permalink
[CPU] Use actual input shape to init desc for MemoryInputSDPA (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#27143)

An output shape was previously used to create an input descriptor for
some reason
  • Loading branch information
EgorDuplensky authored Oct 21, 2024
1 parent 308b420 commit 2cb8222
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 47 deletions.
52 changes: 7 additions & 45 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,29 +427,20 @@ void MemoryInputBase::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

auto&& shape = getOutputShapeAtPort(0);
auto precision = getOriginalOutputPrecisionAtPort(0);
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();

NodeConfig config;

if (!getParentEdges().empty()) {
PortConfig inPortConfig;

inPortConfig.inPlace(-1);
inPortConfig.constant(false);
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));

config.inConfs.push_back(std::move(inPortConfig));
const auto& inputShape = getInputShapeAtPort(0);
config.inConfs.emplace_back(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, inputShape));
}

PortConfig outPortConfig;

outPortConfig.inPlace(0);
outPortConfig.constant(false);
outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));

config.outConfs.push_back(std::move(outPortConfig));
const auto& outputShape = getOutputShapeAtPort(0);
config.outConfs.emplace_back(
descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, outputShape),
BlockedMemoryDesc::FULL_MASK,
0);

supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}
Expand Down Expand Up @@ -759,35 +750,6 @@ void MemoryInputSDPA::createPrimitive() {
OPENVINO_ASSERT(m_child_port_idx != -1, getName(), " should be connected to SDPA node.");
}

void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

auto&& shape = getOutputShapeAtPort(0);
auto precision = getOriginalOutputPrecisionAtPort(0);
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
NodeConfig config;
if (!getParentEdges().empty()) {
PortConfig inPortConfig;
inPortConfig.inPlace(-1);
inPortConfig.constant(false);
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.inConfs.push_back(std::move(inPortConfig));
}

PortConfig outPortConfig;
outPortConfig.inPlace(0);
outPortConfig.constant(false);
// layout for fake memory obj, the child sdpa also does not use it
outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.outConfs.push_back(std::move(outPortConfig));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}

void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
Node::initOptimalPrimitiveDescriptor();
}

void MemoryInputSDPA::assignStateHook() {
auto currentState = getAssignedState();
auto sdpaNode = m_sdpaNode.lock();
Expand Down
2 changes: 0 additions & 2 deletions src/plugins/intel_cpu/src/nodes/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ class MemoryInputSDPA : public MemoryInputBase {
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;

void createPrimitive() override;
void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
void resolveInPlaceEdges(Edge::LOOK look) override;

MemStatePtr makeState() const override;
Expand Down

0 comments on commit 2cb8222

Please sign in to comment.