Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu: aarch64: Expand brgemm aarch64 unsupported cases handling mechanism #2099

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/cpu/aarch64/acl_deconvolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ struct acl_deconvolution_fwd_t : public primitive_t {
}

// Data layout
const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC
: arm_compute::DataLayout::NCHW;
const arm_compute::DataLayout acl_layout = is_nspc
? arm_compute::DataLayout::NHWC
: arm_compute::DataLayout::NCHW;

acl_pd_conf.src_info = arm_compute::TensorInfo(is_nspc
? arm_compute::TensorShape(ic, iw, ih, mb)
Expand Down Expand Up @@ -243,18 +244,15 @@ struct acl_deconvolution_fwd_t : public primitive_t {
// padding is set for convolution. Otherwise, describe deconvolution as convolution of
// upsampling input with stride = 1 and pad = 0.
arm_compute::ConvolutionMethod conv_method;
arm_compute::TensorInfo *conv_src_info;
arm_compute::TensorInfo conv_src_info(
acl_pd_conf.src_info.clone()->set_is_resizable(true));
unsigned int pad_left = 0;
unsigned int pad_right = 0;
unsigned int pad_top = 0;
unsigned int pad_bottom = 0;
if (sh != 1 || sw != 1) {
arm_compute::TensorInfo scale_out_info(
acl_pd_conf.src_info.clone()
->set_is_resizable(true)
.reset_padding()
.set_tensor_shape(scale_out_shape));
conv_src_info = &scale_out_info;
conv_src_info.reset_padding();
conv_src_info.set_tensor_shape(scale_out_shape);
} else {
// compute correct padding here
pad_left = pr > pl ? pr - pl : 0;
Expand All @@ -269,15 +267,13 @@ struct acl_deconvolution_fwd_t : public primitive_t {
pad_right += deconv_pad_x / 2;
pad_top += deconv_pad_y / 2;
pad_bottom += deconv_pad_y / 2;

conv_src_info = &acl_pd_conf.src_info;
}
const arm_compute::PadStrideInfo conv_info(1, 1, pad_left,
pad_right, pad_top, pad_bottom,
arm_compute::DimensionRoundingType::CEIL);
conv_method
= arm_compute::NEConvolutionLayer::get_convolution_method(
conv_src_info, &acl_pd_conf.wei_info,
&conv_src_info, &acl_pd_conf.wei_info,
&acl_pd_conf.dst_info, conv_info,
arm_compute::WeightsInfo(),
arm_compute::Size2D(1U, 1U),
Expand Down
9 changes: 5 additions & 4 deletions src/cpu/aarch64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -170,8 +171,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
if (brg == nullptr) return status::invalid_arguments;
if (transA || transB) return status::unimplemented;

brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
beta, LDA, LDB, LDC, M, N, K, strides);
CHECK(brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout,
alpha, beta, LDA, LDB, LDC, M, N, K, strides));

if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
bool ldx_check = (brg->is_row_major()) ? (LDA < K)
Expand All @@ -197,8 +198,8 @@ status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f)
return status::unimplemented;

brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
beta, LDA, LDC, M, N, strides);
CHECK(brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout,
alpha, beta, LDA, LDC, M, N, strides));

const bool ldx_check = (LDA < N || LDC < N);
if (ldx_check) return status::invalid_arguments;
Expand Down
55 changes: 29 additions & 26 deletions src/cpu/aarch64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2022-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,15 +48,18 @@ impl::data_type_t get_accum_datatype(brgemm_t *brg) {
return brg->is_int8 ? data_type::s32 : data_type::f32;
}

void init_kernel_datatype(
status_t init_kernel_datatype(
brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) {
assert(dt_a != data_type::undef && dt_b != data_type::undef);
if (dt_a != data_type::undef && dt_b != data_type::undef)
return status::unimplemented;
brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8)
&& utils::one_of(dt_b, data_type::u8, data_type::s8);
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
if (brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16)
return status::unimplemented;
return status::success;
}

