From 3e9fb571679a88c05a28f1828011c53d4cba6cd8 Mon Sep 17 00:00:00 2001 From: pyf Date: Thu, 6 Jun 2024 09:18:46 +0800 Subject: [PATCH] add mask primitive --- include/oneapi/dnnl/dnnl.h | 38 +++++- include/oneapi/dnnl/dnnl.hpp | 127 +++++++++++++++++- include/oneapi/dnnl/dnnl_types.h | 4 +- src/common/c_types_map.hpp | 1 + src/common/dnnl_traits.hpp | 1 + src/common/mask.cpp | 82 ++++++++++++ src/common/mask_pd.hpp | 115 +++++++++++++++++ src/common/opdesc.hpp | 11 ++ src/common/primitive_desc_iface.cpp | 2 +- src/common/primitive_hashing.cpp | 16 +++ src/common/primitive_hashing.hpp | 2 + src/common/transpose.cpp | 4 +- src/common/transpose_pd.hpp | 4 +- src/common/type_helpers.hpp | 12 ++ src/common/verbose.hpp | 1 + src/gpu/amd/custom/hip_customs.h | 2 + src/gpu/amd/custom/hip_mask.cc | 191 ++++++++++++++++++++++++++++ src/gpu/amd/custom/hip_transpose.cc | 2 + src/gpu/amd/hip_mask.cpp | 57 +++++++++ src/gpu/amd/hip_mask.hpp | 93 ++++++++++++++ src/gpu/amd/hip_mask_impl.hpp | 67 ++++++++++ src/gpu/amd/hip_transpose.cpp | 4 +- src/gpu/amd/sycl_hip_engine.cpp | 5 + 23 files changed, 828 insertions(+), 13 deletions(-) create mode 100644 src/common/mask.cpp create mode 100644 src/common/mask_pd.hpp create mode 100644 src/gpu/amd/custom/hip_mask.cc create mode 100644 src/gpu/amd/hip_mask.cpp create mode 100644 src/gpu/amd/hip_mask.hpp create mode 100644 src/gpu/amd/hip_mask_impl.hpp diff --git a/include/oneapi/dnnl/dnnl.h b/include/oneapi/dnnl/dnnl.h index 24aacf0728f..42cc534ec03 100644 --- a/include/oneapi/dnnl/dnnl.h +++ b/include/oneapi/dnnl/dnnl.h @@ -1361,6 +1361,42 @@ dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory); /// @addtogroup dnnl_api_primitives /// @{ +/// @addtogroup dnnl_api_mask +/// @{ + +/// Creates a primitive descriptor for a mask primitive. +/// +/// @param mask_primitive_desc Output primitive descriptor. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param mask_desc Mask memory descriptor. +/// @param value value to fill in Source with. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_mask_primitive_desc_create( + dnnl_primitive_desc_t *mask_primitive_desc, + dnnl_engine_t engine, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t mask_desc, + double value); + +/// Creates a primitive descriptor for a mask primitive. +/// +/// @param mask_primitive_desc Output primitive descriptor. +/// @param src_desc Source memory descriptor. +/// @param dst_desc Destination memory descriptor. +/// @param mask_desc Mask memory descriptor. +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. +dnnl_status_t DNNL_API dnnl_mask_primitive_desc_create( + dnnl_primitive_desc_t *mask_primitive_desc, + dnnl_engine_t engine, const_dnnl_memory_desc_t src_desc, + const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t mask_desc, + double value); + +/// @} dnnl_api_mask + /// @addtogroup dnnl_api_transpose /// @{ @@ -1373,7 +1409,7 @@ dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory); /// otherwise. dnnl_status_t DNNL_API dnnl_transpose_primitive_desc_create( dnnl_primitive_desc_t *transpose_primitive_desc, - dnnl_engine *engine, const_dnnl_memory_desc_t src_desc, + dnnl_engine_t engine, const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc, dnnl_dim_t dim1, dnnl_dim_t dim2); diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index 624e923c317..f4418987b96 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -108,6 +108,8 @@ struct primitive : public handle { enum class kind { /// Undefined primitive undef = dnnl_undefined_primitive, + /// A mask primitive. + mask = dnnl_mask, /// A transpose primitive. transpose = dnnl_transpose, /// A reorder primitive. @@ -4896,6 +4898,123 @@ struct primitive_desc_base : public handle { /// @} dnnl_api_primitives_common +/// @addtogroup dnnl_api_mask mask +/// +/// A primitive to copy data between two memory objects. This primitive is +/// typically used to exchange the two specified dimensions of the src memory. +/// +/// @sa @ref dev_guide_transpose in developer guide +/// +/// @{ + +/// Transpose primitive. +struct mask : public primitive { + /// Primitive descriptor for a mask primitive. + struct primitive_desc : public primitive_desc_base { + using primitive_desc_base::primitive_desc_base; + + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + /// Constructs a primitive descriptor for mask primitive. + /// + /// @param src Source memory object. It is used to obtain the source + /// memory descriptor and engine. + /// @param mask mask(weight) memory object. It is used to obtain the mask + /// memory descriptor and engine. + /// @param dst Destination memory object. It is used to obtain the + /// destination memory descriptor and engine. + /// @param value value that used to fill masked place of Source memory. + /// @param allow_empty A flag signifying whether construction is allowed + /// to fail without throwing an exception. In this case an empty + /// object will be produced. This flag is optional and defaults to + /// false. + primitive_desc(const engine &aengine, const memory &src, const memory &mask, + const memory &dst, double value, bool allow_empty = false) { + dnnl_primitive_desc_t result; + auto src_md = src.get_desc(); + auto mask_md = mask.get_desc(); + auto dst_md = dst.get_desc(); + dnnl_status_t status = dnnl_mask_primitive_desc_create(&result, + aengine.get(), src_md.get(), dst_md.get(), mask_md.get(), value); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for a mask " + "primitive"); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for mask primitive, for integer memory. + /// + /// @param src Source memory object. It is used to obtain the source + /// memory descriptor and engine. + /// @param mask mask(weight) memory object. It is used to obtain the mask + /// memory descriptor and engine. + /// @param dst Destination memory object. It is used to obtain the + /// destination memory descriptor and engine. + /// @param value value that used to fill masked place of Source memory. + /// @param allow_empty A flag signifying whether construction is allowed + /// to fail without throwing an exception. In this case an empty + /// object will be produced. This flag is optional and defaults to + /// false. + primitive_desc(const engine &aengine, const memory &src, const memory &mask, + const memory &dst, int64_t value, bool allow_empty = false) { + dnnl_primitive_desc_t result; + auto src_md = src.get_desc(); + auto mask_md = mask.get_desc(); + auto dst_md = dst.get_desc(); + dnnl_status_t status = dnnl_mask_primitive_desc_create(&result, + aengine.get(), src_md.get(), dst_md.get(), mask_md.get(), value); + if (!allow_empty) + error::wrap_c_api(status, + "could not create a primitive descriptor for a mask " + "primitive"); + reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); + } + + /// Constructs a primitive descriptor for reorder primitive from a C + /// API primitive descriptor which must have a matching kind. + /// + /// @param pd C API primitive descriptor for reorder primitive. + primitive_desc(dnnl_primitive_desc_t pd) + : primitive_desc_base(pd, dnnl::primitive::kind::mask) {} + + /// @copydoc dnnl::primitive_desc_base::src_desc()const + memory::desc src_desc() const { return base::src_desc(0); } + + /// @copydoc dnnl::primitive_desc_base::dst_desc()const + memory::desc dst_desc() const { return base::dst_desc(0); } + }; + + /// Default constructor. Produces an empty object. + mask() = default; + + /// Constructs a mask primitive. + /// @param pd Primitive descriptor for mask primitive. + mask(const primitive_desc &pd) : primitive(pd.get()) {} + + /// Constructs a mask primitive from a cache blob. + /// @param pd Primitive descriptor for mask primitive. + /// @param cache_blob Cache blob. + mask(const primitive_desc &pd, const std::vector &cache_blob) + : primitive(pd.get(), cache_blob) {} + + using primitive::execute; + + /// Executes the mask primitive. + /// + /// @param astream Stream object. The stream must belong to the same engine + /// as the primitive. + /// @param src Source memory object. + /// @param dst Destination memory object. + /// @param mask Mask memory object. + void execute(const stream &astream, memory &src, memory &dst, memory &mask) const { + primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}, {DNNL_ARG_WEIGHTS, mask}}); + } +}; + +/// @} dnnl_api_mask + /// @addtogroup dnnl_api_transpose Transpose /// /// A primitive to copy data between two memory objects. This primitive is @@ -4907,7 +5026,7 @@ struct primitive_desc_base : public handle { /// Transpose primitive. struct transpose : public primitive { - /// Primitive descriptor for a reorder primitive. + /// Primitive descriptor for a transpose primitive. struct primitive_desc : public primitive_desc_base { using primitive_desc_base::primitive_desc_base; @@ -4920,6 +5039,8 @@ struct transpose : public primitive { /// memory descriptor and engine. /// @param dst Destination memory object. It is used to obtain the /// destination memory descriptor and engine. + /// @param dim1 dim1 that transpose from. + /// @param dim2 dim1 that transpose to. /// @param allow_empty A flag signifying whether construction is allowed /// to fail without throwing an exception. In this case an empty /// object will be produced. This flag is optional and defaults to @@ -4938,10 +5059,10 @@ struct transpose : public primitive { reset(status == dnnl_success ? result : dnnl_primitive_desc_t()); } - /// Constructs a primitive descriptor for reorder primitive from a C + /// Constructs a primitive descriptor for transpose primitive from a C /// API primitive descriptor which must have a matching kind. /// - /// @param pd C API primitive descriptor for reorder primitive. + /// @param pd C API primitive descriptor for transpose primitive. primitive_desc(dnnl_primitive_desc_t pd) : primitive_desc_base(pd, dnnl::primitive::kind::transpose) {} diff --git a/include/oneapi/dnnl/dnnl_types.h b/include/oneapi/dnnl/dnnl_types.h index bd6d28a0f45..da4c3ecc5d3 100644 --- a/include/oneapi/dnnl/dnnl_types.h +++ b/include/oneapi/dnnl/dnnl_types.h @@ -2019,7 +2019,9 @@ typedef enum { dnnl_group_normalization, /// A transpose primitive. dnnl_transpose, - + /// A mask primitive. + dnnl_mask, + /// Parameter to allow internal only primitives without undefined behavior. /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. dnnl_primitive_kind_max = 0x7fff, diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index 7597d81452a..79720083357 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -1918,6 +1918,7 @@ using primitive_kind_t = dnnl_primitive_kind_t; namespace primitive_kind { const primitive_kind_t undefined = dnnl_undefined_primitive; const primitive_kind_t transpose = dnnl_transpose; +const primitive_kind_t mask = dnnl_mask; const primitive_kind_t reorder = dnnl_reorder; const primitive_kind_t concat = dnnl_concat; const primitive_kind_t sum = dnnl_sum; diff --git a/src/common/dnnl_traits.hpp b/src/common/dnnl_traits.hpp index 73e874a35fe..98dc0f65d84 100644 --- a/src/common/dnnl_traits.hpp +++ b/src/common/dnnl_traits.hpp @@ -176,6 +176,7 @@ PKIND_TRAITS_INST(resampling); PKIND_TRAITS_INST(reduction); PKIND_TRAITS_INST(sdpa); PKIND_TRAITS_INST(transpose); +PKIND_TRAITS_INST(mask); #undef PKIND_TRAITS_INST } // namespace impl diff --git a/src/common/mask.cpp b/src/common/mask.cpp new file mode 100644 index 00000000000..90c37512200 --- /dev/null +++ b/src/common/mask.cpp @@ -0,0 +1,82 @@ +/******************************************************************************* +* Copyright 2016-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "opdesc.hpp" +#include "primitive_desc_iface.hpp" + +#include "oneapi/dnnl/dnnl.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace dnnl::impl; +using namespace dnnl::impl::utils; +using namespace dnnl::impl::status; + +namespace dnnl { +namespace impl { + +#define VCHECK_TRANSPOSE(cond, msg, ...) \ + VCONDCHECK(primitive, create, check, mask, (cond), \ + status::invalid_arguments, msg, ##__VA_ARGS__); + +status_t mask_desc_init(mask_desc_t *mask_desc, + const memory_desc_t *src_md, const memory_desc_t *dst_md, + const memory_desc_t *mask_md, double value_f, double value_i) { + VCHECK_TRANSPOSE(!any_null(src_md, dst_md), VERBOSE_NULL_ARG); + + auto op_d = mask_desc_t(); + op_d.primitive_kind = primitive_kind::mask; + op_d.src_desc = *src_md; + op_d.dst_desc = *dst_md; + op_d.mask_desc = *mask_md; + op_d.value_f = value_f; + op_d.value_i = value_i; + + *mask_desc = op_d; + return status::success; +} + +} // namespace impl +} // namespace dnnl + +status_t dnnl_mask_primitive_desc_create( + primitive_desc_iface_t **primitive_desc_iface, dnnl_engine_t engine, + const memory_desc_t *src_md, const memory_desc_t *dst_md, + const memory_desc_t *mask_md, + double value) { + auto mask_desc = mask_desc_t(); + CHECK(mask_desc_init( + &mask_desc, src_md, dst_md, mask_md, value, 0)); + + return primitive_desc_create(primitive_desc_iface, engine, + (const op_desc_t *)&mask_desc, nullptr, nullptr); +} + +status_t dnnl_mask_primitive_desc_create( + primitive_desc_iface_t **primitive_desc_iface, dnnl_engine_t engine, + const memory_desc_t *src_md, const memory_desc_t *dst_md, + const memory_desc_t *mask_md, + int64_t value) { + auto mask_desc = mask_desc_t(); + CHECK(mask_desc_init( + &mask_desc, src_md, dst_md, mask_md, 0, value)); + + return primitive_desc_create(primitive_desc_iface, engine, + (const op_desc_t *)&mask_desc, nullptr, nullptr); +} \ No newline at end of file diff --git a/src/common/mask_pd.hpp b/src/common/mask_pd.hpp new file mode 100644 index 00000000000..71040c7300f --- /dev/null +++ b/src/common/mask_pd.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2016-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_MASK_PD_HPP +#define COMMON_MASK_PD_HPP + +#include + +#include "oneapi/dnnl/dnnl.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +#define VDISPATCH_MASK(cond, msg, ...) \ + VCONDCHECK(primitive, create, dispatch, mask, (cond), \ + status::unimplemented, "%s," msg, this->info(engine), \ + ##__VA_ARGS__) + +namespace dnnl { +namespace impl { + +status_t mask_desc_init(mask_desc_t *mask_desc, + const memory_desc_t *src_md, const memory_desc_t *dst_md, + const memory_desc_t *mask_md, + double value_f, double value_i); + +struct mask_pd_t : public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::mask; + + typedef mask_pd_t base_class; + typedef mask_pd_t hint_class; + + const mask_desc_t *desc() const { return &desc_; } + const op_desc_t *op_desc() const override { + return reinterpret_cast(this->desc()); + } + + arg_usage_t arg_usage(int arg) const override { + if (arg == DNNL_ARG_FROM) return arg_usage_t::input; + + if (arg == DNNL_ARG_TO) return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + const memory_desc_t *arg_md( + int arg, bool user_input = false) const override { + switch (arg) { + case DNNL_ARG_FROM: return src_md(0); + case DNNL_ARG_TO: return dst_md(0, user_input); + default: return primitive_desc_t::arg_md(arg); + } + } + + const memory_desc_t *src_md( + int index = 0, bool user_input = false) const override { + if (index == 0) return user_input ? &desc()->src_desc : &src_md_; + return &glob_zero_md; + } + + const memory_desc_t *dst_md( + int index = 0, bool user_input = false) const override { + if (index == 0) return user_input ? &desc()->dst_desc : &dst_md_; + return &glob_zero_md; + } + + const memory_desc_t *mask_md( + int index = 0, bool user_input = false) const { + if (index == 0) return user_input ? &desc()->mask_desc : &mask_md_; + return &glob_zero_md; + } + + int n_inputs() const override { return 1; } + int n_outputs() const override { return 1; } + + double value_f() const { return value_f_; } + int64_t value_i() const { return value_i_; } + +protected: + mask_desc_t desc_; + memory_desc_t src_md_; + memory_desc_t dst_md_; + memory_desc_t mask_md_; + double value_f_; + int64_t value_i_; + + mask_pd_t(const mask_desc_t *adesc, const primitive_attr_t *attr, + const mask_pd_t *hint_fwd_pd) + : primitive_desc_t(attr, base_pkind) + , desc_(*adesc) + , src_md_(desc_.src_desc) + , dst_md_(desc_.dst_desc) + , mask_md_(desc_.mask_desc) + , value_f_(desc_.value_f) + , value_i_(desc_.value_i) {} +}; + +} // namespace impl +} // namespace dnnl + +#endif \ No newline at end of file diff --git a/src/common/opdesc.hpp b/src/common/opdesc.hpp index 29f931197a5..8764e2d8ec7 100644 --- a/src/common/opdesc.hpp +++ b/src/common/opdesc.hpp @@ -26,6 +26,15 @@ namespace dnnl { namespace impl { +struct mask_desc_t { + primitive_kind_t primitive_kind; + memory_desc_t src_desc; + memory_desc_t dst_desc; + memory_desc_t mask_desc; + double value_f = 0; + int64_t value_i = 0; +}; + struct transpose_desc_t { primitive_kind_t primitive_kind; memory_desc_t src_desc; @@ -627,6 +636,7 @@ struct op_desc_t { reduction_desc_t reduction; sdpa_desc_t sdpa; transpose_desc_t transpose; + mask_desc_t mask; }; #define DECL_CTOR_AND_CONVERTERS(c_type) \ @@ -654,6 +664,7 @@ struct op_desc_t { DECL_CTOR_AND_CONVERTERS(concat_desc_t); DECL_CTOR_AND_CONVERTERS(reorder_desc_t); DECL_CTOR_AND_CONVERTERS(transpose_desc_t); + DECL_CTOR_AND_CONVERTERS(mask_desc_t); DECL_CTOR_AND_CONVERTERS(sum_desc_t); DECL_CTOR_AND_CONVERTERS(binary_desc_t); DECL_CTOR_AND_CONVERTERS(matmul_desc_t); diff --git a/src/common/primitive_desc_iface.cpp b/src/common/primitive_desc_iface.cpp index fb201c9e548..3afd382cab5 100644 --- a/src/common/primitive_desc_iface.cpp +++ b/src/common/primitive_desc_iface.cpp @@ -41,7 +41,7 @@ status_t primitive_desc_create(primitive_desc_iface_t **primitive_desc_iface, batch_normalization, binary, convolution, deconvolution, eltwise, gemm, group_normalization, inner_product, layer_normalization, lrn, matmul, pooling, prelu, reduction, resampling, rnn, sdpa, shuffle, - softmax, transpose); + softmax, dnnl_transpose, mask); if (!known_primitive_kind) return invalid_arguments; auto pd_iface = utils::make_unique(engine, op_desc, diff --git a/src/common/primitive_hashing.cpp b/src/common/primitive_hashing.cpp index 863692d91ba..fadea116926 100644 --- a/src/common/primitive_hashing.cpp +++ b/src/common/primitive_hashing.cpp @@ -88,6 +88,7 @@ bool key_t::operator==(const key_t &rhs) const { CASE(softmax) CASE(sum) CASE(transpose) + CASE(mask) CASE(zero_pad) default: assert(!"unknown primitive kind"); } @@ -716,6 +717,21 @@ size_t get_desc_hash(const transpose_desc_t &desc) { return seed; } +size_t get_desc_hash(const mask_desc_t &desc) { + size_t seed = 0; + // Kinds + seed = hash_combine(seed, static_cast(desc.primitive_kind)); + // Memory descriptors + seed = hash_combine(seed, get_md_hash(desc.src_desc)); + seed = hash_combine(seed, get_md_hash(desc.dst_desc)); + seed = hash_combine(seed, get_md_hash(desc.mask_desc)); + // Mask value + seed = hash_combine(seed, desc.value_f); + seed = hash_combine(seed, desc.value_i); + // Combined hash for softmax desc + return seed; +} + size_t get_desc_hash(const zero_pad_desc_t &desc) { size_t seed = 0; // Kinds diff --git a/src/common/primitive_hashing.hpp b/src/common/primitive_hashing.hpp index b70e8e9c320..6a98d8912c5 100644 --- a/src/common/primitive_hashing.hpp +++ b/src/common/primitive_hashing.hpp @@ -96,6 +96,7 @@ size_t get_desc_hash(const shuffle_desc_t &desc); size_t get_desc_hash(const softmax_desc_t &desc); size_t get_desc_hash(const sum_desc_t &desc); size_t get_desc_hash(const transpose_desc_t &desc); +size_t get_desc_hash(const mask_desc_t &desc); size_t get_desc_hash(const zero_pad_desc_t &desc); template @@ -187,6 +188,7 @@ struct hash { CASE(sum) CASE(zero_pad) CASE(transpose) + CASE(mask) default: assert(!"unknown primitive_kind"); } // clang-format on diff --git a/src/common/transpose.cpp b/src/common/transpose.cpp index 974cd5063ff..2d707bb9bc9 100644 --- a/src/common/transpose.cpp +++ b/src/common/transpose.cpp @@ -32,7 +32,7 @@ namespace dnnl { namespace impl { #define VCHECK_TRANSPOSE(cond, msg, ...) \ - VCONDCHECK(primitive, create, check, transpose, (cond), \ + VCONDCHECK(primitive, create, check, dnnl_transpose, (cond), \ status::invalid_arguments, msg, ##__VA_ARGS__); status_t transpose_desc_init(transpose_desc_t *transpose_desc, @@ -58,7 +58,7 @@ status_t transpose_desc_init(transpose_desc_t *transpose_desc, } // namespace dnnl status_t dnnl_transpose_primitive_desc_create( - primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, + primitive_desc_iface_t **primitive_desc_iface, dnnl_engine_t engine, const memory_desc_t *src_md, const memory_desc_t *dst_md, dnnl_dim_t dim1, dnnl_dim_t dim2) { auto transpose_desc = transpose_desc_t(); diff --git a/src/common/transpose_pd.hpp b/src/common/transpose_pd.hpp index 317be266a50..33a92819130 100644 --- a/src/common/transpose_pd.hpp +++ b/src/common/transpose_pd.hpp @@ -103,6 +103,4 @@ struct transpose_pd_t : public primitive_desc_t { } // namespace impl } // namespace dnnl -#endif - -// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s +#endif \ No newline at end of file diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index 7ce7807c39a..4572407a79c 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -807,6 +807,17 @@ inline bool operator==(const transpose_desc_t &lhs, const transpose_desc_t &rhs) return ret; } +inline bool operator==(const mask_desc_t &lhs, const mask_desc_t &rhs) { + bool ret = COMPARE_DESC_MEMBERS(primitive_kind) + && COMPARE_DESC_MEMBERS(src_desc) + && COMPARE_DESC_MEMBERS(dst_desc) + && COMPARE_DESC_MEMBERS(mask_desc) + && COMPARE_DESC_MEMBERS(value_f) + && COMPARE_DESC_MEMBERS(value_i); + return ret; +} + + inline bool operator==(const zero_pad_desc_t &lhs, const zero_pad_desc_t &rhs) { bool ret = COMPARE_DESC_MEMBERS(primitive_kind); return ret; @@ -1140,6 +1151,7 @@ inline void copy_c_op_desc(op_desc_t *dst, const op_desc_t *src) { CASE_OP_DESC(shuffle); CASE_OP_DESC(softmax); CASE_OP_DESC(transpose); + CASE_OP_DESC(mask); // Internal descs CASE_OP_DESC(zero_pad); diff --git a/src/common/verbose.hpp b/src/common/verbose.hpp index 2bbc72537a0..1b49157895d 100644 --- a/src/common/verbose.hpp +++ b/src/common/verbose.hpp @@ -190,6 +190,7 @@ struct component_t { graph = 1 << 22, gemm_api = 1 << 23, transpose = 1 << 24, + mask = 1 << 25, all = (uint32_t)-1, }; }; diff --git a/src/gpu/amd/custom/hip_customs.h b/src/gpu/amd/custom/hip_customs.h index 3f25aa78121..34e1f46eacb 100644 --- a/src/gpu/amd/custom/hip_customs.h +++ b/src/gpu/amd/custom/hip_customs.h @@ -20,5 +20,7 @@ namespace hip_custom { void transpose(int dtype, void *input, void *output, const size_t *dims, int num_dims, int dim1, int dim2); +void mask(void *input, void *output, void *mask, const size_t *dims, const size_t *dims_mask, int num_dims, float masked_value, int fp_length); + } diff --git a/src/gpu/amd/custom/hip_mask.cc b/src/gpu/amd/custom/hip_mask.cc new file mode 100644 index 00000000000..bc0eda2e93a --- /dev/null +++ b/src/gpu/amd/custom/hip_mask.cc @@ -0,0 +1,191 @@ +/******************************************************************************* +* Copyright 2016-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include "hip_customs.h" +#include + +#define MAX_DIM 8 +#define blockSize 256 + +// Kernel to mask the input, can be done in place. the mask tensor shall be broadcast to the input tensor. +__global__ void mask_kernel_f32(float* input, float* output, char* mask, size_t* stride_io, size_t* dims_mask, int num_dims, size_t total_elements, float masked_value) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int Scope = blockDim.x * gridDim.x; + int IterCount = total_elements / Scope + 1; + int mask_coordinate[MAX_DIM]; + + __shared__ size_t s_stride_io[MAX_DIM]; + __shared__ size_t s_dims_mask[MAX_DIM]; + __shared__ size_t s_stride_mask[MAX_DIM]; + + if(idx < num_dims) + { + s_stride_io[threadIdx.x] = stride_io[threadIdx.x]; + s_dims_mask[threadIdx.x] = dims_mask[threadIdx.x]; + } + + if(idx == 0) + { + s_stride_mask[num_dims-1] = 1; + for(int i=num_dims-2; i >= 0; i--) + s_stride_mask[i] = s_stride_mask[i+1] * s_dims_mask[i+1]; + } + + __syncthreads(); + + for(int i = 0; i < IterCount; i++) + { + size_t io_offset = idx + Scope * i; + if(io_offset >= total_elements) + continue; + + // read + float value = input[io_offset]; + + // colculate mask_coordinate of element. + size_t margin = io_offset; + for(int j = 0; j < num_dims; j++) + { + mask_coordinate[j] = margin / s_stride_io[j]; + if(s_dims_mask[j] == 1) + mask_coordinate[j] = 1; + margin = margin % s_stride_io[j]; + } + + // broadcast mask offset + size_t mask_offset = 0; + for(int j = 0; j < num_dims; j++) + { + mask_offset += s_stride_mask[j] * mask_coordinate[j]; + } + + char mask_value = mask[mask_offset]; + // Fills elements of input tensor with value where mask is True + if(mask_value != 0) + output[io_offset] = masked_value; + // if not inplace and mask is False, copy input to output. + if(input != output && mask_value == 0) + output[io_offset] = input[io_offset]; + } +} + +// Kernel to mask the input, can be done in place. the mask tensor shall be broadcast to the input tensor. +__global__ void mask_kernel_f16(half* input, half* output, char* mask, size_t* stride_io, size_t* dims_mask, int num_dims, size_t total_elements, half masked_value) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int Scope = blockDim.x * gridDim.x; + int IterCount = total_elements / Scope + 1; + int mask_coordinate[MAX_DIM]; + + __shared__ size_t s_stride_io[MAX_DIM]; + __shared__ size_t s_dims_mask[MAX_DIM]; + __shared__ size_t s_stride_mask[MAX_DIM]; + + if(idx < num_dims) + { + s_stride_io[threadIdx.x] = stride_io[threadIdx.x]; + s_dims_mask[threadIdx.x] = dims_mask[threadIdx.x]; + } + + if(idx == 0) + { + s_stride_mask[num_dims-1] = 1; + for(int i=num_dims-2; i >= 0; i--) + s_stride_mask[i] = s_stride_mask[i+1] * s_dims_mask[i+1]; + } + + __syncthreads(); + + for(int i = 0; i < IterCount; i++) + { + size_t io_offset = idx + Scope * i; + if(io_offset >= total_elements) + continue; + + // read + half value = input[io_offset]; + + // colculate mask_coordinate of element. + size_t margin = io_offset; + for(int j = 0; j < num_dims; j++) + { + mask_coordinate[j] = margin / s_stride_io[j]; + if(s_dims_mask[j] == 1) + mask_coordinate[j] = 1; + margin = margin % s_stride_io[j]; + } + + // broadcast mask offset + size_t mask_offset = 0; + for(int j = 0; j < num_dims; j++) + { + mask_offset += s_stride_mask[j] * mask_coordinate[j]; + } + + char mask_value = mask[mask_offset]; + // Fills elements of input tensor with value where mask is True + if(mask_value != 0) + output[io_offset] = __half2float(masked_value); + // if not inplace and mask is False, copy input to output. + if(input != output && mask_value == 0) + output[io_offset] = input[io_offset]; + } +} + + +inline void prepare_utils(int num_dims, const size_t *dims, const size_t *dims_mask, size_t **d_io_strides, size_t **d_dims_mask, size_t* total_elements, int* numBlocks){ + size_t* io_strides = (size_t*)malloc(num_dims * sizeof(size_t)); + io_strides[num_dims-1] = 1; + + for (int i = num_dims-2; i >= 0; i--) + io_strides[i] = io_strides[i+1] * dims[i+1]; + + hipMalloc(d_io_strides, num_dims * sizeof(size_t)); + hipMalloc(d_dims_mask, num_dims * sizeof(size_t)); + hipMemcpy(*d_io_strides, io_strides, num_dims*sizeof(size_t), hipMemcpyHostToDevice); + hipMemcpy(*d_dims_mask, dims_mask, num_dims*sizeof(size_t), hipMemcpyHostToDevice); + + *total_elements = io_strides[0] * dims[0]; + + // ceiling div + *numBlocks = *total_elements / blockSize; + if(*total_elements % blockSize != 0) + *numBlocks += 1; +} + +namespace hip_custom { + +// float tensor mask +void mask(void *input, void *output, void *mask, const size_t *dims, const size_t *dims_mask, int num_dims, float masked_value, int fp_length) +{ + size_t *d_io_strides, *d_dims_mask; + size_t total_elements; + int numBlocks; + + prepare_utils(num_dims, dims, dims_mask, &d_io_strides, &d_dims_mask, &total_elements, &numBlocks); + + if(fp_length == 32) + hipLaunchKernelGGL(mask_kernel_f32, dim3(numBlocks), dim3(blockSize), 0, 0, (float*)input, (float*)output, (char*)mask, d_io_strides, d_dims_mask, num_dims, total_elements, masked_value); + if(fp_length == 16) + hipLaunchKernelGGL(mask_kernel_f16, dim3(numBlocks), dim3(blockSize), 0, 0, (half*)input, (half*)output, (char*)mask, d_io_strides, d_dims_mask, num_dims, total_elements, masked_value); + + hipFree(d_io_strides); + hipFree(d_dims_mask); +} + +} diff --git a/src/gpu/amd/custom/hip_transpose.cc b/src/gpu/amd/custom/hip_transpose.cc index c5b0f0fba19..43b13bb7658 100644 --- a/src/gpu/amd/custom/hip_transpose.cc +++ b/src/gpu/amd/custom/hip_transpose.cc @@ -34,6 +34,8 @@ __global__ void transpose_kernel_f32(float *input, float *output, size_t* in_str istrides[threadIdx.x] = in_strides[threadIdx.x]; ostrides[threadIdx.x] = out_strides[threadIdx.x]; } + + __syncthreads(); for(int i = 0; i < IterCount; i++) { diff --git a/src/gpu/amd/hip_mask.cpp b/src/gpu/amd/hip_mask.cpp new file mode 100644 index 00000000000..f9952cef22b --- /dev/null +++ b/src/gpu/amd/hip_mask.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* +* Copyright 2020-2022 Intel Corporation +* Copyright 2020 Codeplay Software Limited +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gpu/amd/hip_mask.hpp" +#include "gpu/amd/sycl_hip_scoped_context.hpp" +#include "gpu/amd/sycl_hip_stream.hpp" +#include "xpu/sycl/buffer_memory_storage.hpp" +#include "xpu/sycl/memory_storage_helper.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace amd { + +status_t hip_mask_t::execute(const exec_ctx_t &ctx) const { + if (memory_desc_wrapper(pd()->src_md()).has_zero_dim()) + return status::success; + amd::sycl_hip_stream_t *hip_stream + = utils::downcast(ctx.stream()); + + return hip_stream->interop_task([&](::sycl::handler &cgh) { + auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC); + auto arg_dst = CTX_OUT_SYCL_MEMORY(DNNL_ARG_DST); + auto arg_mask = CTX_IN_SYCL_MEMORY(DNNL_ARG_WEIGHTS); + + compat::host_task(cgh, [=](const compat::interop_handle &ih) { + auto &sycl_engine = *utils::downcast( + hip_stream->engine()); + auto sc = hip_sycl_scoped_context_handler_t(sycl_engine); + + void *x = arg_src.get_native_pointer(ih); + void *y = arg_dst.get_native_pointer(ih); + void *mask = arg_mask.get_native_pointer(ih); + + pd()->mask_impl_->execute(x, y, mask); + }); + }); +} + +} // namespace amd +} // namespace gpu +} // namespace impl +} // namespace dnnl diff --git a/src/gpu/amd/hip_mask.hpp b/src/gpu/amd/hip_mask.hpp new file mode 100644 index 00000000000..d0b0cdd8738 --- /dev/null +++ b/src/gpu/amd/hip_mask.hpp @@ -0,0 +1,93 @@ +/******************************************************************************* +* Copyright 2020-2022 Intel Corporation +* Copyright 2020 Codeplay Software Limited +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_AMD_HIP_MASK_HPP +#define GPU_AMD_HIP_MASK_HPP + +#include "common/mask_pd.hpp" +#include "common/c_types_map.hpp" +#include "common/primitive.hpp" +#include "gpu/amd/hip_mask_impl.hpp" +#include "gpu/amd/sycl_hip_engine.hpp" +#include "gpu/amd/sycl_hip_utils.hpp" +#include + +namespace dnnl { +namespace impl { +namespace gpu { +namespace amd { + +struct hip_mask_t : public primitive_t { + using primitive_t::primitive_t; + + struct pd_t : public mask_pd_t { + using mask_pd_t::mask_pd_t; + + DECLARE_COMMON_PD_T("hip:miopen:any", hip_mask_t); + + status_t init(engine_t *) { + using namespace data_type; + + // TODO: do some check. + // bool ok = true; + + if (!check_for_zero_dims()) + return status::invalid_arguments; + if (!check_data_types()) + return status::invalid_arguments; + + mask_impl_.reset(new hip_mask_impl_t()); + return mask_impl_->init(this); + } + + bool check_for_zero_dims() const { + return has_zero_dims(src_md()->dims, src_md()->ndims) + || has_zero_dims(dst_md()->dims, dst_md()->ndims); + } + + bool check_no_blocking() const { + // Blocking is not supported by MIOPENOpTensor, return false if any + // blocks are present + return src_md(0)->format_desc.blocking.inner_nblks + + src_md(1)->format_desc.blocking.inner_nblks + + dst_md()->format_desc.blocking.inner_nblks + == 0; + } + + bool check_data_types() const { + using namespace data_type; + data_type_t input_type = src_md()->data_type; + data_type_t output_type = dst_md()->data_type; + bool type_same = (input_type == output_type); + return type_same; + } + + std::shared_ptr mask_impl_; + }; + + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +}; + +} // namespace amd +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/amd/hip_mask_impl.hpp b/src/gpu/amd/hip_mask_impl.hpp new file mode 100644 index 00000000000..e4a751b3e4a --- /dev/null +++ b/src/gpu/amd/hip_mask_impl.hpp @@ -0,0 +1,67 @@ +/******************************************************************************* +* Copyright 2020-2022 Intel Corporation +* Copyright 2020 Codeplay Software Limited +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_AMD_HIP_MASK_IMPL_HPP +#define GPU_AMD_HIP_MASK_IMPL_HPP +#include "gpu/amd/sycl_hip_utils.hpp" +#include "gpu/amd/custom/hip_customs.h" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace amd { + +struct hip_mask_impl_t { + status_t init(const mask_pd_t *pd) { + this->src_dtype = pd->src_md()->data_type; + this->value_f = pd->value_f(); + this->value_i = pd->value_i(); + this->num_dims = pd->src_md()->ndims; + + for(int i=0; inum_dims; i++){ + this->dims_io[i] = pd->src_md()->dims[i]; + this->dims_mask[i] = pd->mask_md()->dims[i]; + } + this->src_dtype = pd->src_md()->data_type; + return status::success; + } + + void execute(void *x, void *y, void* mask) const { + if(src_dtype == dnnl::impl::data_type_t::dnnl_f32){ + hip_custom::mask(x, y, mask, dims_io, dims_mask, num_dims, (float)value_f, 32); + } + else if(src_dtype == dnnl::impl::data_type_t::dnnl_f16){ + hip_custom::mask(x, y, mask, dims_io, dims_mask, num_dims, (float)value_f, 16); + } + return; + } + + dnnl::impl::data_type_t src_dtype; + int num_dims; + size_t dims_io[8]; + size_t dims_mask[8]; + + double value_f; + int64_t value_i; +}; + +} // namespace amd +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/amd/hip_transpose.cpp b/src/gpu/amd/hip_transpose.cpp index 06ed8d3eb6e..a2fcb7b3a25 100644 --- a/src/gpu/amd/hip_transpose.cpp +++ b/src/gpu/amd/hip_transpose.cpp @@ -18,8 +18,8 @@ #include "gpu/amd/hip_transpose.hpp" #include "gpu/amd/sycl_hip_scoped_context.hpp" #include "gpu/amd/sycl_hip_stream.hpp" -#include "sycl/sycl_buffer_memory_storage.hpp" -#include "sycl/sycl_memory_storage_helper.hpp" +#include "xpu/sycl/buffer_memory_storage.hpp" +#include "xpu/sycl/memory_storage_helper.hpp" namespace dnnl { namespace impl { diff --git a/src/gpu/amd/sycl_hip_engine.cpp b/src/gpu/amd/sycl_hip_engine.cpp index f178725d29c..879cab8ccd6 100644 --- a/src/gpu/amd/sycl_hip_engine.cpp +++ b/src/gpu/amd/sycl_hip_engine.cpp @@ -22,6 +22,7 @@ #include "miopen/miopen.h" #include "xpu/sycl/utils.hpp" #include "gpu/amd/hip_transpose.hpp" +#include "gpu/amd/hip_mask.hpp" #include "gpu/amd/sycl_hip_compat.hpp" #include "gpu/amd/sycl_hip_engine.hpp" #include "gpu/amd/sycl_hip_scoped_context.hpp" @@ -156,6 +157,10 @@ using namespace dnnl::impl::data_type; // clang-format off constexpr dnnl::impl::impl_list_item_t sycl_hip_impl_list[] = { + // Mask + INSTANCE(hip_mask_t) + // Transpose + INSTANCE(hip_transpose_t) // Binary INSTANCE(miopen_binary_t) // Elementwise