Skip to content

Commit

Permalink
add mask primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
pyf committed Jun 6, 2024
1 parent 2f7d655 commit 3e9fb57
Show file tree
Hide file tree
Showing 23 changed files with 828 additions and 13 deletions.
38 changes: 37 additions & 1 deletion include/oneapi/dnnl/dnnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,42 @@ dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
/// @addtogroup dnnl_api_primitives
/// @{

/// @addtogroup dnnl_api_mask
/// @{

/// Creates a primitive descriptor for a mask primitive.
///
/// @param mask_primitive_desc Output primitive descriptor.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param mask_desc Mask memory descriptor.
/// @param value value to fill in Source with.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_mask_primitive_desc_create(
dnnl_primitive_desc_t *mask_primitive_desc,
dnnl_engine_t engine, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc,
double value);

/// Creates a primitive descriptor for a mask primitive.
///
/// @param mask_primitive_desc Output primitive descriptor.
/// @param src_desc Source memory descriptor.
/// @param dst_desc Destination memory descriptor.
/// @param mask_desc Mask memory descriptor.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_mask_primitive_desc_create(
dnnl_primitive_desc_t *mask_primitive_desc,
dnnl_engine_t engine, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc,
double value);

/// @} dnnl_api_mask

/// @addtogroup dnnl_api_transpose
/// @{

Expand All @@ -1373,7 +1409,7 @@ dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
/// otherwise.
dnnl_status_t DNNL_API dnnl_transpose_primitive_desc_create(
dnnl_primitive_desc_t *transpose_primitive_desc,
dnnl_engine *engine, const_dnnl_memory_desc_t src_desc,
dnnl_engine_t engine, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t dst_desc,
dnnl_dim_t dim1, dnnl_dim_t dim2);

