diff --git a/include/oneapi/dnnl/dnnl.h b/include/oneapi/dnnl/dnnl.h index 188d869758e..f818e6979af 100644 --- a/include/oneapi/dnnl/dnnl.h +++ b/include/oneapi/dnnl/dnnl.h @@ -403,7 +403,7 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode( dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask( dnnl_primitive_attr_t attr, int arg, int mask); dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims( - dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims); + dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type); /// Sets primitive attributes scaling factors for primitive operations for a /// given memory argument. The scaling factors must be passed at execution time diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index ad7ae163d0d..da795f42175 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -884,6 +884,10 @@ struct memory : public handle { bin = dnnl_bin, /// 4-bit normalized float. nf4 = dnnl_nf4, + /// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa. + f8_e8m0 = dnnl_f8_e8m0, + /// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa. + f4_e2m1 = dnnl_f4_e2m1 }; /// Returns size of data type in bytes. @@ -4142,8 +4146,8 @@ struct primitive_attr : public handle { error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask), "could not set scales primitive attribute"); } - void set_scales_dims(int arg, const memory::dims& dims) { - error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size()), + void set_scales_dims(int arg, const memory::dims& dims, memory::data_type data_type = memory::data_type::f32) { + error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(data_type)), "could not set scales primitive attribute"); } diff --git a/include/oneapi/dnnl/dnnl_common_types.h b/include/oneapi/dnnl/dnnl_common_types.h index 1dad28530ff..3c06ad47a31 100644 --- a/include/oneapi/dnnl/dnnl_common_types.h +++ b/include/oneapi/dnnl/dnnl_common_types.h @@ -104,8 +104,12 @@ typedef enum { dnnl_u4 = 12, /// 4-bit normalized float. dnnl_nf4 = 13, + /// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa + dnnl_f8_e8m0 = 14, + /// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa + dnnl_f4_e2m1 = 15, /// 1-bit integer. - dnnl_bin = 14, + dnnl_bin = 16, /// Parameter to allow internal only data_types without undefined behavior. /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index 63191acd115..8d9847dd7dc 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -169,11 +169,13 @@ const data_type_t u4 = dnnl_u4; const data_type_t boolean = dnnl_boolean; const data_type_t data_type_max = dnnl_data_type_max; -// Not exposed through API as all current uses are internal only -const data_type_t tf32 = static_cast(1 << 8); - const data_type_t bin = dnnl_bin; const data_type_t nf4 = dnnl_nf4; +const data_type_t f4_e2m1 = dnnl_f4_e2m1; +const data_type_t f8_e8m0 = dnnl_f8_e8m0; + +// Not exposed through API as all current uses are internal only +const data_type_t tf32 = static_cast(1 << 8); } // namespace data_type using fpmath_mode_t = dnnl_fpmath_mode_t; diff --git a/src/common/dnnl_debug_autogenerated.cpp b/src/common/dnnl_debug_autogenerated.cpp index 24fb05aab01..59365585b1f 100644 --- a/src/common/dnnl_debug_autogenerated.cpp +++ b/src/common/dnnl_debug_autogenerated.cpp @@ -58,6 +58,8 @@ const char *dnnl_dt2str(dnnl_data_type_t v) { if (v == dnnl_nf4) return "nf4"; if (v == dnnl_s4) return "s4"; if (v == dnnl_u4) return "u4"; + if (v == dnnl_f8_e8m0) return "f8_e8m0"; + if (v == dnnl_f4_e2m1) return "f4_e2m1"; if (v == dnnl_data_type_max) return "data_type_max"; assert(!"unknown dt"); return "unknown dt"; diff --git a/src/common/dnnl_traits.hpp b/src/common/dnnl_traits.hpp index afa4016d46b..f0810a6bfd0 100644 --- a/src/common/dnnl_traits.hpp +++ b/src/common/dnnl_traits.hpp @@ -100,6 +100,16 @@ template <> struct prec_traits { typedef uint8_t type; }; +template <> +struct prec_traits { + typedef uint8_t type; +}; + +template <> +struct prec_traits { + typedef uint8_t type; +}; + template <> struct data_traits { static constexpr data_type_t data_type = data_type::f8_e5m2; diff --git a/src/common/inner_product.cpp b/src/common/inner_product.cpp index be824728411..2d5058e5754 100644 --- a/src/common/inner_product.cpp +++ b/src/common/inner_product.cpp @@ -111,7 +111,8 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine, if (attr == nullptr) return status::success; const data_type_t src_dt = desc.src_desc.data_type; const data_type_t wei_dt = desc.weights_desc.data_type; - bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) && one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4)) || + bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) && + one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) || (one_of(src_dt, data_type::f32) && one_of(wei_dt, data_type::f16, data_type::bf16)); auto attr_mask = smask_t::none; // From oneDNN 3.5, those checks must be skipped if wei_decomp is enabled @@ -140,7 +141,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine, || utils::one_of(dst_dt, data_type::s8, data_type::u8, data_type::s32); if (engine->kind() == engine_kind::cpu) - is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4); + is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1); if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime | smask_t::src_dyn_quant_params; if (is_weight_compression) { fwd_attr_mask |= attr_mask; diff --git a/src/common/memory_zero_pad.cpp b/src/common/memory_zero_pad.cpp index 27dc081211b..963682afea4 100644 --- a/src/common/memory_zero_pad.cpp +++ b/src/common/memory_zero_pad.cpp @@ -294,6 +294,8 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { case u4: return typed_zero_pad(memory, ctx); case bin: return typed_zero_pad(memory, ctx); case nf4: return typed_zero_pad(memory, ctx); + case f8_e8m0: return typed_zero_pad(memory, ctx); + case f4_e2m1: return typed_zero_pad(memory, ctx); default: assert(!"memory is undefined"); return unimplemented; } return unimplemented; diff --git a/src/common/primitive_attr.cpp b/src/common/primitive_attr.cpp index 27afcdae7f2..0fab60d4379 100644 --- a/src/common/primitive_attr.cpp +++ b/src/common/primitive_attr.cpp @@ -656,11 +656,11 @@ status_t dnnl_primitive_attr_set_scales_mask( return attr->scales_.set(arg, mask); } status_t dnnl_primitive_attr_set_scales_dims( - primitive_attr_t *attr, int arg, const dims_t dims, int ndims) { + primitive_attr_t *attr, int arg, const dims_t dims, int ndims, data_type_t data_type) { bool ok = attr && arg >= 0 && ndims > 0 && attr->output_scales_.has_default_values(); if (!ok) return invalid_arguments; - return attr->scales_.set(arg, dims, ndims); + return attr->scales_.set(arg, dims, ndims, data_type); } status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg, diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index 3888a6f0cc4..918c4b7fe8a 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -265,11 +265,12 @@ struct runtime_scales_t : public c_compatible { return status::success; } - status_t set(const dims_t dims, int ndims) { + status_t set(const dims_t dims, int ndims, data_type_t data_type = data_type::f32) { is_set_ = true; ndims_ = ndims; mask_ = 1; utils::array_copy(dims_, dims, ndims_); + data_type_ = data_type; return status::success; } @@ -348,9 +349,9 @@ struct arg_scales_t : public c_compatible { if (!check_arg(arg)) return status::invalid_arguments; return scales_[arg].set(mask); } - status_t set(int arg, const dims_t dims, int ndims) { + status_t set(int arg, const dims_t dims, int ndims, data_type_t data_type) { if (!check_arg(arg)) return status::invalid_arguments; - return scales_[arg].set(dims, ndims); + return scales_[arg].set(dims, ndims, data_type); } status_t set(int arg, int mask, int ndims, const dims_t group_dims, diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index 9392de9c781..19adddb86bc 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -98,6 +98,8 @@ inline size_t data_type_size(data_type_t data_type) { case boolean: return sizeof(prec_traits::type); case bin: return sizeof(prec_traits::type); case nf4: return sizeof(prec_traits::type); + case f8_e8m0: return sizeof(prec_traits::type); + case f4_e2m1: return sizeof(prec_traits::type); case data_type::undef: default: assert(!"unknown data_type"); } @@ -318,7 +320,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt, /* prop_kind doesn't matter */ if (everyone_is(f32, src_dt, wei_dt)) return f32; - if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4)) return f32; + if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) return f32; if (everyone_is(f64, src_dt, wei_dt)) return f64; if (one_of(prop_kind, forward_training, forward_inference)) { @@ -1086,7 +1088,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims, bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS && utils::one_of(data_type, f8_e5m2, f8_e4m3, f16, bf16, f32, f64, - s32, s8, u8, nf4, s4, u4, bin); + s32, s8, u8, nf4, s4, u4, bin, f8_e8m0, f4_e2m1); if (!ok) return false; bool has_runtime_dims = false; diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp index 87d4966eeda..972345eaaf2 100644 --- a/src/cpu/cpu_inner_product_list.cpp +++ b/src/cpu/cpu_inner_product_list.cpp @@ -71,6 +71,11 @@ const std::map> &impl_list_map() CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) nullptr, }}, + {{forward, f32, f4_e2m1, f32}, { + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) + CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) + nullptr, + }}, {{forward, f32, s4, f32}, { CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core) CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2) @@ -123,7 +128,7 @@ const std::map> &impl_list_map() CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) nullptr, - }}, + }}, {{forward, bf16, s8, bf16}, { CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) @@ -139,6 +144,16 @@ const std::map> &impl_list_map() CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) nullptr, }}, + {{forward, bf16, f4_e2m1, f32}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, + {{forward, bf16, f4_e2m1, bf16}, { + CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) + CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) + nullptr, + }}, {{forward, bf16, s4, f32}, { CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx) CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16) diff --git a/src/cpu/cpu_primitive.hpp b/src/cpu/cpu_primitive.hpp index caf4b30bac3..c1802843377 100644 --- a/src/cpu/cpu_primitive.hpp +++ b/src/cpu/cpu_primitive.hpp @@ -75,7 +75,7 @@ VCHECK_ATTR(scales != nullptr, \ "Scales buffer for arg %d is missing", arg); \ const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \ - bool ok = scales_d.data_type() == data_type::f32 \ + bool ok = (scales_d.data_type() == data_type::f32 || scales_d.data_type() == data_type::f8_e8m0) \ && (scales_d.ndims() == 1 || scales_d.ndims() == 2); \ if (!ok) return status::invalid_arguments; \ if (scales_d.dims()[0] == 1) { \ diff --git a/src/cpu/reorder/cpu_reorder.cpp b/src/cpu/reorder/cpu_reorder.cpp index 56eb3480919..714275635e8 100644 --- a/src/cpu/reorder/cpu_reorder.cpp +++ b/src/cpu/reorder/cpu_reorder.cpp @@ -35,6 +35,7 @@ regular_impl_list_map() { {{f32, u8, 0}, ®ular_f32_u8_impl_list_map()}, {{f8_e5m2, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, {{f8_e4m3, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, + {{f8_e8m0, data_type::undef, 0}, ®ular_fp8_impl_list_map()}, {{f32, bin, 0}, ®ular_f32_bin_impl_list_map()}, {{bf16, data_type::undef, 0}, ®ular_bf16_impl_list_map()}, {{f16, data_type::undef, 0}, ®ular_f16_impl_list_map()}, @@ -47,6 +48,7 @@ regular_impl_list_map() { {{u4, f32, 0}, ®ular_u4_impl_list_map()}, {{bin, data_type::undef, 0}, ®ular_bin_impl_list_map()}, {{nf4, data_type::undef, 0}, ®ular_nf4_impl_list_map()}, + {{f4_e2m1, data_type::undef, 0}, ®ular_f4_impl_list_map()}, {{s4, data_type::undef, 0}, ®ular_s4_impl_list_map()}, {{u4, data_type::undef, 0}, ®ular_u4_impl_list_map()}, }; diff --git a/src/cpu/reorder/cpu_reorder.hpp b/src/cpu/reorder/cpu_reorder.hpp index ae3a89e8e68..29e20bc2edd 100644 --- a/src/cpu/reorder/cpu_reorder.hpp +++ b/src/cpu/reorder/cpu_reorder.hpp @@ -91,6 +91,7 @@ extern const impl_list_map_t ®ular_s4_impl_list_map(); extern const impl_list_map_t ®ular_u4_impl_list_map(); extern const impl_list_map_t ®ular_bin_impl_list_map(); extern const impl_list_map_t ®ular_nf4_impl_list_map(); +extern const impl_list_map_t ®ular_f4_impl_list_map(); extern const impl_list_map_t ®ular_s4_impl_list_map(); extern const impl_list_map_t ®ular_u4_impl_list_map(); diff --git a/src/cpu/reorder/cpu_reorder_regular_f4.cpp b/src/cpu/reorder/cpu_reorder_regular_f4.cpp new file mode 100644 index 00000000000..f42b401726c --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_f4.cpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t ®ular_f4_impl_list_map() { + static const impl_list_map_t the_map = REG_REORDER_P({ + // f4_e2m1 -> + {{f4_e2m1, data_type::undef, 0}, { + REG_SR(f4_e2m1, any, f4_e2m1, OI8i8o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i16o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i24o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i32o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI8i64o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i16o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i32o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i48o2i, fmt_order_keep) + REG_SR(f4_e2m1, any, f4_e2m1, OI16i64o2i, fmt_order_keep) + nullptr, + }}, + }); + return the_map; +} + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/reorder/cpu_reorder_regular_fp8.cpp b/src/cpu/reorder/cpu_reorder_regular_fp8.cpp index 81ef168d728..db84be29f8b 100644 --- a/src/cpu/reorder/cpu_reorder_regular_fp8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_fp8.cpp @@ -46,6 +46,12 @@ const impl_list_map_t ®ular_fp8_impl_list_map() { REG_SR(f8_e4m3, any, bf16, any, fmt_order::any, spec::reference) REG_SR(f8_e4m3, any, f32, any, fmt_order::any, spec::reference) + nullptr, + }}, + // f8_e8m0 -> + {{f8_e8m0, data_type::undef, 0}, { + REG_SR(f8_e8m0, any, f8_e8m0, any, fmt_order::any, spec::reference) + nullptr, }}, }); diff --git a/src/cpu/reorder/simple_reorder.hpp b/src/cpu/reorder/simple_reorder.hpp index 5fe251fbb05..4926040aac7 100644 --- a/src/cpu/reorder/simple_reorder.hpp +++ b/src/cpu/reorder/simple_reorder.hpp @@ -1697,7 +1697,8 @@ struct simple_reorder_impl::ndims >= 4 && tag_traits::ndims <= 6) && (type_i != dnnl_bin && type_o != dnnl_bin) - && (type_i != dnnl_nf4 && type_o != dnnl_nf4)>::type> { + && (type_i != dnnl_nf4 && type_o != dnnl_nf4) + && (type_i != dnnl_f4_e2m1 && type_o != dnnl_f4_e2m1)>::type> { PLAIN_TO_BLOCKED_IS_APPLICABLE(); GET_SCRATCHPAD_SIZE_ZERO(); @@ -2004,7 +2005,7 @@ template struct simple_reorder_impl::block_dims == bd::_AB && - utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4) && + utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1) && type_i == type_o>::type> { static bool is_applicable(const memory_desc_wrapper &input_d, @@ -2483,8 +2484,8 @@ struct simple_reorder_impl::type> { static bool is_applicable(const memory_desc_wrapper &input_d, const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index 28a7dd25364..9fa023a6e87 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -285,6 +285,10 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa, brg->with_wei_decomp_scales = !wei_scales.has_default_values(); brg->wei_decomp_scales_group_size = wei_d.dims()[1]; if (brg->with_wei_decomp_scales) { + brg->wei_decomp_scales_dt = wei_scales.data_type_; + if (!one_of(brg->wei_decomp_scales_dt, f32, f8_e8m0)) + return status::unimplemented; + auto ld_dim = wei_scales.dims_[0]; brg->wei_decomp_scales_stride = ld_dim > 1 ? ld_dim : 0; brg->wei_decomp_scales_group_size = wei_d.dims()[1] / wei_scales.dims_[1]; diff --git a/src/cpu/x64/brgemm/brgemm_types.hpp b/src/cpu/x64/brgemm/brgemm_types.hpp index 4166aace702..d85f690a86e 100644 --- a/src/cpu/x64/brgemm/brgemm_types.hpp +++ b/src/cpu/x64/brgemm/brgemm_types.hpp @@ -315,6 +315,7 @@ struct brgemm_desc_t { int wei_decomp_zero_points_stride = 0; int wei_decomp_scales_group_size = 0; int wei_decomp_zero_points_group_size = 0; + impl::data_type_t wei_decomp_scales_dt = data_type::undef; impl::data_type_t wei_decomp_zero_points_dt = data_type::undef; bool with_src_dyn_quant = false; int src_scales_group_size = 0; diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index 5c97e93fcb8..8cf994b256a 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -53,9 +53,9 @@ void init_kernel_datatype( brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8) && utils::one_of(dt_b, data_type::u8, data_type::s8, data_type::u4); brg->is_bf16 = one_of(dt_a, data_type::bf16) && - one_of(dt_b, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4); + one_of(dt_b, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1); brg->is_f32 = one_of(dt_a, data_type::f32) && - one_of(dt_b, data_type::f32, data_type::f16, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4); + one_of(dt_b, data_type::f32, data_type::f16, data_type::bf16, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1); brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b); brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3) && one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3); @@ -221,7 +221,8 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) { if (brg->is_int8 && !brg->has_int8_vnni) max_bcast_block -= 2; if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl == avx2) max_bcast_block -= 5; - if (one_of(brg->dt_b, data_type::nf4) && brg->isa_impl != avx2) max_bcast_block -= 1; + if (one_of(brg->dt_b, data_type::f4_e2m1) && brg->isa_impl == avx2) max_bcast_block -= 2; + if (one_of(brg->dt_b, data_type::nf4, data_type::f4_e2m1) && brg->isa_impl != avx2) max_bcast_block -= 1; if (brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride == 0) max_bcast_block -= 1; if (brg->with_src_dyn_quant) max_bcast_block -= 2; if (brg->with_src_dyn_quant && brg->with_wei_decomp_zero_points && brg->wei_decomp_zero_points_stride != 0) max_bcast_block -= adj_ld_block2; @@ -293,7 +294,7 @@ status_t brgemm_blocking(brgemm_desc_t *brg) { = (brg->is_f16 && brg->isa_impl == avx512_core_fp16) ? 1 : data_type_vnni_granularity(brg->dt_a); - int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4) ? 32 : 4; + int rd_unroll = one_of(brg->dt_b, data_type::nf4, data_type::u4, data_type::s4, data_type::f4_e2m1) ? 32 : 4; if (brg->with_grouped_wei_decomp) { auto min_group_size = nstl::min(brg->wei_decomp_scales_group_size, brg->wei_decomp_zero_points_group_size); min_group_size = nstl::min(min_group_size, brg->src_scales_group_size); @@ -911,7 +912,7 @@ void init_brgemm_conf(brgemm_desc_t *brg, cpu_isa_t isa, && one_of(brg->isa_impl, avx2_vnni_2, avx512_core_fp16)) || (brg->is_bf16 && brg->isa_impl == avx2_vnni_2) || (one_of(brg->dt_a, data_type::f32, data_type::bf16) && - one_of(brg->dt_b, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4)) + one_of(brg->dt_b, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) || (one_of(brg->dt_a, data_type::f32) && one_of(brg->dt_b, data_type::bf16, data_type::f16)); brg->rd_step = has_no_vnni_compute_instruction ? 1 diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index edfcb745def..5c53e671a5d 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -48,7 +48,8 @@ struct jit_brgemm_kernel_t : public jit_generator { , max_effective_vregs( isa_num_vregs(brg.isa_impl) - (brg.is_int8 && !brg.has_int8_vnni ? 2 : (brg.is_fp8_via_convert() ? 5 : 0)) - (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl == avx2 ? 5 : 0) - - (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl != avx2 ? 1 : 0) + - (one_of(brg.dt_b, data_type::f4_e2m1) && brg.isa_impl == avx2 ? 2 : 0) + - (one_of(brg.dt_b, data_type::nf4, data_type::f4_e2m1) && brg.isa_impl != avx2 ? 1 : 0) - (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride == 0 ? 1 : 0) - (brg.with_src_dyn_quant ? 2 : 0) - (brg.with_src_dyn_quant && brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0 ? brg.ld_block2 : 0)) { @@ -460,7 +461,7 @@ int jit_brgemm_kernel_t::A_offset( template int jit_brgemm_kernel_t::B_offset( int ld, int rd, bool is_amx) const noexcept { - int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; if (is_amx) { return brg.typesize_B * (brg.rd_step * ld * brg.ld_block) / typesize_scale; } else { @@ -494,14 +495,14 @@ int jit_brgemm_kernel_t::rdb_A_offset() const noexcept { template int jit_brgemm_kernel_t::rdb_B_offset() const noexcept { - int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; return brg.typesize_B * brg.rd_block * brg.LDB / typesize_scale; } template int jit_brgemm_kernel_t::ldb_B_offset( int ld_block2, bool is_tail) const noexcept { - int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + int typesize_scale = one_of(brg.dt_b, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; return (is_tail) ? brg.typesize_B * brg.ldb_tail * brg.ld_step / typesize_scale : brg.typesize_B * ld_block2 * brg.ld_block * brg.ld_step / typesize_scale; } @@ -589,8 +590,8 @@ int jit_brgemm_kernel_t::scales_offset( template int jit_brgemm_kernel_t::wei_scales_offset( int ld, bool is_tail) const noexcept { - return (is_tail) ? sizeof(float) * brg.ldb_tail - : sizeof(float) * ld * brg.ld_block; + return (is_tail) ? types::data_type_size(brg.wei_decomp_scales_dt) * brg.ldb_tail + : types::data_type_size(brg.wei_decomp_scales_dt) * ld * brg.ld_block; } template @@ -2556,6 +2557,40 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, } }; + auto load_scales = [&](Vmm vmm_scales, Xbyak::Address addr) { + if (brg.wei_decomp_scales_stride == 0) { + switch (brg.wei_decomp_scales_dt) { + case data_type::f32: { + uni_vbroadcastss(vmm_scales, addr); + break; + } + case data_type::f8_e8m0: { + auto xmm_scales = Xmm(vmm_scales.getIdx()); + auto reg_ptr_32 = Reg32(reg_ptr.getIdx()); + movzx(reg_ptr_32, addr); + uni_vmovq(xmm_scales, reg_ptr); + uni_vpslld(xmm_scales, xmm_scales, 23); + uni_vbroadcastss(vmm_scales, xmm_scales); + break; + } + default: assert(!"unsupported data type"); + } + } else { + switch (brg.wei_decomp_scales_dt) { + case data_type::f32: { + uni_vmovups(vmm_scales, addr); + break; + } + case data_type::f8_e8m0: { + uni_vpmovzxbd(vmm_scales, addr); + uni_vpslld(vmm_scales, vmm_scales, 23); + break; + } + default: assert(!"unsupported data type"); + } + } + }; + mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); @@ -2565,6 +2600,9 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, auto vmm_lookup = Vmm(isa_num_vregs(brg.isa_impl) - 1); auto vmm_lookup_low = Vmm(isa_num_vregs(brg.isa_impl) - 3); auto vmm_lookup_high = Vmm(isa_num_vregs(brg.isa_impl) - 4); + + auto vmm_mask_signed_bit = Vmm(isa_num_vregs(brg.isa_impl) - 2); + if (brg.dt_b == data_type::nf4) { static const float lookup[16] = { -1.0, @@ -2611,6 +2649,34 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, uni_vmovups(vmm_lookup, ptr[reg_ptr]); vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 2); } + } else if (brg.dt_b == data_type::f4_e2m1) { + static const float lookup[16] = { + 0.0f, 0.5f, + 1.0f, 1.5f, + 2.0f, 3.0f, + 4.0f, 6.0f, + -0.0f, -0.5f, + -1.0f, -1.5f, + -2.0f, -3.0f, + -4.0f, -6.0f + }; + + static const uint32_t mask_signed_bit[8] = { + 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0x80000000, 0x80000000, + }; + + if (brg.isa_impl == avx2) { + mov(reg_ptr, (size_t)lookup); + uni_vmovups(vmm_lookup, ptr[reg_ptr]); + mov(reg_ptr, (size_t)mask_signed_bit); + uni_vmovups(vmm_mask_signed_bit, ptr[reg_ptr]); + vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 3); + } else { + mov(reg_ptr, (size_t)lookup); + uni_vmovups(vmm_lookup, ptr[reg_ptr]); + vmm_zero_points = Vmm(isa_num_vregs(brg.isa_impl) - 2); + } } mov(reg_local_wei_scales, ptr[rsp + reg_aux2_wei_scales_offs_]); @@ -2641,11 +2707,10 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, } uni_vcvtdq2ps(vmm_load, vmm_load); } else if (brg.dt_b == data_type::s4) { + uni_vpmovsxbd(vmm_load, addr); if (rd % 2 == 0) { - uni_vpmovsxbd(vmm_load, addr); vpsrad(vmm_load, vmm_load, 4); } else { - uni_vpmovsxbd(vmm_load, addr); uni_vpslld(vmm_load, vmm_load, 28); vpsrad(vmm_load, vmm_load, 28); } @@ -2670,6 +2735,29 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, } else { vpermd(vmm_load, vmm_load, vmm_lookup); } + } else if (brg.dt_b == data_type::f4_e2m1) { + if (brg.isa_impl == avx2) { + uni_vpmovsxbd(vmm_load, addr); + if (rd % 2 == 0) { + vpsrad(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + vpsrad(vmm_load, vmm_load, 28); + } + auto mask = bcst(); + uni_vpand(mask, vmm_load, vmm_mask_signed_bit); + vpermd(vmm_load, vmm_load, vmm_lookup); + uni_vorps(vmm_load, vmm_load, mask); + } else { + uni_vpmovzxbd(vmm_load, addr); + if (rd % 2 == 0) { + uni_vpsrld(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + uni_vpsrld(vmm_load, vmm_load, 28); + } + vpermd(vmm_load, vmm_load, vmm_lookup); + } } else { assert(!"unsupported combination"); } @@ -2685,9 +2773,9 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, if (brg.with_wei_decomp_scales && brg.bd_block != 1) { if (brg.wei_decomp_scales_stride == 0) { - uni_vbroadcastss(bcst(), ptr[reg_local_wei_scales]); + load_scales(bcst(), ptr[reg_local_wei_scales]); } else { - uni_vmovups(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]); + load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]); } uni_vmulps(vmm_load, vmm_load, bcst()); } @@ -2737,7 +2825,7 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, for (int ld = 0; ld < ld_block2; ld++) { auto vmm_accm_tmp = accm_tmp(ld_block2, 0, ld); auto vmm_accm = accm(ld_block2, 0, ld); - uni_vmovups(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]); + load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]); uni_vfmadd231ps(vmm_accm, vmm_accm_tmp, bcst()); } } @@ -2873,7 +2961,7 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, if (brg.with_wei_decomp_scales && brg.wei_decomp_scales_stride != 0) { ic_group_shift(reg_aux_wei_scales_offs_, reg_aux2_wei_scales_offs_, - brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * sizeof(float)); + brg.wei_decomp_scales_group_size, brg.wei_decomp_scales_stride * types::data_type_size(brg.wei_decomp_scales_dt)); } if (brg.with_wei_decomp_zero_points && brg.wei_decomp_zero_points_stride != 0) { diff --git a/src/cpu/x64/cpu_isa_traits.hpp b/src/cpu/x64/cpu_isa_traits.hpp index 7e106ee3dc8..21c5d21cce8 100644 --- a/src/cpu/x64/cpu_isa_traits.hpp +++ b/src/cpu/x64/cpu_isa_traits.hpp @@ -536,7 +536,8 @@ inline size_t data_type_vnni_granularity(const data_type_t data_type) { case bf16: case s4: case u4: - case nf4: return size_t(2); + case nf4: + case f4_e2m1: return size_t(2); case f8_e5m2: case f8_e4m3: case s8: diff --git a/src/cpu/x64/jit_brgemm_inner_product.cpp b/src/cpu/x64/jit_brgemm_inner_product.cpp index 32caa916725..5879b5b5a89 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product.cpp @@ -91,21 +91,23 @@ status_t brgemm_inner_product_fwd_t::execute_forward( const auto &jbgp = pd()->jbgp_; DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); - DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); + DEFINE_ARG_SCALES_BUFFER(wei_scales_f, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); const int wei_scale_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; const float *oscales - = scale_utils::precompute_scales(scratchpad, src_scales, wei_scales, + = scale_utils::precompute_scales(scratchpad, src_scales, wei_scales_f, pd()->IC(), pd()->OC(), false, wei_scale_mask == (1 << 0), pd()->attr(), jit_scale_precompute_.get()); + auto wei_scales = reinterpret_cast(wei_scales_f); DEFINE_ZERO_POINTS_BUFFER_TYPED(wei_zero_points, DNNL_ARG_WEIGHTS, uint8_t); const auto wei_scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); const auto wei_zero_points_d = ctx.memory_mdw(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); int wei_scales_oc_stride = wei_scales_d.dims()[0] > 1 ? 1 : 0; int wei_zero_points_oc_stride = wei_zero_points_d.dims()[0] > 1 ? 1 : 0; + size_t wei_scales_dt_size = jbgp.wei_decomp_scales_dt == data_type::undef ? 0 : types::data_type_size(jbgp.wei_decomp_scales_dt); size_t wei_zero_points_dt_size = jbgp.wei_decomp_zero_points_dt == data_type::undef ? 0 : types::data_type_size(jbgp.wei_decomp_zero_points_dt); if (jbgp.weights_decompression) { // weights decompression algorithm requires weights scales to be @@ -116,8 +118,8 @@ status_t brgemm_inner_product_fwd_t::execute_forward( if (jbgp.oc % jbgp.simd_w != 0) { if (!pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()) { auto dims = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).dims_; - auto decomp_scales_buf = scratchpad.template get(key_decompression_scales); - std::memcpy(decomp_scales_buf, wei_scales, dims[0] * dims[1] * sizeof(float)); + auto decomp_scales_buf = scratchpad.template get(key_decompression_scales); + std::memcpy(decomp_scales_buf, wei_scales, dims[0] * dims[1] * wei_scales_dt_size); wei_scales = decomp_scales_buf; } @@ -130,7 +132,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( } } else { oscales = precompute_scales(ctx.get_scratchpad_grantor(), - src_scales, wei_scales, pd()->OC(), pd()->attr()); + src_scales, wei_scales_f, pd()->OC(), pd()->attr()); } const size_t src_dt_size = types::data_type_size(jbgp.src_dt); @@ -374,7 +376,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( (*brg_decomp_kernel_)(&dcomp_params); addr_batch[b].ptr.B = decomp_buf; } else if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t::prepack) { - int typesize_scale = one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + int typesize_scale = one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt) / typesize_scale; auto weights_ptr = reinterpret_cast(&weights[w_off]); @@ -382,9 +384,9 @@ status_t brgemm_inner_product_fwd_t::execute_forward( auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr + wei_ic_stride * b * ic_blocks_per_batch; const int ic_internal_block = pd()->jbgp_.wei_dt == data_type::bf16 || - one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; - auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc; + auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc * wei_scales_dt_size; if (jbgp.with_grouped_weights_decompression) { weights_decompression_runtime_params_t rt_params = {}; @@ -402,7 +404,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt) / typesize_scale; rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block *jbgp.oc_block * types::data_type_size(jbgp.wei_dt); - rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0]; + rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0] * wei_scales_dt_size; rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims()[0] * wei_zero_points_dt_size; rt_params.ic_size = nstl::min(group_size, ic_size - icb_idx * group_size); (*brg_weights_decomp_kernel_)(&rt_params); @@ -419,7 +421,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( addr_batch[b].ptr.B = decomp_buf; } else { - int typesize_scale = one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + int typesize_scale = one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; addr_batch[b].ptr.B = weights + wei_offset / typesize_scale; } } @@ -428,7 +430,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_zero_points_offset = 0; int src_scales_offset = 0; if (jbgp.weights_decompression) { - wei_scales_offset = wei_scales_oc_stride * oc; + wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); } @@ -480,7 +482,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( = (wei_cur_ocb + wei_ic_stride * (icb + ic_block)); if (jbgp.weights_decompression && jbgp.wei_decomp_algo == weights_decomp_kind_t::prepack) { - int typesize_scale = one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + int typesize_scale = one_of(jbgp.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; auto w_off = wei_offset * types::data_type_size(jbgp.orig_wei_dt) / types::data_type_size(jbgp.wei_dt) / typesize_scale; auto weights_ptr = reinterpret_cast(&weights[w_off]); @@ -488,9 +490,9 @@ status_t brgemm_inner_product_fwd_t::execute_forward( auto decomp_buf = decomp_buf_global + ithr * decomp_buf_per_thr; const int ic_internal_block = pd()->jbgp_.wei_dt == data_type::bf16 || - one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1, data_type::f4_e2m1) ? 2 : 1; auto wei_zero_points_ptr = wei_zero_points + wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; - auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc; + auto wei_scales_ptr = wei_scales + wei_scales_oc_stride * oc * wei_scales_dt_size; if (jbgp.with_grouped_weights_decompression) { weights_decompression_runtime_params_t rt_params = {}; @@ -508,7 +510,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( rt_params.weights_ptr = weights_ptr + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.orig_wei_dt) / typesize_scale; rt_params.decomp_buffer_ptr = decomp_buf + ic_idx * ic_internal_block * jbgp.oc_block * types::data_type_size(jbgp.wei_dt); - rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0]; + rt_params.scales_ptr = wei_scales_ptr + scales_idx * wei_scales_d.dims()[0] * wei_scales_dt_size; rt_params.zero_points_ptr = wei_zero_points_ptr + zero_points_idx * wei_zero_points_d.dims()[0] * wei_zero_points_dt_size; rt_params.ic_size = nstl::min(group_size, ic_size - icb_idx * group_size); (*brg_weights_decomp_kernel_)(&rt_params); @@ -525,7 +527,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( addr_batch[0].ptr.B = decomp_buf; } else { - int typesize_scale = one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + int typesize_scale = one_of(jbgp.wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; addr_batch[0].ptr.B = weights + wei_offset / typesize_scale; } @@ -533,7 +535,7 @@ status_t brgemm_inner_product_fwd_t::execute_forward( int wei_zero_points_offset = 0; int src_scales_offset = 0; if (jbgp.weights_decompression) { - wei_scales_offset = wei_scales_oc_stride * oc; + wei_scales_offset = wei_scales_oc_stride * oc * wei_scales_dt_size; wei_zero_points_offset = wei_zero_points_oc_stride * oc * wei_zero_points_dt_size; src_scales_offset = n * div_up(jbgp.ic, jbgp.src_quant_group_size); } diff --git a/src/cpu/x64/jit_brgemm_inner_product.hpp b/src/cpu/x64/jit_brgemm_inner_product.hpp index 9cacd52218a..118b6b79fc4 100644 --- a/src/cpu/x64/jit_brgemm_inner_product.hpp +++ b/src/cpu/x64/jit_brgemm_inner_product.hpp @@ -63,7 +63,7 @@ struct brgemm_inner_product_fwd_t : public primitive_t { auto dst_dt = invariant_dst_md()->data_type; auto wei_dt = invariant_wei_md()->data_type; const bool is_int8 = one_of(src_dt, u8, s8); - const bool is_wei_decomp = (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4)) || + const bool is_wei_decomp = (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) || (one_of(src_dt, f32) && one_of(wei_dt, f16, bf16)); using skip_mask_t = primitive_attr_t::skip_mask_t; @@ -241,13 +241,14 @@ struct brgemm_inner_product_fwd_t : public primitive_t { weights_decompression_compile_params_t jcp = {}; jcp.oc_size = pd()->jbgp_.oc_block; jcp.ic_internal_size = pd()->jbgp_.wei_dt == data_type::bf16 || - utils::one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + utils::one_of(pd()->jbgp_.orig_wei_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; jcp.with_scales = !pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values(); jcp.broadcast_scales = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).dims_[0] == 1; jcp.with_zero_points = !pd()->attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS); jcp.broadcast_zero_points = pd()->attr()->zero_points_.get_dims(DNNL_ARG_WEIGHTS)[0] == 1; jcp.weights_dt = pd()->jbgp_.orig_wei_dt; jcp.decomp_buffer_dt = pd()->jbgp_.wei_dt; + jcp.scales_dt = pd()->jbgp_.wei_decomp_scales_dt; jcp.zero_points_dt = pd()->jbgp_.wei_decomp_zero_points_dt; if (is_superset(pd()->jbgp_.isa, avx512_core)) { diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index 40254c97113..85a152e2498 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -253,7 +253,7 @@ jit_brgemm_ip_conf_t::get_desired_weights_tag() const { pick(n_sp_dims, OI4i8o4i, OwI4i8o4i, OhwI4i8o4i, OdhwI4i8o4i)}}; } - } else if (jbgp.weights_decompression && one_of(jbgp.orig_wei_dt, nf4, s4, u4)) { + } else if (jbgp.weights_decompression && one_of(jbgp.orig_wei_dt, nf4, s4, u4, f4_e2m1)) { if (jbgp.with_src_dynamic_quant) { return {{64, pick(n_sp_dims, OI16i64o4i, OIw16i64o4i, @@ -696,25 +696,17 @@ status_t jit_brgemm_ip_fwd_conf_t::init_conf(cpu_isa_t isa, jbgp.gemm_batch_size = nb_k_blocking; } - if (jbgp.with_src_dynamic_quant) { - if ((jbgp.nb_ic_blocking * k_blk) % jbgp.src_quant_group_size != 0) { - jbgp.nb_ic_blocking = 64; - } - jbgp.K = k_blk * jbgp.nb_ic_blocking; - jbgp.gemm_batch_size = 1; - } - // Current implementation of grouped weights decompression algorithm requires K size to be aligned on group size. // Besides that "batched" usage of brgemm block is not covered, so forcing the value to 1. - if (jbgp.with_grouped_weights_decompression && !jbgp.with_src_dynamic_quant) { + if (jbgp.with_grouped_weights_decompression || jbgp.with_src_dynamic_quant) { auto min_ic_group_size = std::min(jbgp.wei_scales_ic_group_size, jbgp.wei_zero_points_ic_group_size); min_ic_group_size = std::min(min_ic_group_size, jbgp.src_quant_group_size); - if (jbgp.K % min_ic_group_size != 0 || jbgp.gemm_batch_size != 1) { - jbgp.nthr_ic_b = 1; - jbgp.nb_ic_blocking = min_ic_group_size / jbgp.ic_block; - jbgp.K = jbgp.ic_block * jbgp.nb_ic_blocking; - jbgp.gemm_batch_size = 1; + if ((jbgp.nb_ic_blocking * k_blk) % min_ic_group_size != 0) { + jbgp.nb_ic_blocking = 64; } + jbgp.K = k_blk * jbgp.nb_ic_blocking; + jbgp.gemm_batch_size = 1; + jbgp.nthr_ic_b = 1; } const int nthrs_other = jbgp.nthr / jbgp.nthr_ic_b; @@ -1418,11 +1410,12 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, jbgp.dst_dt = dst_d.data_type(); jbgp.wei_dt = weights_d.data_type(); - jbgp.weights_decompression = (one_of(jbgp.src_dt, f32, bf16) && one_of(jbgp.wei_dt, u8, s8, nf4, s4, u4)) || + jbgp.weights_decompression = (one_of(jbgp.src_dt, f32, bf16) && one_of(jbgp.wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) || (one_of(jbgp.src_dt, f32) && one_of(jbgp.wei_dt, f16, bf16)); jbgp.wei_decomp_algo = weights_decomp_kind_t::immediate; jbgp.orig_wei_dt = jbgp.wei_dt; jbgp.with_grouped_weights_decompression = false; + jbgp.wei_decomp_scales_dt = data_type::undef; jbgp.wei_decomp_zero_points_dt = data_type::undef; jbgp.with_src_dynamic_quant = false; if (jbgp.weights_decompression) { @@ -1433,6 +1426,9 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, jbgp.wei_scales_ic_group_size = jbgp.ic; auto wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); + jbgp.wei_decomp_scales_dt = wei_scales.data_type_; + if (!one_of(jbgp.wei_decomp_scales_dt, f32, f8_e8m0)) + return status::unimplemented; if (!wei_scales.has_default_values() && wei_scales.dims_[1] != 1) { jbgp.with_grouped_weights_decompression = true; jbgp.wei_scales_ic_group_size = div_up(jbgp.ic, wei_scales.dims_[1]); @@ -1471,7 +1467,9 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, return status::unimplemented; if (jbgp.with_src_dynamic_quant) { - if (!(one_of(jbgp.wei_dt, u4, u8) && one_of(jbgp.wei_decomp_zero_points_dt, u8, data_type::undef))) + if (!(one_of(jbgp.wei_dt, u4, u8) && + one_of(jbgp.wei_decomp_scales_dt, f32) && + one_of(jbgp.wei_decomp_zero_points_dt, u8, data_type::undef))) return status::unimplemented; const size_t simd_width = 16; @@ -1687,7 +1685,8 @@ void jit_brgemm_ip_conf_t::init_scratchpad_base( types::data_type_size(jbgp.wei_dt)); } if (jbgp.wei_decomp_scales_buffer_size) - scratchpad.book(key_decompression_scales, jbgp.wei_decomp_scales_buffer_size, sizeof(float)); + scratchpad.book(key_decompression_scales, jbgp.wei_decomp_scales_buffer_size, + types::data_type_size(jbgp.wei_decomp_scales_dt)); if (jbgp.wei_decomp_zero_points_buffer_size) scratchpad.book(key_decompression_zero_points, jbgp.wei_decomp_zero_points_buffer_size, types::data_type_size(jbgp.wei_decomp_zero_points_dt)); diff --git a/src/cpu/x64/jit_brgemm_primitive_conf.hpp b/src/cpu/x64/jit_brgemm_primitive_conf.hpp index 47fe820069c..5f7ebd2cf0e 100644 --- a/src/cpu/x64/jit_brgemm_primitive_conf.hpp +++ b/src/cpu/x64/jit_brgemm_primitive_conf.hpp @@ -108,6 +108,7 @@ struct jit_brgemm_primitive_conf_t { size_t wei_zero_points_ic_group_size; size_t wei_decomp_scales_buffer_size; size_t wei_decomp_zero_points_buffer_size; + data_type_t wei_decomp_scales_dt; data_type_t wei_decomp_zero_points_dt; bool with_src_dynamic_quant; diff --git a/src/cpu/x64/jit_brgemm_weights_decompression_kernel.cpp b/src/cpu/x64/jit_brgemm_weights_decompression_kernel.cpp index d2bded12cb0..aec93af9b9c 100644 --- a/src/cpu/x64/jit_brgemm_weights_decompression_kernel.cpp +++ b/src/cpu/x64/jit_brgemm_weights_decompression_kernel.cpp @@ -51,6 +51,15 @@ void jit_brgemm_weights_decompression_kernel_t::init_decomp_params(std::fun uni_vbroadcastss(vmm_params(ocb), xmm_params); break; } + case data_type::f8_e8m0: { + auto xmm_params = Xmm(vmm_params(ocb).getIdx()); + auto reg_tmp_32 = Reg32(reg_tmp.getIdx()); + movzx(reg_tmp_32, ptr[reg_params]); + uni_vmovq(xmm_params, reg_tmp); + uni_vpslld(xmm_params, xmm_params, 23); + uni_vbroadcastss(vmm_params(ocb), xmm_params); + break; + } default: assert(!"unsupported data type"); } } else { @@ -65,6 +74,11 @@ void jit_brgemm_weights_decompression_kernel_t::init_decomp_params(std::fun uni_vcvtdq2ps(vmm_params(ocb), vmm_params(ocb)); break; } + case data_type::f8_e8m0: { + uni_vpmovzxbd(vmm_params(ocb), load_addr); + uni_vpslld(vmm_params(ocb), vmm_params(ocb), 23); + break; + } default: assert(!"unsupported data type"); } } @@ -128,6 +142,31 @@ void jit_brgemm_weights_decompression_kernel_t::load_weights(Vmm vmm_load, } break; } + case data_type::f4_e2m1: { + if (isa == avx2) { + uni_vpmovsxbd(vmm_load, addr); + if (ic % 2 == 0) { + vpsrad(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + vpsrad(vmm_load, vmm_load, 28); + } + auto mask = vmm_weights(1); + uni_vpand(mask, vmm_load, vmm_mask()); + vpermd(vmm_load, vmm_load, vmm_lookup()); + uni_vorps(vmm_load, vmm_load, mask); + } else { + uni_vpmovzxbd(vmm_load, addr); + if (ic % 2 == 0) { + uni_vpsrld(vmm_load, vmm_load, 4); + } else { + uni_vpslld(vmm_load, vmm_load, 28); + uni_vpsrld(vmm_load, vmm_load, 28); + } + vpermd(vmm_load, vmm_load, vmm_lookup()); + } + break; + } case data_type::f16: { vcvtph2ps(vmm_load, addr); break; @@ -211,10 +250,36 @@ void jit_brgemm_weights_decompression_kernel_t::generate() { mov(reg_tmp, (size_t)lookup); uni_vmovups(vmm_lookup(), ptr[reg_tmp]); } + } else if (jcp_.weights_dt == data_type::f4_e2m1) { + static const float lookup[16] = { + 0.0f, 0.5f, + 1.0f, 1.5f, + 2.0f, 3.0f, + 4.0f, 6.0f, + -0.0f, -0.5f, + -1.0f, -1.5f, + -2.0f, -3.0f, + -4.0f, -6.0f + }; + + static const uint32_t mask_signed_bit[8] = { + 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0x80000000, 0x80000000, + }; + + if (isa == avx2) { + mov(reg_tmp, (size_t)lookup); + uni_vmovups(vmm_lookup(), ptr[reg_tmp]); + mov(reg_tmp, (size_t)mask_signed_bit); + uni_vmovups(vmm_mask(), ptr[reg_tmp]); + } else { + mov(reg_tmp, (size_t)lookup); + uni_vmovups(vmm_lookup(), ptr[reg_tmp]); + } } if (jcp_.with_scales) - init_decomp_params(std::bind(&jit_brgemm_weights_decompression_kernel_t::vmm_scales, this, _1), reg_scales, jcp_.broadcast_scales, data_type::f32); + init_decomp_params(std::bind(&jit_brgemm_weights_decompression_kernel_t::vmm_scales, this, _1), reg_scales, jcp_.broadcast_scales, jcp_.scales_dt); if (jcp_.with_zero_points) init_decomp_params(std::bind(&jit_brgemm_weights_decompression_kernel_t::vmm_zero_points, this, _1), reg_zero_points, jcp_.broadcast_zero_points, jcp_.zero_points_dt); @@ -225,7 +290,7 @@ void jit_brgemm_weights_decompression_kernel_t::generate() { Xbyak::Label ic_end_label; size_t weights_dt_size = types::data_type_size(jcp_.weights_dt); - size_t typesize_scale = one_of(jcp_.weights_dt, data_type::nf4, data_type::s4, data_type::u4) ? 2 : 1; + size_t typesize_scale = one_of(jcp_.weights_dt, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1) ? 2 : 1; size_t decomp_buf_dt_size = types::data_type_size(jcp_.decomp_buffer_dt); L(ic_loop_label); diff --git a/src/cpu/x64/jit_brgemm_weights_decompression_kernel.hpp b/src/cpu/x64/jit_brgemm_weights_decompression_kernel.hpp index 2d189a4c30f..377b400e0d4 100644 --- a/src/cpu/x64/jit_brgemm_weights_decompression_kernel.hpp +++ b/src/cpu/x64/jit_brgemm_weights_decompression_kernel.hpp @@ -39,6 +39,7 @@ struct weights_decompression_compile_params_t { size_t ic_internal_size; data_type_t weights_dt; data_type_t decomp_buffer_dt; + data_type_t scales_dt; data_type_t zero_points_dt; }; @@ -98,10 +99,6 @@ struct jit_brgemm_weights_decompression_kernel_t : public jit_weights_decompress return Vmm(ocb); } - Vmm vmm_mask(int ic) { - return Vmm(n_vregs - ic - 2); - } - Vmm vmm_tmp(int idx) { return Vmm(n_vregs - idx - 1); } @@ -109,6 +106,7 @@ struct jit_brgemm_weights_decompression_kernel_t : public jit_weights_decompress Vmm vmm_lookup() { return vmm_tmp(0); } Vmm vmm_lookup_low() { return vmm_tmp(0); } Vmm vmm_lookup_high() { return vmm_tmp(1); } + Vmm vmm_mask() { return vmm_tmp(1); } Vmm vmm_mask8() { return vmm_tmp(2); } Vmm vmm_mask7() { return vmm_tmp(3); }