Skip to content

Commit

Permalink
[FORK][FEATURE] IP weights compression: mxfp4 (wei=f4e2m1, scales=f8e…
Browse files Browse the repository at this point in the history
…8m0) support
  • Loading branch information
dmitry-gorokhov committed Jul 31, 2024
1 parent b1c677d commit 6b99866
Show file tree
Hide file tree
Showing 29 changed files with 345 additions and 83 deletions.
2 changes: 1 addition & 1 deletion include/oneapi/dnnl/dnnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
dnnl_primitive_attr_t attr, int arg, int mask);
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims(
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims);
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims, dnnl_data_type_t data_type);

/// Sets primitive attributes scaling factors for primitive operations for a
/// given memory argument. The scaling factors must be passed at execution time
Expand Down
8 changes: 6 additions & 2 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,10 @@ struct memory : public handle<dnnl_memory_t> {
bin = dnnl_bin,
/// 4-bit normalized float.
nf4 = dnnl_nf4,
/// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa.
f8_e8m0 = dnnl_f8_e8m0,
/// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa.
f4_e2m1 = dnnl_f4_e2m1
};

/// Returns size of data type in bytes.
Expand Down Expand Up @@ -4142,8 +4146,8 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
"could not set scales primitive attribute");
}
void set_scales_dims(int arg, const memory::dims& dims) {
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size()),
void set_scales_dims(int arg, const memory::dims& dims, memory::data_type data_type = memory::data_type::f32) {
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size(), memory::convert_to_c(data_type)),
"could not set scales primitive attribute");
}

Expand Down
6 changes: 5 additions & 1 deletion include/oneapi/dnnl/dnnl_common_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ typedef enum {
dnnl_u4 = 12,
/// 4-bit normalized float.
dnnl_nf4 = 13,
/// 8-bit floating-point with a 8-bit exponent and a 0-bit mantissa
dnnl_f8_e8m0 = 14,
/// 4-bit floating-point with a 2-bit exponent and a 1-bit mantissa
dnnl_f4_e2m1 = 15,
/// 1-bit integer.
dnnl_bin = 14,
dnnl_bin = 16,

/// Parameter to allow internal only data_types without undefined behavior.
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.
Expand Down
8 changes: 5 additions & 3 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ const data_type_t u4 = dnnl_u4;
const data_type_t boolean = dnnl_boolean;
const data_type_t data_type_max = dnnl_data_type_max;

// Not exposed through API as all current uses are internal only
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);

const data_type_t bin = dnnl_bin;
const data_type_t nf4 = dnnl_nf4;
const data_type_t f4_e2m1 = dnnl_f4_e2m1;
const data_type_t f8_e8m0 = dnnl_f8_e8m0;

// Not exposed through API as all current uses are internal only
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);
} // namespace data_type

