Skip to content

Commit

Permalink
[CPU]Define key/value cache prec/group_size priority
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang Yi <[email protected]>
  • Loading branch information
zhangYiIntel committed Jan 6, 2025
1 parent 56245d0 commit 84f03a3
Show file tree
Hide file tree
Showing 12 changed files with 225 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from openvino._pyopenvino.properties import loaded_from_cache
from openvino._pyopenvino.properties import cache_encryption_callbacks
from openvino._pyopenvino.properties import weights_path
from openvino._pyopenvino.properties import key_cache_precision
from openvino._pyopenvino.properties import value_cache_precision
from openvino._pyopenvino.properties import key_cache_group_size
from openvino._pyopenvino.properties import value_cache_group_size

# Submodules
from openvino.runtime.properties import hint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,4 @@
from openvino._pyopenvino.properties.hint import allow_auto_batching
from openvino._pyopenvino.properties.hint import dynamic_quantization_group_size
from openvino._pyopenvino.properties.hint import kv_cache_precision
from openvino._pyopenvino.properties.hint import key_cache_precision
from openvino._pyopenvino.properties.hint import value_cache_precision
from openvino._pyopenvino.properties.hint import key_cache_group_size
from openvino._pyopenvino.properties.hint import value_cache_group_size
from openvino._pyopenvino.properties.hint import activations_scale_factor
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void regmodule_properties(py::module m) {
wrap_property_RW(m_properties, ov::force_tbb_terminate, "force_tbb_terminate");
wrap_property_RW(m_properties, ov::enable_mmap, "enable_mmap");
wrap_property_RW(m_properties, ov::weights_path, "weights_path");
wrap_property_RW(m_properties, ov::key_cache_precision, "key_cache_precision");
wrap_property_RW(m_properties, ov::value_cache_precision, "value_cache_precision");
wrap_property_RW(m_properties, ov::key_cache_group_size, "key_cache_group_size");
wrap_property_RW(m_properties, ov::value_cache_group_size, "value_cache_group_size");

wrap_property_RO(m_properties, ov::supported_properties, "supported_properties");
wrap_property_RO(m_properties, ov::available_devices, "available_devices");
Expand Down Expand Up @@ -101,10 +105,6 @@ void regmodule_properties(py::module m) {
wrap_property_RW(m_hint, ov::hint::allow_auto_batching, "allow_auto_batching");
wrap_property_RW(m_hint, ov::hint::dynamic_quantization_group_size, "dynamic_quantization_group_size");
wrap_property_RW(m_hint, ov::hint::kv_cache_precision, "kv_cache_precision");
wrap_property_RW(m_hint, ov::hint::key_cache_precision, "key_cache_precision");
wrap_property_RW(m_hint, ov::hint::value_cache_precision, "value_cache_precision");
wrap_property_RW(m_hint, ov::hint::key_cache_group_size, "key_cache_group_size");
wrap_property_RW(m_hint, ov::hint::value_cache_group_size, "value_cache_group_size");
wrap_property_RW(m_hint, ov::hint::activations_scale_factor, "activations_scale_factor");

// Submodule intel_cpu
Expand Down
24 changes: 12 additions & 12 deletions src/bindings/python/tests/test_runtime/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,18 @@ def test_properties_ro(ov_property_ro, expected_value):
"WEIGHTS_PATH",
(("./model.bin", "./model.bin"),),
),
(
props.key_cache_group_size,
"KEY_CACHE_GROUP_SIZE",
((64, 64),),
),
(
props.value_cache_group_size,
"VALUE_CACHE_GROUP_SIZE",
((64, 64),),
),
(props.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)),
(props.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)),
(hints.inference_precision, "INFERENCE_PRECISION_HINT", ((Type.f32, Type.f32),)),
(
hints.model_priority,
Expand Down Expand Up @@ -334,19 +346,7 @@ def test_properties_ro(ov_property_ro, expected_value):
"DYNAMIC_QUANTIZATION_GROUP_SIZE",
((64, 64),),
),
(
hints.key_cache_group_size,
"KEY_CACHE_GROUP_SIZE",
((64, 64),),
),
(
hints.value_cache_group_size,
"VALUE_CACHE_GROUP_SIZE",
((64, 64),),
),
(hints.kv_cache_precision, "KV_CACHE_PRECISION", ((Type.f32, Type.f32),)),
(hints.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)),
(hints.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)),
(
hints.activations_scale_factor,
"ACTIVATIONS_SCALE_FACTOR",
Expand Down
48 changes: 24 additions & 24 deletions src/inference/include/openvino/runtime/properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,30 +580,6 @@ static constexpr Property<uint64_t, PropertyMutability::RW> dynamic_quantization
*/
static constexpr Property<element::Type, PropertyMutability::RW> kv_cache_precision{"KV_CACHE_PRECISION"};

/**
* @brief Hint for device to use specified precision for key cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<element::Type, PropertyMutability::RW> key_cache_precision{"KEY_CACHE_PRECISION"};

/**
* @brief Hint for device to use specified precision for value cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<element::Type, PropertyMutability::RW> value_cache_precision{"VALUE_CACHE_PRECISION"};

/**
* @brief Hint for device to use group_size for key cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<uint64_t, PropertyMutability::RW> key_cache_group_size{"KEY_CACHE_GROUP_SIZE"};

/**
* @brief Hint for device to use group_size for value cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<uint64_t, PropertyMutability::RW> value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"};

/**
* @brief This property scales down activations to prevent overflows when inference precision is f16.
* @ingroup ov_runtime_cpp_prop_api
Expand Down Expand Up @@ -1383,4 +1359,28 @@ static constexpr Property<std::vector<std::string>, PropertyMutability::RO> exec
* @note This property is used for weightless caching. Only used when ov::CacheMode Property is set to "OPTIMIZE_SIZE".
*/
static constexpr Property<std::string, PropertyMutability::RW> weights_path{"WEIGHTS_PATH"};

/**
* @brief The precision of key cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<element::Type, PropertyMutability::RW> key_cache_precision{"KEY_CACHE_PRECISION"};

/**
* @brief The precision of value cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<element::Type, PropertyMutability::RW> value_cache_precision{"VALUE_CACHE_PRECISION"};

/**
* @brief The group_size of key cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<uint64_t, PropertyMutability::RW> key_cache_group_size{"KEY_CACHE_GROUP_SIZE"};

/**
* @brief The group_size of value cache compression
* @ingroup ov_runtime_cpp_prop_api
*/
static constexpr Property<uint64_t, PropertyMutability::RW> value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"};
} // namespace ov
24 changes: 12 additions & 12 deletions src/plugins/intel_cpu/src/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,10 @@ ov::Any CompiledModel::get_property(const std::string& name) const {
RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()),
RO_property(ov::hint::dynamic_quantization_group_size.name()),
RO_property(ov::hint::kv_cache_precision.name()),
RO_property(ov::hint::key_cache_precision.name()),
RO_property(ov::hint::value_cache_precision.name()),
RO_property(ov::hint::key_cache_group_size.name()),
RO_property(ov::hint::value_cache_group_size.name()),
RO_property(ov::key_cache_precision.name()),
RO_property(ov::value_cache_precision.name()),
RO_property(ov::key_cache_group_size.name()),
RO_property(ov::value_cache_group_size.name()),
};