void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
Expand Down Expand Up @@ -88,27 +92,22 @@ void maybe_try_bf32(brgemm_t *brg) {
//
}

void set_isa_impl(brgemm_t *brg) {
status_t set_isa_impl(brgemm_t *brg) {
auto is_isa_ok = [&](cpu_isa_t isa) {
return mayiuse(isa) &&
// maybe IMPLICATION(brg->isa_user != isa_undef,
// is_superset(brg->isa_user, isa)), but the API is not clear.
one_of(brg->isa_user, isa_undef, isa);
};

if (brg->is_bf32) {
assert(!"unsupported case");
} else if (brg->is_f32) {
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512,
is_isa_ok(sve_256), sve_256);
} else if (brg->is_bf16) {
assert(!"unsupported case");
} else if (brg->is_f16) {
assert(!"unsupported case");
} else if (brg->is_int8) {
if (brg->is_bf32 || brg->is_bf16 || brg->is_f16) {
return status::unimplemented;
} else if (brg->is_f32 || brg->is_int8) {
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512,
is_isa_ok(sve_256), sve_256);
return status::success;
}
return status::success;
}

void set_brg_vmm(brgemm_t *brg) {
Expand Down Expand Up @@ -187,7 +186,7 @@ inline size_t data_type_vnni_granularity(data_type_t data_type) {
}
status_t brgemm_blocking(brgemm_t *brg) {

set_isa_impl(brg);
CHECK(set_isa_impl(brg));
if (brg->isa_impl == isa_undef) return status::unimplemented;
assert(!brg->is_dgmm); // should not be called from brdgmm
set_brg_vmm(brg);
Expand Down Expand Up @@ -296,18 +295,19 @@ status_t brdgmm_blocking(brgemm_t *brg) {
return status::success;
}

void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) {
status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K,
const brgemm_strides_t *strides, bool is_bf32) {

init_common_conf(brg, type, alpha, beta, strides);

brg->layout = layout;

brg->dt_a = brg->is_row_major() ? dt_a : dt_b;
brg->dt_b = brg->is_row_major() ? dt_b : dt_a;
init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b));

brg->dt_c = get_accum_datatype(brg);
brg->dt_d = brg->dt_c;
Expand All @@ -319,7 +319,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
brg->typesize_D = types::data_type_size(brg->dt_d);

brg->isa_user = isa;
set_isa_impl(brg);
CHECK(set_isa_impl(brg));
brg->is_bf32 = false;

brg->has_int8_vnni = true;
Expand Down Expand Up @@ -352,11 +352,13 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
brg->rd_step = has_no_vnni_compute_instruction
? 1
: data_type_vnni_granularity(brg->dt_b);
return status::success;
}

void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDC, dim_t M, dim_t N,
const brgemm_strides_t *strides) {

init_common_conf(brg, type, alpha, beta, strides);
Expand All @@ -365,7 +367,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,

brg->dt_a = dt_a;
brg->dt_b = dt_b;
init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b));

brg->dt_c = get_accum_datatype(brg);
brg->dt_d = brg->dt_c;
Expand Down Expand Up @@ -394,6 +396,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,

brg->bcast_dim = M;
brg->load_dim = N;
return status::success;
}

} // namespace brgemm_utils
Expand All @@ -402,4 +405,4 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
} // namespace impl
} // namespace dnnl

//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
18 changes: 10 additions & 8 deletions src/cpu/aarch64/brgemm/brgemm_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,20 +45,21 @@ status_t brdgmm_blocking(brgemm_t *brg);
* having to depend on BRGeMM's API. An additional feature is that this
* function can be modified depending on needs without requiring changes
* at the API level. */
void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
dim_t N, dim_t K, const brgemm_strides_t *strides = nullptr,
bool is_bf32 = false);
status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K,
const brgemm_strides_t *strides = nullptr, bool is_bf32 = false);