using fpmath_mode_t = dnnl_fpmath_mode_t;
Expand Down
2 changes: 2 additions & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ const char *dnnl_dt2str(dnnl_data_type_t v) {
if (v == dnnl_nf4) return "nf4";
if (v == dnnl_s4) return "s4";
if (v == dnnl_u4) return "u4";
if (v == dnnl_f8_e8m0) return "f8_e8m0";
if (v == dnnl_f4_e2m1) return "f4_e2m1";
if (v == dnnl_data_type_max) return "data_type_max";
assert(!"unknown dt");
return "unknown dt";
Expand Down
10 changes: 10 additions & 0 deletions src/common/dnnl_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ template <> struct prec_traits<data_type::nf4> {
typedef uint8_t type;
};

template <>
struct prec_traits<data_type::f8_e8m0> {
typedef uint8_t type;
};

template <>
struct prec_traits<data_type::f4_e2m1> {
typedef uint8_t type;
};

template <>
struct data_traits<float8_e5m2_t> {
static constexpr data_type_t data_type = data_type::f8_e5m2;
Expand Down
5 changes: 3 additions & 2 deletions src/common/inner_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
if (attr == nullptr) return status::success;
const data_type_t src_dt = desc.src_desc.data_type;
const data_type_t wei_dt = desc.weights_desc.data_type;
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) && one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4)) ||
bool is_weight_compression = (one_of(src_dt, data_type::f32, data_type::bf16) &&
one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1)) ||
(one_of(src_dt, data_type::f32) && one_of(wei_dt, data_type::f16, data_type::bf16));
auto attr_mask = smask_t::none;
// From oneDNN 3.5, those checks must be skipped if wei_decomp is enabled
Expand Down Expand Up @@ -140,7 +141,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
|| utils::one_of(dst_dt, data_type::s8, data_type::u8,
data_type::s32);
if (engine->kind() == engine_kind::cpu)
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4);
is_int8 |= one_of(wei_dt, data_type::u8, data_type::s8, data_type::nf4, data_type::s4, data_type::u4, data_type::f4_e2m1);
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime | smask_t::src_dyn_quant_params;
if (is_weight_compression) {
fwd_attr_mask |= attr_mask;
Expand Down
2 changes: 2 additions & 0 deletions src/common/memory_zero_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
case u4: return typed_zero_pad<u8>(memory, ctx);
case bin: return typed_zero_pad<u8>(memory, ctx);
case nf4: return typed_zero_pad<u8>(memory, ctx);
case f8_e8m0: return typed_zero_pad<u8>(memory, ctx);
case f4_e2m1: return typed_zero_pad<u8>(memory, ctx);
default: assert(!"memory is undefined"); return unimplemented;
}
return unimplemented;
Expand Down
4 changes: 2 additions & 2 deletions src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,11 @@ status_t dnnl_primitive_attr_set_scales_mask(
return attr->scales_.set(arg, mask);
}
status_t dnnl_primitive_attr_set_scales_dims(
primitive_attr_t *attr, int arg, const dims_t dims, int ndims) {
primitive_attr_t *attr, int arg, const dims_t dims, int ndims, data_type_t data_type) {
bool ok = attr && arg >= 0 && ndims > 0
&& attr->output_scales_.has_default_values();
if (!ok) return invalid_arguments;
return attr->scales_.set(arg, dims, ndims);
return attr->scales_.set(arg, dims, ndims, data_type);
}

status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg,
Expand Down
7 changes: 4 additions & 3 deletions src/common/primitive_attr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,12 @@ struct runtime_scales_t : public c_compatible {
return status::success;
}

status_t set(const dims_t dims, int ndims) {
status_t set(const dims_t dims, int ndims, data_type_t data_type = data_type::f32) {
is_set_ = true;
ndims_ = ndims;
mask_ = 1;
utils::array_copy(dims_, dims, ndims_);
data_type_ = data_type;
return status::success;
}

Expand Down Expand Up @@ -348,9 +349,9 @@ struct arg_scales_t : public c_compatible {
if (!check_arg(arg)) return status::invalid_arguments;
return scales_[arg].set(mask);
}
status_t set(int arg, const dims_t dims, int ndims) {
status_t set(int arg, const dims_t dims, int ndims, data_type_t data_type) {
if (!check_arg(arg)) return status::invalid_arguments;
return scales_[arg].set(dims, ndims);
return scales_[arg].set(dims, ndims, data_type);
}

status_t set(int arg, int mask, int ndims, const dims_t group_dims,
Expand Down
6 changes: 4 additions & 2 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ inline size_t data_type_size(data_type_t data_type) {
case boolean: return sizeof(prec_traits<boolean>::type);
case bin: return sizeof(prec_traits<u8>::type);
case nf4: return sizeof(prec_traits<u8>::type);
case f8_e8m0: return sizeof(prec_traits<f8_e8m0>::type);
case f4_e2m1: return sizeof(prec_traits<f4_e2m1>::type);
case data_type::undef:
default: assert(!"unknown data_type");
}
Expand Down Expand Up @@ -318,7 +320,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,

/* prop_kind doesn't matter */
if (everyone_is(f32, src_dt, wei_dt)) return f32;
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4)) return f32;
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, s8, nf4, s4, u4, f4_e2m1)) return f32;
if (everyone_is(f64, src_dt, wei_dt)) return f64;

