From d0993d91377f9a65df439fc163781bf7d8581f6e Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Wed, 11 Dec 2024 15:50:18 -0800 Subject: [PATCH] [xla:cpu] Add missing files from openxla/xla#16438 Add missing files and changes from a PR adding matmul reordering support to oneDNN for aarch64 CPU: https://github.com/openxla/xla/pull/16438 Also add a missing indirect convolution patch from a TF PR: https://github.com/tensorflow/tensorflow/pull/62852 PiperOrigin-RevId: 705268797 --- ..._acl_add_bf16_platform_support_check.patch | 31 ++++++ ...d_sbgemm_matmul_primitive_definition.patch | 44 ++++++++ ...d_weight_format_for_matmul_primitive.patch | 100 ++++++++++++++++++ ...l_fix_segfault_during_postop_execute.patch | 96 +++++++++++++++++ workspace2.bzl | 5 + 5 files changed, 276 insertions(+) create mode 100644 third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch create mode 100644 third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch create mode 100644 third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch create mode 100644 third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch diff --git a/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch b/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch new file mode 100644 index 000000000..42dd26232 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_add_bf16_platform_support_check.patch @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp +index 65b887ea21..eabdb827bd 100644 +--- a/src/cpu/platform.cpp ++++ b/src/cpu/platform.cpp +@@ -117,6 +117,8 @@ bool has_data_type_support(data_type_t data_type) { + #if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) + return true; + #endif ++#elif DNNL_AARCH64_USE_ACL ++ return arm_compute::CPUInfo::get().has_bf16(); + #else + return false; + #endif +-- +2.34.1 + diff --git a/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch b/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch new file mode 100644 index 000000000..779608a68 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_add_sbgemm_matmul_primitive_definition.patch @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index ab13efb9b2..ec261e156d 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -78,11 +78,21 @@ struct acl_matmul_t : public primitive_t { + = utils::everyone_is(data_type::f16, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type) + && platform::has_data_type_support(data_type::f16); ++ const bool is_fp32_bf16_ok ++ = (utils::everyone_is(data_type::f32, src_md()->data_type, ++ dst_md()->data_type, desc()->accum_data_type) ++ && platform::has_data_type_support(data_type::f32) ++ && utils::everyone_is( ++ data_type::bf16, weights_md()->data_type) ++ && platform::has_data_type_support( ++ data_type::bf16)); ++ + const bool is_weights_md_format_ok + = utils::one_of(weights_format_kind_received, + format_kind::any, format_kind::blocked); + bool ok = is_dense_data() +- && utils::one_of(true, is_fp32_ok, is_fp16_ok) ++ && utils::one_of( ++ true, is_fp32_ok, is_fp16_ok, is_fp32_bf16_ok) + && !has_zero_dim_memory() && is_weights_md_format_ok + && set_default_formats() + && attr()->has_default_values( +-- +2.34.1 diff --git a/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch b/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch new file mode 100644 index 000000000..ec2cb97f5 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch @@ -0,0 +1,100 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index 451cc78d52..ab13efb9b2 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -67,6 +67,8 @@ struct acl_matmul_t : public primitive_t { + + status_t init(engine_t *engine) { + using smask_t = primitive_attr_t::skip_mask_t; ++ const format_kind_t weights_format_kind_received ++ = weights_md_.format_kind; + const bool is_fp32_ok + = utils::everyone_is(data_type::f32, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type, +@@ -76,18 +78,20 @@ struct acl_matmul_t : public primitive_t { + = utils::everyone_is(data_type::f16, src_md()->data_type, + weights_md()->data_type, dst_md()->data_type) + && platform::has_data_type_support(data_type::f16); ++ const bool is_weights_md_format_ok ++ = utils::one_of(weights_format_kind_received, ++ format_kind::any, format_kind::blocked); + bool ok = is_dense_data() + && utils::one_of(true, is_fp32_ok, is_fp16_ok) +- && !has_zero_dim_memory() +- && weights_md_.format_kind == format_kind::any ++ && !has_zero_dim_memory() && is_weights_md_format_ok + && set_default_formats() + && attr()->has_default_values( + smask_t::oscale | smask_t::post_ops) + && attr_oscale_ok() && !has_runtime_dims_or_strides(); + if (!ok) return status::unimplemented; + +- CHECK(acl_matmul_utils::init_conf_matmul( +- amp_, src_md_, weights_md_, dst_md_, *desc(), *attr())); ++ CHECK(acl_matmul_utils::init_conf_matmul(amp_, src_md_, weights_md_, ++ dst_md_, *desc(), *attr(), weights_format_kind_received)); + + arm_compute::ActivationLayerInfo act_info; + CHECK(post_ops.init(engine, attr_.post_ops_, dst_md_, act_info)); +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +index a314d96384..027f915a8a 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +@@ -27,7 +27,8 @@ namespace acl_matmul_utils { + + status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, +- const primitive_attr_t &attr) { ++ const primitive_attr_t &attr, ++ format_kind_t weights_format_kind_received) { + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); +@@ -128,9 +129,16 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + for (dim_t i = K_dim - 1; i >= 0; --i) + batch_dims.push_back(i); + ++ const memory_desc_t weights_md_received = wei_md; + acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md, + expected_weight_format, K_dim, N_dim, {}, batch_dims); + ++ ACL_CHECK_SUPPORT((weights_format_kind_received == format_kind::blocked) ++ && !(dnnl_memory_desc_equal(&weights_md_received, &wei_md)), ++ "specified blocked format not supported by ACL, use " ++ "format_kind_t::any to find a supported blocked format for " ++ "your platform"); ++ + return status::success; + } + +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +index 67bb2e78eb..5ba4241abc 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +@@ -52,7 +52,8 @@ namespace acl_matmul_utils { + + status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, const matmul_desc_t &md, +- const primitive_attr_t &attr); ++ const primitive_attr_t &attr, ++ format_kind_t weights_format_kind_received); + + } // namespace acl_matmul_utils + +-- +2.34.1 diff --git a/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch b/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch new file mode 100644 index 000000000..39f7e7434 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_fix_segfault_during_postop_execute.patch @@ -0,0 +1,96 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +diff --git a/src/cpu/aarch64/acl_post_ops.cpp b/src/cpu/aarch64/acl_post_ops.cpp +index ea4bb200ec..3eb53b81bd 100644 +--- a/src/cpu/aarch64/acl_post_ops.cpp ++++ b/src/cpu/aarch64/acl_post_ops.cpp +@@ -24,7 +24,7 @@ namespace aarch64 { + + status_t acl_post_ops_t::execute(const exec_ctx_t &ctx, void *src_orig) const { + +- int post_op_index = 0; ++ int post_op_index = post_op_start_index_; + + // As these are post ops, this src will also be our dst. If we have a sum + // post op, the src/dst will start off in a temporary, then change to +diff --git a/src/cpu/aarch64/acl_post_ops.hpp b/src/cpu/aarch64/acl_post_ops.hpp +index 7b59ad71d3..ceaa95b73a 100644 +--- a/src/cpu/aarch64/acl_post_ops.hpp ++++ b/src/cpu/aarch64/acl_post_ops.hpp +@@ -32,7 +32,9 @@ struct acl_post_ops_t { + // init the acl_post_ops_t. Note that this function modifies the passed in + // post ops by setting the preferred memory formats + status_t init(engine_t *engine, post_ops_t &post_ops, +- const memory_desc_t &dst_md) { ++ const memory_desc_t &dst_md, int post_op_start_index = 0) { ++ ++ post_op_start_index_ = post_op_start_index; + + CHECK(post_ops.set_default_formats(&dst_md)); + dst_data_type = dst_md.data_type; +@@ -41,7 +43,7 @@ struct acl_post_ops_t { + sum_index = -1; + post_op_primitives = {}; + +- for (int i = 0; i < post_ops.len(); i++) { ++ for (int i = post_op_start_index; i < post_ops.len(); i++) { + auto &po = post_ops.entry_[i]; + + if (po.is_sum()) { +@@ -135,7 +137,8 @@ struct acl_post_ops_t { + // formats + status_t init(engine_t *engine, post_ops_t &base_post_ops, + const memory_desc_t &dst_md, +- arm_compute::ActivationLayerInfo &act_info_to_fuse) { ++ arm_compute::ActivationLayerInfo &act_info_to_fuse, ++ int post_op_start_index = 0) { + + CHECK(base_post_ops.set_default_formats(&dst_md)); + dst_data_type = dst_md.data_type; +@@ -149,18 +152,11 @@ struct acl_post_ops_t { + "eltwise post op scale must be 1 (no scale)"); + CHECK(acl_utils::convert_to_acl_act(first_po, act_info_to_fuse)); + +- // Copy all but the first, because it has been fused +- post_ops_t post_ops; +- for (int idx = 1; idx < base_post_ops.len(); ++idx) { +- // Construct empty entry then copy, so that we can check for failure +- post_ops.entry_.emplace_back(); +- post_ops.entry_.back().copy_from(base_post_ops.entry_[idx]); +- } +- return init(engine, post_ops, dst_md); +- ++ // post_op_start_index + 1 to skip the fused eltwise ++ return init(engine, base_post_ops, dst_md, post_op_start_index + 1); + } else { + // Nothing to fuse, just copy all post ops +- return init(engine, base_post_ops, dst_md); ++ return init(engine, base_post_ops, dst_md, post_op_start_index); + } + } + +@@ -179,6 +175,9 @@ struct acl_post_ops_t { + private: + // Index of the sum post op if there is one, < 0 means no sum + int sum_index = -1; ++ // Index of the first post op this primitive executes. This is typically the ++ // number of post ops which were fused. ++ int post_op_start_index_ = 0; + data_type_t dst_data_type; + // Vector of primitives used to execute the post ops. They are constructed + // in init to be either acl_binary_t (for sum, add, sub, div, mul, min and +-- +2.34.1 diff --git a/workspace2.bzl b/workspace2.bzl index 5cffcd00d..993450dc3 100644 --- a/workspace2.bzl +++ b/workspace2.bzl @@ -163,6 +163,11 @@ def _tf_repositories(): "//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch", "//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch", "//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch", + "//third_party/mkl_dnn:onednn_acl_indirect_conv.patch", + "//third_party/mkl_dnn:onednn_acl_allow_blocked_weight_format_for_matmul_primitive.patch", + "//third_party/mkl_dnn:onednn_acl_fix_segfault_during_postop_execute.patch", + "//third_party/mkl_dnn:onednn_acl_add_bf16_platform_support_check.patch", + "//third_party/mkl_dnn:onednn_acl_add_sbgemm_matmul_primitive_definition.patch", ], sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3", strip_prefix = "oneDNN-3.2.1",