Expand Down
127 changes: 124 additions & 3 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ struct primitive : public handle<dnnl_primitive_t> {
enum class kind {
/// Undefined primitive
undef = dnnl_undefined_primitive,
/// A mask primitive.
mask = dnnl_mask,
/// A transpose primitive.
transpose = dnnl_transpose,
/// A reorder primitive.
Expand Down Expand Up @@ -4896,6 +4898,123 @@ struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {

/// @} dnnl_api_primitives_common

/// @addtogroup dnnl_api_mask mask
///
/// A primitive to copy data between two memory objects. This primitive is
/// typically used to exchange the two specified dimensions of the src memory.
///
/// @sa @ref dev_guide_transpose in developer guide
///
/// @{

/// Transpose primitive.
struct mask : public primitive {
/// Primitive descriptor for a mask primitive.
struct primitive_desc : public primitive_desc_base {
using primitive_desc_base::primitive_desc_base;

/// Default constructor. Produces an empty object.
primitive_desc() = default;

/// Constructs a primitive descriptor for mask primitive.
///
/// @param src Source memory object. It is used to obtain the source
/// memory descriptor and engine.
/// @param mask mask(weight) memory object. It is used to obtain the mask
/// memory descriptor and engine.
/// @param dst Destination memory object. It is used to obtain the
/// destination memory descriptor and engine.
/// @param value value that used to fill masked place of Source memory.
/// @param allow_empty A flag signifying whether construction is allowed
/// to fail without throwing an exception. In this case an empty
/// object will be produced. This flag is optional and defaults to
/// false.
primitive_desc(const engine &aengine, const memory &src, const memory &mask,
const memory &dst, double value, bool allow_empty = false) {
dnnl_primitive_desc_t result;
auto src_md = src.get_desc();
auto mask_md = mask.get_desc();
auto dst_md = dst.get_desc();
dnnl_status_t status = dnnl_mask_primitive_desc_create(&result,
aengine.get(), src_md.get(), dst_md.get(), mask_md.get(), value);
if (!allow_empty)
error::wrap_c_api(status,
"could not create a primitive descriptor for a mask "
"primitive");
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
}

/// Constructs a primitive descriptor for mask primitive, for integer memory.
///
/// @param src Source memory object. It is used to obtain the source
/// memory descriptor and engine.
/// @param mask mask(weight) memory object. It is used to obtain the mask
/// memory descriptor and engine.
/// @param dst Destination memory object. It is used to obtain the
/// destination memory descriptor and engine.
/// @param value value that used to fill masked place of Source memory.
/// @param allow_empty A flag signifying whether construction is allowed
/// to fail without throwing an exception. In this case an empty
/// object will be produced. This flag is optional and defaults to
/// false.
primitive_desc(const engine &aengine, const memory &src, const memory &mask,
const memory &dst, int64_t value, bool allow_empty = false) {
dnnl_primitive_desc_t result;
auto src_md = src.get_desc();
auto mask_md = mask.get_desc();
auto dst_md = dst.get_desc();
dnnl_status_t status = dnnl_mask_primitive_desc_create(&result,
aengine.get(), src_md.get(), dst_md.get(), mask_md.get(), value);
if (!allow_empty)
error::wrap_c_api(status,
"could not create a primitive descriptor for a mask "
"primitive");
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
}

/// Constructs a primitive descriptor for reorder primitive from a C
/// API primitive descriptor which must have a matching kind.
///
/// @param pd C API primitive descriptor for reorder primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: primitive_desc_base(pd, dnnl::primitive::kind::mask) {}

/// @copydoc dnnl::primitive_desc_base::src_desc()const
memory::desc src_desc() const { return base::src_desc(0); }

/// @copydoc dnnl::primitive_desc_base::dst_desc()const
memory::desc dst_desc() const { return base::dst_desc(0); }
};

/// Default constructor. Produces an empty object.
mask() = default;

/// Constructs a mask primitive.
/// @param pd Primitive descriptor for mask primitive.
mask(const primitive_desc &pd) : primitive(pd.get()) {}

/// Constructs a mask primitive from a cache blob.
/// @param pd Primitive descriptor for mask primitive.
/// @param cache_blob Cache blob.
mask(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
: primitive(pd.get(), cache_blob) {}

using primitive::execute;

/// Executes the mask primitive.
///
/// @param astream Stream object. The stream must belong to the same engine
/// as the primitive.
/// @param src Source memory object.
/// @param dst Destination memory object.
/// @param mask Mask memory object.
void execute(const stream &astream, memory &src, memory &dst, memory &mask) const {
primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}, {DNNL_ARG_WEIGHTS, mask}});
}
};

/// @} dnnl_api_mask

/// @addtogroup dnnl_api_transpose Transpose
///
/// A primitive to copy data between two memory objects. This primitive is
Expand All @@ -4907,7 +5026,7 @@ struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {

/// Transpose primitive.
struct transpose : public primitive {
/// Primitive descriptor for a reorder primitive.
/// Primitive descriptor for a transpose primitive.
struct primitive_desc : public primitive_desc_base {
using primitive_desc_base::primitive_desc_base;

Expand All @@ -4920,6 +5039,8 @@ struct transpose : public primitive {
/// memory descriptor and engine.
/// @param dst Destination memory object. It is used to obtain the
/// destination memory descriptor and engine.
/// @param dim1 dim1 that transpose from.
/// @param dim2 dim1 that transpose to.
/// @param allow_empty A flag signifying whether construction is allowed
/// to fail without throwing an exception. In this case an empty
/// object will be produced. This flag is optional and defaults to
Expand All @@ -4938,10 +5059,10 @@ struct transpose : public primitive {
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
}

/// Constructs a primitive descriptor for reorder primitive from a C
/// Constructs a primitive descriptor for transpose primitive from a C
/// API primitive descriptor which must have a matching kind.
///
/// @param pd C API primitive descriptor for reorder primitive.
/// @param pd C API primitive descriptor for transpose primitive.
primitive_desc(dnnl_primitive_desc_t pd)
: primitive_desc_base(pd, dnnl::primitive::kind::transpose) {}

Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,9 @@ typedef enum {
dnnl_group_normalization,
/// A transpose primitive.
dnnl_transpose,

/// A mask primitive.
dnnl_mask,

/// Parameter to allow internal only primitives without undefined behavior.
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.
dnnl_primitive_kind_max = 0x7fff,
Expand Down
1 change: 1 addition & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,7 @@ using primitive_kind_t = dnnl_primitive_kind_t;
namespace primitive_kind {
const primitive_kind_t undefined = dnnl_undefined_primitive;
const primitive_kind_t transpose = dnnl_transpose;
const primitive_kind_t mask = dnnl_mask;
const primitive_kind_t reorder = dnnl_reorder;
const primitive_kind_t concat = dnnl_concat;
const primitive_kind_t sum = dnnl_sum;
Expand Down
1 change: 1 addition & 0 deletions src/common/dnnl_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ PKIND_TRAITS_INST(resampling);
PKIND_TRAITS_INST(reduction);
PKIND_TRAITS_INST(sdpa);
PKIND_TRAITS_INST(transpose);
PKIND_TRAITS_INST(mask);
#undef PKIND_TRAITS_INST

} // namespace impl
Expand Down
82 changes: 82 additions & 0 deletions src/common/mask.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*******************************************************************************
* Copyright 2016-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <assert.h>
#include "opdesc.hpp"
#include "primitive_desc_iface.hpp"

#include "oneapi/dnnl/dnnl.h"

#include "c_types_map.hpp"
#include "type_helpers.hpp"
#include "utils.hpp"

using namespace dnnl::impl;
using namespace dnnl::impl::utils;
using namespace dnnl::impl::status;

namespace dnnl {
namespace impl {

#define VCHECK_TRANSPOSE(cond, msg, ...) \
VCONDCHECK(primitive, create, check, mask, (cond), \
status::invalid_arguments, msg, ##__VA_ARGS__);

status_t mask_desc_init(mask_desc_t *mask_desc,
const memory_desc_t *src_md, const memory_desc_t *dst_md,
const memory_desc_t *mask_md, double value_f, double value_i) {
VCHECK_TRANSPOSE(!any_null(src_md, dst_md), VERBOSE_NULL_ARG);

auto op_d = mask_desc_t();
op_d.primitive_kind = primitive_kind::mask;
op_d.src_desc = *src_md;
op_d.dst_desc = *dst_md;
op_d.mask_desc = *mask_md;
op_d.value_f = value_f;
op_d.value_i = value_i;

*mask_desc = op_d;
return status::success;
}

} // namespace impl
} // namespace dnnl

status_t dnnl_mask_primitive_desc_create(
primitive_desc_iface_t **primitive_desc_iface, dnnl_engine_t engine,
const memory_desc_t *src_md, const memory_desc_t *dst_md,
const memory_desc_t *mask_md,
double value) {
auto mask_desc = mask_desc_t();
CHECK(mask_desc_init(
&mask_desc, src_md, dst_md, mask_md, value, 0));

return primitive_desc_create(primitive_desc_iface, engine,
(const op_desc_t *)&mask_desc, nullptr, nullptr);
}

status_t dnnl_mask_primitive_desc_create(
primitive_desc_iface_t **primitive_desc_iface, dnnl_engine_t engine,
const memory_desc_t *src_md, const memory_desc_t *dst_md,
const memory_desc_t *mask_md,
int64_t value) {
auto mask_desc = mask_desc_t();
CHECK(mask_desc_init(
&mask_desc, src_md, dst_md, mask_md, 0, value));

return primitive_desc_create(primitive_desc_iface, engine,
(const op_desc_t *)&mask_desc, nullptr, nullptr);
}
Loading

0 comments on commit 3e9fb57

Please sign in to comment.