if (one_of(prop_kind, forward_training, forward_inference)) {
Expand Down Expand Up @@ -1086,7 +1088,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,

bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
&& utils::one_of(data_type, f8_e5m2, f8_e4m3, f16, bf16, f32, f64,
s32, s8, u8, nf4, s4, u4, bin);
s32, s8, u8, nf4, s4, u4, bin, f8_e8m0, f4_e2m1);
if (!ok) return false;

bool has_runtime_dims = false;
Expand Down
17 changes: 16 additions & 1 deletion src/cpu/cpu_inner_product_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
nullptr,
}},
{{forward, f32, f4_e2m1, f32}, {
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
nullptr,
}},
{{forward, f32, s4, f32}, {
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
Expand Down Expand Up @@ -123,7 +128,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
nullptr,
}},
}},
{{forward, bf16, s8, bf16}, {
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
Expand All @@ -139,6 +144,16 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
nullptr,
}},
{{forward, bf16, f4_e2m1, f32}, {
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
nullptr,
}},
{{forward, bf16, f4_e2m1, bf16}, {
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
nullptr,
}},
{{forward, bf16, s4, f32}, {
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/cpu_primitive.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
VCHECK_ATTR(scales != nullptr, \
"Scales buffer for arg %d is missing", arg); \
const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \
bool ok = scales_d.data_type() == data_type::f32 \
bool ok = (scales_d.data_type() == data_type::f32 || scales_d.data_type() == data_type::f8_e8m0) \
&& (scales_d.ndims() == 1 || scales_d.ndims() == 2); \
if (!ok) return status::invalid_arguments; \
if (scales_d.dims()[0] == 1) { \
Expand Down
2 changes: 2 additions & 0 deletions src/cpu/reorder/cpu_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ regular_impl_list_map() {
{{f32, u8, 0}, &regular_f32_u8_impl_list_map()},
{{f8_e5m2, data_type::undef, 0}, &regular_fp8_impl_list_map()},
{{f8_e4m3, data_type::undef, 0}, &regular_fp8_impl_list_map()},
{{f8_e8m0, data_type::undef, 0}, &regular_fp8_impl_list_map()},
{{f32, bin, 0}, &regular_f32_bin_impl_list_map()},
{{bf16, data_type::undef, 0}, &regular_bf16_impl_list_map()},
{{f16, data_type::undef, 0}, &regular_f16_impl_list_map()},
Expand All @@ -47,6 +48,7 @@ regular_impl_list_map() {
{{u4, f32, 0}, &regular_u4_impl_list_map()},
{{bin, data_type::undef, 0}, &regular_bin_impl_list_map()},
{{nf4, data_type::undef, 0}, &regular_nf4_impl_list_map()},
{{f4_e2m1, data_type::undef, 0}, &regular_f4_impl_list_map()},
{{s4, data_type::undef, 0}, &regular_s4_impl_list_map()},
{{u4, data_type::undef, 0}, &regular_u4_impl_list_map()},
};
Expand Down
1 change: 1 addition & 0 deletions src/cpu/reorder/cpu_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ extern const impl_list_map_t &regular_s4_impl_list_map();
extern const impl_list_map_t &regular_u4_impl_list_map();
extern const impl_list_map_t &regular_bin_impl_list_map();
extern const impl_list_map_t &regular_nf4_impl_list_map();
extern const impl_list_map_t &regular_f4_impl_list_map();
extern const impl_list_map_t &regular_s4_impl_list_map();
extern const impl_list_map_t &regular_u4_impl_list_map();

Expand Down
48 changes: 48 additions & 0 deletions src/cpu/reorder/cpu_reorder_regular_f4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright 2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "cpu/reorder/cpu_reorder.hpp"

namespace dnnl {
namespace impl {
namespace cpu {

// clang-format off

const impl_list_map_t &regular_f4_impl_list_map() {
static const impl_list_map_t the_map = REG_REORDER_P({
// f4_e2m1 ->
{{f4_e2m1, data_type::undef, 0}, {
REG_SR(f4_e2m1, any, f4_e2m1, OI8i8o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI8i16o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI8i24o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI8i32o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI8i64o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI16i16o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI16i32o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI16i48o2i, fmt_order_keep)
REG_SR(f4_e2m1, any, f4_e2m1, OI16i64o2i, fmt_order_keep)
nullptr,
}},
});
return the_map;
}

// clang-format on

} // namespace cpu
} // namespace impl
} // namespace dnnl
6 changes: 6 additions & 0 deletions src/cpu/reorder/cpu_reorder_regular_fp8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ const impl_list_map_t &regular_fp8_impl_list_map() {
REG_SR(f8_e4m3, any, bf16, any, fmt_order::any, spec::reference)
REG_SR(f8_e4m3, any, f32, any, fmt_order::any, spec::reference)

nullptr,
}},
// f8_e8m0 ->
{{f8_e8m0, data_type::undef, 0}, {
REG_SR(f8_e8m0, any, f8_e8m0, any, fmt_order::any, spec::reference)

nullptr,
}},
});
Expand Down
9 changes: 5 additions & 4 deletions src/cpu/reorder/simple_reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,8 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
tag_traits<tag_o>::ndims >= 4
&& tag_traits<tag_o>::ndims <= 6)
&& (type_i != dnnl_bin && type_o != dnnl_bin)
&& (type_i != dnnl_nf4 && type_o != dnnl_nf4)>::type> {
&& (type_i != dnnl_nf4 && type_o != dnnl_nf4)
&& (type_i != dnnl_f4_e2m1 && type_o != dnnl_f4_e2m1)>::type> {
PLAIN_TO_BLOCKED_IS_APPLICABLE();

GET_SCRATCHPAD_SIZE_ZERO();
Expand Down Expand Up @@ -2004,7 +2005,7 @@ template <SIMPLE_REORDER_TEMPL_DECL>
struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
typename utils::enable_if<tag_i == format_tag::any &&
tag_traits<tag_o>::block_dims == bd::_AB &&
utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4) &&
utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1) &&
type_i == type_o>::type>
{
static bool is_applicable(const memory_desc_wrapper &input_d,
Expand Down Expand Up @@ -2483,8 +2484,8 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
typename utils::enable_if<tag_i == format_tag::any
&& tag_o == format_tag::any
&& order_keep == fmt_order::any
&& !(utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4) ||
utils::one_of(type_o, dnnl_nf4, dnnl_s4, dnnl_u4)),
&& !(utils::one_of(type_i, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1) ||
utils::one_of(type_o, dnnl_nf4, dnnl_s4, dnnl_u4, dnnl_f4_e2m1)),
spec::reference>::type> {
static bool is_applicable(const memory_desc_wrapper &input_d,
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
Expand Down
4 changes: 4 additions & 0 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ status_t brgemm_desc_init(brgemm_desc_t *brg, cpu_isa_t isa,
brg->with_wei_decomp_scales = !wei_scales.has_default_values();
brg->wei_decomp_scales_group_size = wei_d.dims()[1];
if (brg->with_wei_decomp_scales) {
brg->wei_decomp_scales_dt = wei_scales.data_type_;
if (!one_of(brg->wei_decomp_scales_dt, f32, f8_e8m0))
return status::unimplemented;

auto ld_dim = wei_scales.dims_[0];
brg->wei_decomp_scales_stride = ld_dim > 1 ? ld_dim : 0;
brg->wei_decomp_scales_group_size = wei_d.dims()[1] / wei_scales.dims_[1];
Expand Down
1 change: 1 addition & 0 deletions src/cpu/x64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ struct brgemm_desc_t {
int wei_decomp_zero_points_stride = 0;
int wei_decomp_scales_group_size = 0;
int wei_decomp_zero_points_group_size = 0;
impl::data_type_t wei_decomp_scales_dt = data_type::undef;
impl::data_type_t wei_decomp_zero_points_dt = data_type::undef;
bool with_src_dyn_quant = false;
int src_scales_group_size = 0;
Expand Down
Loading

0 comments on commit 6b99866

Please sign in to comment.