OPENVINO_SUPPRESS_DEPRECATED_START
Expand Down Expand Up @@ -336,14 +336,14 @@ ov::Any CompiledModel::get_property(const std::string& name) const {
return decltype(ov::hint::dynamic_quantization_group_size)::value_type(config.fcDynamicQuantizationGroupSize);
} else if (name == ov::hint::kv_cache_precision) {
return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision);
} else if (name == ov::hint::key_cache_precision) {
return decltype(ov::hint::key_cache_precision)::value_type(config.keyCachePrecision);
} else if (name == ov::hint::value_cache_precision) {
return decltype(ov::hint::value_cache_precision)::value_type(config.valueCachePrecision);
} else if (name == ov::hint::key_cache_group_size) {
return decltype(ov::hint::key_cache_group_size)::value_type(config.keyCacheGroupSize);
} else if (name == ov::hint::value_cache_group_size) {
return decltype(ov::hint::value_cache_group_size)::value_type(config.valueCacheGroupSize);
} else if (name == ov::key_cache_precision) {
return decltype(ov::key_cache_precision)::value_type(config.keyCachePrecision);
} else if (name == ov::value_cache_precision) {
return decltype(ov::value_cache_precision)::value_type(config.valueCachePrecision);
} else if (name == ov::key_cache_group_size) {
return decltype(ov::key_cache_group_size)::value_type(config.keyCacheGroupSize);
} else if (name == ov::value_cache_group_size) {
return decltype(ov::value_cache_group_size)::value_type(config.valueCacheGroupSize);
}
OPENVINO_THROW("Unsupported property: ", name);
}
Expand Down
48 changes: 39 additions & 9 deletions src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
ov::hint::kv_cache_precision.name(),
". Supported values: u8, bf16, f16, f32");
}
} else if (key == ov::hint::key_cache_precision.name()) {
} else if (key == ov::key_cache_precision.name()) {
try {
kvCachePrecisionSetExplicitly = true;
keyCachePrecisionSetExplicitly = true;
auto const prec = val.as<ov::element::Type>();
if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) {
keyCachePrecision = prec;
Expand All @@ -386,12 +386,12 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::hint::key_cache_precision.name(),
ov::key_cache_precision.name(),
". Supported values: u8, bf16, f16, f32");
}
} else if (key == ov::hint::value_cache_precision.name()) {
} else if (key == ov::value_cache_precision.name()) {
try {
kvCachePrecisionSetExplicitly = true;
valueCachePrecisionSetExplicitly = true;
auto const prec = val.as<ov::element::Type>();
if (one_of(prec,
ov::element::f32,
Expand All @@ -407,15 +407,17 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::hint::value_cache_precision.name(),
ov::value_cache_precision.name(),
". Supported values: u4, u8, bf16, f16, f32");
}
} else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) {
} else if (key == ov::key_cache_group_size.name() || key == ov::value_cache_group_size.name()) {
try {
auto const groupSize = val.as<uint64_t>();
if (key == ov::hint::key_cache_group_size.name()) {
if (key == ov::key_cache_group_size.name()) {
keyCacheGroupSizeSetExplicitly = true;
keyCacheGroupSize = groupSize;
} else {
valueCacheGroupSizeSetExplicitly = true;
valueCacheGroupSize = groupSize;
}
} catch (ov::Exception&) {
Expand Down Expand Up @@ -460,16 +462,27 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
aclFastMath = true;
}
#endif
// key/value cache precision has higher priority, if not defined use kvCachePrecision
if (!keyCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) {
keyCachePrecision = kvCachePrecision;
}
if (!valueCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) {
valueCachePrecision = kvCachePrecision;
}
// disable dynamic quantization and kv quantization for best accuracy
if (executionMode == ov::hint::ExecutionMode::ACCURACY) {
if (!fcDynamicQuantizationGroupSizeSetExplicitly) {
fcDynamicQuantizationGroupSize = 0;
}
if (!kvCachePrecisionSetExplicitly) {
kvCachePrecision = ov::element::f32;
valueCachePrecision = ov::element::f32;
}
if (!keyCachePrecisionSetExplicitly) {
keyCachePrecision = ov::element::f32;
}
if (!valueCachePrecisionSetExplicitly) {
valueCachePrecision = ov::element::f32;
}
}

if (!prop.empty())
Expand Down Expand Up @@ -524,6 +537,23 @@ void Config::applyRtInfo(const std::shared_ptr<const ov::Model>& model) {
this->fcDynamicQuantizationGroupSize =
model->get_rt_info<uint64_t>({"runtime_options", ov::hint::dynamic_quantization_group_size.name()});
}
if (!keyCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_precision.name()})) {
this->keyCachePrecision =
model->get_rt_info<ov::element::Type>({"runtime_options", ov::key_cache_precision.name()});
}
if (!valueCachePrecisionSetExplicitly &&
model->has_rt_info({"runtime_options", ov::value_cache_precision.name()})) {
this->valueCachePrecision =
model->get_rt_info<ov::element::Type>({"runtime_options", ov::value_cache_precision.name()});
}
if (!keyCacheGroupSizeSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_group_size.name()})) {
this->keyCacheGroupSize = model->get_rt_info<uint64_t>({"runtime_options", ov::key_cache_group_size.name()});
}
if (!valueCacheGroupSizeSetExplicitly &&
model->has_rt_info({"runtime_options", ov::value_cache_group_size.name()})) {
this->valueCacheGroupSize =
model->get_rt_info<uint64_t>({"runtime_options", ov::value_cache_group_size.name()});
}
}

} // namespace intel_cpu
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_cpu/src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ struct Config {
uint64_t fcDynamicQuantizationGroupSize = 32;
bool fcDynamicQuantizationGroupSizeSetExplicitly = false;
bool kvCachePrecisionSetExplicitly = false;
bool keyCachePrecisionSetExplicitly = false;
bool valueCachePrecisionSetExplicitly = false;
bool keyCacheGroupSizeSetExplicitly = false;
bool valueCacheGroupSizeSetExplicitly = false;
#if defined(OV_CPU_WITH_ACL)
bool aclFastMath = false;
#endif
Expand Down
19 changes: 15 additions & 4 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,14 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ov::N
if (!isSupportedOperation(op, errorMessage)) {
OPENVINO_THROW("CPU: " + errorMessage);
}

const auto& cpuConfig = context->getConfig();
const auto& keyCachePrecision = cpuConfig.keyCachePrecision;
const auto& valueCachePrecision = cpuConfig.valueCachePrecision;
OPENVINO_ASSERT(valueCachePrecision == keyCachePrecision,
"CPU: SDPA node only supports same key/value cache precision");
OPENVINO_ASSERT(one_of(keyCachePrecision, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8),
"CPU: SDPA only supports key/value cache precision f32, f16, bf16, u8 but gets ",
keyCachePrecision);
if (const auto node = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op)) {
m_config.config.is_causal = node->get_causal();
} else if (const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op)) {
Expand Down Expand Up @@ -1835,12 +1842,16 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M

ov::element::Type ScaledDotProductAttention::getKVCachePrecision() {
ov::element::Type kvcache_precision;
// TODO: SDPA only supports same key/value cache precision.
auto rtPrecision = getRuntimePrecision();
auto kvCachePrecisionHint = context->getConfig().kvCachePrecision;
auto keyCachePrecisionHint = context->getConfig().keyCachePrecision;
auto valueCachePrecisionHint = context->getConfig().valueCachePrecision;
bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) &&
rtPrecision != ov::element::bf16 && kvCachePrecisionHint == ov::element::f16;
rtPrecision != ov::element::bf16 &&
(keyCachePrecisionHint == ov::element::f16 && valueCachePrecisionHint == ov::element::f16);
kvcache_precision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision;
bool use_int8_kv_cache_precision = kvCachePrecisionHint == ov::element::u8;
bool use_int8_kv_cache_precision =
(keyCachePrecisionHint == ov::element::u8 && valueCachePrecisionHint == ov::element::u8);
if (use_int8_kv_cache_precision)
kvcache_precision = ov::element::u8;
else
Expand Down
24 changes: 12 additions & 12 deletions src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,14 @@ ov::Any Plugin::get_property(const std::string& name, const ov::AnyMap& options)
engConfig.fcDynamicQuantizationGroupSize);
} else if (name == ov::hint::kv_cache_precision) {
return decltype(ov::hint::kv_cache_precision)::value_type(engConfig.kvCachePrecision);
} else if (name == ov::hint::key_cache_precision) {
return decltype(ov::hint::key_cache_precision)::value_type(engConfig.keyCachePrecision);
} else if (name == ov::hint::value_cache_precision) {
return decltype(ov::hint::value_cache_precision)::value_type(engConfig.valueCachePrecision);
} else if (name == ov::hint::key_cache_group_size) {
return decltype(ov::hint::key_cache_group_size)::value_type(engConfig.keyCacheGroupSize);
} else if (name == ov::hint::value_cache_group_size) {
return decltype(ov::hint::value_cache_group_size)::value_type(engConfig.valueCacheGroupSize);
} else if (name == ov::key_cache_precision) {
return decltype(ov::key_cache_precision)::value_type(engConfig.keyCachePrecision);
} else if (name == ov::value_cache_precision) {
return decltype(ov::value_cache_precision)::value_type(engConfig.valueCachePrecision);
} else if (name == ov::key_cache_group_size) {
return decltype(ov::key_cache_group_size)::value_type(engConfig.keyCacheGroupSize);
} else if (name == ov::value_cache_group_size) {
return decltype(ov::value_cache_group_size)::value_type(engConfig.valueCacheGroupSize);
}
return get_ro_property(name, options);
}
Expand Down Expand Up @@ -443,10 +443,10 @@ ov::Any Plugin::get_ro_property(const std::string& name, const ov::AnyMap& optio
RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()),
RW_property(ov::hint::dynamic_quantization_group_size.name()),
RW_property(ov::hint::kv_cache_precision.name()),
RW_property(ov::hint::key_cache_precision.name()),
RW_property(ov::hint::value_cache_precision.name()),
RW_property(ov::hint::key_cache_group_size.name()),
RW_property(ov::hint::value_cache_group_size.name()),
RW_property(ov::key_cache_precision.name()),
RW_property(ov::value_cache_precision.name()),
RW_property(ov::key_cache_group_size.name()),
RW_property(ov::value_cache_group_size.name()),
};

OPENVINO_SUPPRESS_DEPRECATED_START
Expand Down
Loading

0 comments on commit 84f03a3

Please sign in to comment.