diff --git a/src/cpp/src/device_config.hpp b/src/cpp/src/device_config.hpp index cc2e21b9a1..fee6c7abd1 100644 --- a/src/cpp/src/device_config.hpp +++ b/src/cpp/src/device_config.hpp @@ -117,22 +117,22 @@ class DeviceConfig { } for (size_t layer_id = 0; layer_id < m_num_decoder_layers; layer_id++) { - m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(), - ov::Dimension(m_num_kv_heads[layer_id]), - ov::Dimension(m_block_size), - ov::Dimension(m_head_size)}); - m_value_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(), ov::Dimension(m_num_kv_heads[layer_id]), ov::Dimension(m_block_size), ov::Dimension(m_head_size)}); - if (m_device.find("GPU") != std::string::npos) { + if (m_device.find("GPU") == std::string::npos) { + m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(), + ov::Dimension(m_num_kv_heads[layer_id]), + ov::Dimension(m_block_size), + ov::Dimension(m_head_size)}); + } else if (m_device.find("GPU") != std::string::npos) { // Update key shape, as the key's shape is different from the value's shape m_key_cache_shape.push_back(ov::PartialShape{ov::Dimension::dynamic(), - ov::Dimension(m_num_kv_heads[layer_id]), - ov::Dimension(m_head_size), - ov::Dimension(m_block_size)}); + ov::Dimension(m_num_kv_heads[layer_id]), + ov::Dimension(m_head_size), + ov::Dimension(m_block_size)}); } } }