Skip to content

Commit

Permalink
FusedMultiHeadAttentionInference (#9287)
Browse files Browse the repository at this point in the history
* FusedMultiHeadAttentionInference

* auto format by CI

* cmake

* fix graph

* auto format by CI

* fix cmake for mlir

* rm duplicated install

* fix align

* support float

* support causal

* support causal

* test global property

* fix

* disable clang

* skip cpu test

* skil all test

Co-authored-by: oneflow-ci-bot <[email protected]>
Co-authored-by: jackalcooper <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2022
1 parent bd6b00f commit b2091bc
Show file tree
Hide file tree
Showing 11 changed files with 631 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ set(KINETO_URL
use_mirror(VARIABLE KINETO_URL URL ${KINETO_URL})
set(KINETO_MD5 f9b550591b3899fb267270c19484933f)

set(CUTLASS_URL
https://github.com/NVIDIA/cutlass/archive/4db6a6140e45c4ffe6339c55b43b159602fa1f35.zip)
use_mirror(VARIABLE CUTLASS_URL URL ${CUTLASS_URL})
set(CUTLASS_MD5 132bc2d7b635e33666dd6fab1a9ab340)

include(cuda)
add_subdirectory(external)
include(third_party)
Expand Down
7 changes: 7 additions & 0 deletions cmake/oneflow.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,13 @@ if(BUILD_CUDA)
PROPERTIES COMPILE_FLAGS "-DCUDA_REAL_ARCHS=\"${CUDA_REAL_ARCHS}\"")
endif()

if(BUILD_CUDA)
get_target_property(CUTLASS_FMHA_INCLUDE_DIR cutlass_fmha_headers INTERFACE_INCLUDE_DIRECTORIES)
set_property(
SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_multi_head_attention_inference_kernel.cu
APPEND PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_FMHA_INCLUDE_DIR})
endif()

# oneflow api common
if(BUILD_PYTHON OR BUILD_CPP_API)
file(GLOB_RECURSE of_api_common_files ${PROJECT_SOURCE_DIR}/oneflow/api/common/*.h
Expand Down
6 changes: 5 additions & 1 deletion external/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,9 @@ list(APPEND EXTERNAL_TARGETS fmt)
add_subdirectory(kineto)
list(APPEND EXTERNAL_TARGETS kineto)

mark_targets_as_system(${EXTERNAL_TARGETS})
if(BUILD_CUDA)
add_subdirectory(cutlass)
list(APPEND EXTERNAL_TARGETS cutlass_headers)
endif()

set_property(GLOBAL PROPERTY EXTERNAL_TARGETS ${EXTERNAL_TARGETS})
19 changes: 19 additions & 0 deletions external/cutlass/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
include(FetchContent)
FetchContent_Declare(
cutlass
URL ${CUTLASS_URL}
URL_HASH MD5=${CUTLASS_MD5}
)
FetchContent_Populate(cutlass)

add_library(cutlass_headers INTERFACE)
set_property(TARGET cutlass_headers PROPERTY INTERFACE_INCLUDE_DIRECTORIES
$<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
add_library(cutlass_fmha_headers INTERFACE)
set_property(TARGET cutlass_fmha_headers PROPERTY INTERFACE_INCLUDE_DIRECTORIES
$<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/examples/42_fused_multi_head_attention>
$<INSTALL_INTERFACE:include>
)
install(TARGETS cutlass_headers;cutlass_fmha_headers EXPORT oneflow DESTINATION include)
7 changes: 6 additions & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2375,6 +2375,10 @@
signature: "Tensor (Tensor softmax_y, Tensor dy, Tensor mask, Int64 diagonal, Float tril_scale_value, Float mask_scale_value) => FusedScaleTrilSoftmaxMaskScaleGrad"
bind_python: False

- name: "fused_multi_head_attention_inference"
signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1) => FusedMultiHeadAttentionInference"
bind_python: True

- name: "send"
signature: "Void (Tensor input, Int64 dst, Bool send_meta=True) => Send"
bind_python: True
Expand Down Expand Up @@ -2871,4 +2875,5 @@

- name: "bincount"
signature: "Tensor (Tensor input, Tensor weights=None, Int64 minlength=None) => BinCount"
bind_python: True
bind_python: True

31 changes: 31 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4430,6 +4430,36 @@ class BatchNormBackwardElemtFunctor {
std::shared_ptr<OpExpr> op_;
};

class FusedMultiHeadAttentionInferenceFunctor {
public:
FusedMultiHeadAttentionInferenceFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fused_multi_head_attention_inference")
.Input("query")
.Input("key")
.Input("value")
.Output("out")
.Build());
}
Maybe<Tensor> operator()(
const std::shared_ptr<one::Tensor>& query, const std::shared_ptr<one::Tensor>& key,
const std::shared_ptr<one::Tensor>& value, const int64_t& num_heads, const bool& causal,
const int64_t& query_hidden_slice_start, const int64_t& query_hidden_slice_end,
const int64_t& key_hidden_slice_start, const int64_t& key_hidden_slice_end,
const int64_t& value_hidden_slice_start, const int64_t& value_hidden_slice_end) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("num_heads", "causal", "query_hidden_slice_start",
"query_hidden_slice_end", "key_hidden_slice_start",
"key_hidden_slice_end", "value_hidden_slice_start",
"value_hidden_slice_end");
attrs.SetAllAttrs(num_heads, causal, query_hidden_slice_start, query_hidden_slice_end,
key_hidden_slice_start, key_hidden_slice_end, value_hidden_slice_start,
value_hidden_slice_end);
return OpInterpUtil::Dispatch<Tensor>(*op_, {query, key, value}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -4550,6 +4580,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BatchNormElemtFunctor>("BatchNormElemt");
m.add_functor<impl::BatchNormBackwardReduceFunctor>("BatchNormBackwardReduce");
m.add_functor<impl::BatchNormBackwardElemtFunctor>("BatchNormBackwardElemt");
m.add_functor<impl::FusedMultiHeadAttentionInferenceFunctor>("FusedMultiHeadAttentionInference");
}

} // namespace functional
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ const AMPList& AutoMixedPrecisionLists::WhiteList() {
"binary_cross_entropy_with_logits_reduce_mean_grad",
"fused_cross_feature_interaction",
"fused_cross_feature_interaction_v1_grad",
"fused_cross_feature_interaction_v2_grad"};
"fused_cross_feature_interaction_v2_grad",
"fused_multi_head_attention_inference"};
return white_list;
}

Expand Down
29 changes: 27 additions & 2 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2214,8 +2214,8 @@ def OneFlow_EagerSymmetricSToPOp : OneFlow_BaseOp<"eager_symmetric_s_to_p", [NoS
#endif // GET_ONEFLOW_EAGER_OP_DEFINITIONS

// Group: FUSED
// cudnn_fused_normalization_add_relu, cudnn_fused_normalization_add_relu_grad, fused_bias_add_gelu, fused_bias_add_gelu_grad, fused_bias_add_mask_scale, fused_cast_scale, fused_scale_mask_softmax, fused_scale_mask_softmax_dropout, fused_scale_mask_softmax_dropout_grad, fused_scale_mask_softmax_grad, fused_scale_tril, fused_self_attention_query_mul_key_and_value, fused_self_attention_query_mul_key_and_value_grad, fused_tril_scale_softmax_mask_scale, fused_tril_scale_softmax_mask_scale_grad, normalization_add_relu_grad, fused_dot_feature_interaction, fused_dot_feature_interaction_grad, fused_cross_feature_interaction, fused_cross_feature_interaction_grad_v1, fused_cross_feature_interaction_grad_v2, fused_lstm_cell, fused_lstm_cell_grad, fused_gru_cell, fused_gru_cell_grad
// Total: 25
// cudnn_fused_normalization_add_relu, cudnn_fused_normalization_add_relu_grad, fused_bias_add_gelu, fused_bias_add_gelu_grad, fused_bias_add_mask_scale, fused_cast_scale, fused_scale_mask_softmax, fused_scale_mask_softmax_dropout, fused_scale_mask_softmax_dropout_grad, fused_scale_mask_softmax_grad, fused_scale_tril, fused_self_attention_query_mul_key_and_value, fused_self_attention_query_mul_key_and_value_grad, fused_tril_scale_softmax_mask_scale, fused_tril_scale_softmax_mask_scale_grad, normalization_add_relu_grad, fused_dot_feature_interaction, fused_dot_feature_interaction_grad, fused_cross_feature_interaction, fused_cross_feature_interaction_grad_v1, fused_cross_feature_interaction_grad_v2, fused_lstm_cell, fused_lstm_cell_grad, fused_gru_cell, fused_gru_cell_grad, fused_multi_head_attention_inference
// Total: 26

#ifdef GET_ONEFLOW_FUSED_OP_DEFINITIONS

Expand Down Expand Up @@ -2752,6 +2752,31 @@ def OneFlow_FusedCrossFeatureInteractionV2GradOp : OneFlow_BaseOp<"fused_cross_f
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$query,
OneFlow_Tensor:$key,
OneFlow_Tensor:$value
);
let output = (outs
OneFlow_Tensor:$out
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$num_heads,
DefaultValuedAttr<BoolAttr, "false">:$causal,
DefaultValuedAttr<SI64Attr, "0">:$query_hidden_slice_start,
DefaultValuedAttr<SI64Attr, "-1">:$query_hidden_slice_end,
DefaultValuedAttr<SI64Attr, "0">:$key_hidden_slice_start,
DefaultValuedAttr<SI64Attr, "-1">:$key_hidden_slice_end,
DefaultValuedAttr<SI64Attr, "0">:$value_hidden_slice_start,
DefaultValuedAttr<SI64Attr, "-1">:$value_hidden_slice_end
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

#endif // GET_ONEFLOW_FUSED_OP_DEFINITIONS

// Group: IDEMPOTENT
Expand Down
Loading

0 comments on commit b2091bc

Please sign in to comment.