/* The purpose of this function is to enable initialization of brgemm values
* and then call additional functions like blocking heuristics without
* having to depend on BRDGeMM's API. An additional feature is that this
* function can be modified depending on needs without requiring changes
* at the API level. */
void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDC, dim_t M, dim_t N,
const brgemm_strides_t *strides = nullptr);

} // namespace brgemm_utils
Expand Down
11 changes: 6 additions & 5 deletions src/cpu/aarch64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -725,9 +726,9 @@ status_t brg_blocking_t::estimate_brgemm_ur() {
const float alpha = 1.0;
const float beta = 0.0;
brgemm_t brg;
brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr,
is_bf32);
is_bf32));
CHECK(brgemm_utils::brgemm_blocking(&brg));
ur = brg.bd_block;
ur_block = brg.bd_block;
Expand Down Expand Up @@ -771,9 +772,9 @@ status_t brg_blocking_t::get_brgemm_ur(
* rnd_up(oc, oc_block) * wei_dsz;
const auto strides_ptr
= (brg_type == brgemm_strd) ? &brg_strides : nullptr;
brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt,
wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB, LDC,
vM, vN, vK, strides_ptr, is_bf32);
CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brg_type,
src_dt, wei_dt, brgemm_row_major, alpha, vbeta, LDA,
LDB, LDC, vM, vN, vK, strides_ptr, is_bf32));
CHECK(brgemm_utils::brgemm_blocking(&brg));

brgemm_attr_t brgattr;
Expand Down
3 changes: 1 addition & 2 deletions src/cpu/aarch64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down Expand Up @@ -642,7 +643,6 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
assert(isa == sve_512);
(*copy_B_kernel_)(&ctx);
}

Expand All @@ -654,7 +654,6 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
assert(isa == sve_512);
(*copy_B_kernel_)(&ctx);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -129,7 +130,7 @@ bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr,
}

status_t check_isa_with_datatype(
const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) {
const brgemm_matmul_conf_utils_t &bm_conf_utils) {
if (bm_conf_utils.is_f32() && !bm_conf_utils.is_int8()
&& !bm_conf_utils.is_bf16() && !bm_conf_utils.is_f16()
&& !bm_conf_utils.is_int8())
Expand Down Expand Up @@ -732,8 +733,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
dst_d.format_kind() == format_kind::any,
bias_md.format_kind == format_kind::any);

VCHECK_BG(check_isa_with_datatype(isa, bm_conf_utils),
VERBOSE_ISA_DT_MISMATCH);
VCHECK_BG(check_isa_with_datatype(bm_conf_utils), VERBOSE_ISA_DT_MISMATCH);

bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt);
bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt);
Expand Down Expand Up @@ -1107,4 +1107,4 @@ void init_scratchpad(memory_tracking::registrar_t &scratchpad,
} // namespace aarch64
} // namespace cpu
} // namespace impl
} // namespace dnnl
} // namespace dnnl
5 changes: 3 additions & 2 deletions tests/benchdnn/graph/ref_primitive.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2023-2024 Intel Corporation
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -244,8 +245,8 @@ void ref_primitive_t::check_correctness(
check_buffer_overwrite(args.dnn_mem(i), args.arg(i), res);

const auto arg = args.arg(i);
const auto &mem_dt = args.find(arg);
const auto &mem_fp = args_.find(arg);
const dnn_mem_t &mem_dt = args.find(arg);
const dnn_mem_t &mem_fp = args_.find(arg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed? What's the issue with auto?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzarukin Hi Dmitry, thanks for having a look into this. I have replaced auto with explicit data types for improved readability, consistency, and type safety. When debugging it was particularly useful to have the data type specified and I did go an extra mile checking that the function returns the same type. Should not be a problem but I am happy to review it if it is a deal breaker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve been working on this the past few days, and all tests for the Graph API should run fine. There are a few edge cases that fail but I’ll detail in an issue soon and shall be fixed in the foreseeable future.

Thanks for the approval!


if (dnnl_arg_2_data_kind_map.find(arg)
== dnnl_arg_2_data_kind_map.end()) {
Expand Down