diff --git a/CMakeLists.txt b/CMakeLists.txt index d869f90b8c..62da050cd6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -881,10 +881,6 @@ if(USE_CUDA AND USE_CUTLASS) install(TARGETS fpA_intB_gemm EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) target_link_libraries(tvm PRIVATE fpA_intB_gemm) target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm) - - install(TARGETS flash_attn EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) - target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn) - target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn) endif() if(USE_CUDA AND USE_NCCL) diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index bd3e3b1166..f8aaa2f40d 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -21,7 +21,6 @@ if(USE_CUDA AND USE_CUTLASS) set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass) add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm) - add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn) list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc) message(STATUS "Build with CUTLASS") diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 9b4fa78127..b6a9517f80 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -159,100 +159,3 @@ def instantiate_attention_template(attrs): ) return substitute_template(template, attrs) - - -def instantiate_flash_attention_template(attrs): - """Return host code for flash attention.""" - - template = """ - int q_head_stride = ${head_dim}; - int k_head_stride = ${head_dim}; - int v_head_stride = ${head_dim}; - int o_head_stride = ${head_dim}; - int q_row_stride = q_head_stride * ${num_heads}; - int k_row_stride = k_head_stride * ${num_heads}; - int v_row_stride = v_head_stride * ${num_heads}; - int o_row_stride = o_head_stride * ${num_heads}; - int q_batch_stride = q_row_stride * ${num_queries}; - int k_batch_stride = k_row_stride * ${num_keys}; - int v_batch_stride = v_row_stride * ${num_keys}; - int o_batch_stride = o_row_stride * ${num_queries}; - - flash_attn::flash_attention_forward( - static_cast(${query}->data), - static_cast(${key}->data), - static_cast(${value}->data), - static_cast(out0->data), - ${num_batches}, - ${num_queries}, - ${num_keys}, - ${num_heads}, - ${num_heads}, - ${head_dim}, - q_batch_stride, - k_batch_stride, - v_batch_stride, - o_batch_stride, - q_head_stride, - k_head_stride, - v_head_stride, - o_head_stride, - q_row_stride, - k_row_stride, - v_row_stride, - o_row_stride, - ${scale}, - ${is_causal}, - nullptr); - """ - - template_stacked = """ - int q_head_stride = ${head_dim}; - int k_head_stride = ${head_dim}; - int v_head_stride = ${head_dim}; - int o_head_stride = ${head_dim}; - int row_stride = q_head_stride * ${num_heads} + - k_head_stride * ${num_heads} + - v_head_stride * ${num_heads}; - int q_row_stride = row_stride; - int k_row_stride = row_stride; - int v_row_stride = row_stride; - int o_row_stride = o_head_stride * ${num_heads}; - - int q_batch_stride = q_row_stride * ${num_queries}; - int k_batch_stride = k_row_stride * ${num_keys}; - int v_batch_stride = v_row_stride * ${num_keys}; - int o_batch_stride = o_row_stride * ${num_queries}; - - flash_attn::flash_attention_forward( - static_cast(${qkv}->data), - static_cast(${qkv}->data) + ${head_dim} * ${num_heads}, - static_cast(${qkv}->data) + ${head_dim} * ${num_heads} * 2, - static_cast(out0->data), - ${num_batches}, - ${num_queries}, - ${num_keys}, - ${num_heads}, - ${num_heads}, - ${head_dim}, - q_batch_stride, - k_batch_stride, - v_batch_stride, - o_batch_stride, - q_head_stride, - k_head_stride, - v_head_stride, - o_head_stride, - q_row_stride, - k_row_stride, - v_row_stride, - o_row_stride, - ${scale}, - ${is_causal}, - nullptr); - """ - - if "qkv" in attrs: - return substitute_template(template_stacked, attrs) - - return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 0c57c4750e..afff214bf9 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -59,7 +59,6 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): cutlass_util_include = os.path.join(cutlass_root, "tools/util/include") cutlass_attention_include = os.path.join(cutlass_root, "examples/41_fused_multi_head_attention") cutlass_fpA_intB_gemm_include = os.path.join(cutlass_root, "../cutlass_fpA_intB_gemm") - flash_attn_include = os.path.join(cutlass_root, "../libflash_attn/include") kwargs = {} kwargs["cc"] = "nvcc" @@ -78,7 +77,6 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): f"-I{cutlass_util_include}", f"-I{cutlass_attention_include}", f"-I{cutlass_fpA_intB_gemm_include}", - f"-I{flash_attn_include}", ] if use_fast_math: kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID") diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 1d43ee6e9f..49adc1e79a 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -29,10 +29,7 @@ from tvm.tir import IntImm from . import _ffi_api as ffi -from .attention_operation import ( - instantiate_attention_template, - instantiate_flash_attention_template, -) +from .attention_operation import instantiate_attention_template from .conv2d_operation import instantiate_conv2d_template from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul from .layer_norm_operation import instantiate_layer_norm_template @@ -723,6 +720,7 @@ def get_batch_on_arg(arg_name, arg_shape): return CodegenResult(code, headers) elif "attention" in func_name: + headers.append("kernel_forward.h") data_type = dtype_map[annotations["arg0_dtype"]] attrs["qkv_layout"] = annotations["qkv_layout"] @@ -749,86 +747,62 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["head_dim"] = h = annotations["head_dim"] attrs["head_dim_value"] = h_v = annotations["head_dim_value"] attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"])) + + data_type_size = DataTypeSize[data_type] + if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: + attrs["kIsAligned"] = True + elif (h % 4 == 0) and (h_v % 4 == 0): + attrs["kIsAligned"] = False + else: + raise NotImplementedError() + if h_v > 64: + attrs["kQueriesPerBlock"] = 32 + attrs["kKeysPerBlock"] = 128 + attrs["kSingleValueIteration"] = h_v <= 128 + else: + attrs["kQueriesPerBlock"] = 64 + attrs["kKeysPerBlock"] = 64 + attrs["kSingleValueIteration"] = True + attrs["output_size"] = f"{b} * {s} * {n} * {h_v}" attrs["scale"] = ( float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"] ) - - use_flash = ( - annotations["ret_dtype"] == "float16" - and "bias" not in attrs - and int(attrs["head_dim"]) <= 256 - and int(attrs["head_dim"]) % 8 == 0 - and int(attrs["head_dim"]) == int(attrs["head_dim_value"]) - # We have not thoroughly validated flash with causal mask yet, so for now we support - # only non-causal cases. - and int(annotations["custom_mask_type"]) == 0 - # Flash v2 is currently not supported for sm < 80 - and int(annotations["arch"]) >= 80 - ) - - if use_flash: - headers.append("flash.h") - attrs["is_causal"] = int(annotations["custom_mask_type"]) == 0 - code = instantiate_flash_attention_template(attrs) - else: - headers.append("kernel_forward.h") - - data_type_size = DataTypeSize[data_type] - if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: - attrs["kIsAligned"] = True - elif (h % 4 == 0) and (h_v % 4 == 0): - attrs["kIsAligned"] = False - else: - raise NotImplementedError() - if h_v > 64: - attrs["kQueriesPerBlock"] = 32 - attrs["kKeysPerBlock"] = 128 - attrs["kSingleValueIteration"] = h_v <= 128 - else: - attrs["kQueriesPerBlock"] = 64 - attrs["kKeysPerBlock"] = 64 - attrs["kSingleValueIteration"] = True - - assert ( - attrs["scale"] > 0 or attrs["scale"] < 0 - ), "Cutlass may generate nan occasionally when scale == 0.0" - attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) - attrs["kSupportsDropout"] = False - - attrs["output_size"] = f"{b} * {s} * {n} * {h_v}" - - attrs["custom_mask_type"] = annotations["custom_mask_type"] - - for arg in func_args: - if "workspace" in arg: - attrs["workspace"] = arg - if "bias" in attrs: - attrs["kSupportsBias"] = True - if len(annotations["bias_shape"]) == 4: - strides = "p.num_keys" - if annotations["bias_shape"][2] == 1: - attrs["bias_strideM"] = 0 - else: - attrs["bias_strideM"] = strides - strides = f"p.num_queries * {strides}" - if annotations["bias_shape"][1] == 1: - attrs["bias_strideH"] = 0 - else: - attrs["bias_strideH"] = strides - strides = f"p.num_heads * {strides}" - if annotations["bias_shape"][0] == 1: - attrs["bias_strideB"] = 0 - else: - attrs["bias_strideB"] = strides + attrs["custom_mask_type"] = annotations["custom_mask_type"] + + assert ( + attrs["scale"] > 0 or attrs["scale"] < 0 + ), "Cutlass may generate nan occasionally when scale == 0.0" + attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) + attrs["kSupportsDropout"] = False + + for arg in func_args: + if "workspace" in arg: + attrs["workspace"] = arg + if "bias" in attrs: + attrs["kSupportsBias"] = True + if len(annotations["bias_shape"]) == 4: + strides = "p.num_keys" + if annotations["bias_shape"][2] == 1: + attrs["bias_strideM"] = 0 + else: + attrs["bias_strideM"] = strides + strides = f"p.num_queries * {strides}" + if annotations["bias_shape"][1] == 1: + attrs["bias_strideH"] = 0 + else: + attrs["bias_strideH"] = strides + strides = f"p.num_heads * {strides}" + if annotations["bias_shape"][0] == 1: + attrs["bias_strideB"] = 0 else: - raise NotImplementedError() + attrs["bias_strideB"] = strides else: - # To support negative scale in current Cutlass implementation, - # kSupportsBias should be set true, or there are nan's as result. - attrs["kSupportsBias"] = attrs["scale"] < 0 - - code = instantiate_attention_template(attrs) - + raise NotImplementedError() + else: + # To support negative scale in current Cutlass implementation, + # kSupportsBias should be set true, or there are nan's as result. + attrs["kSupportsBias"] = attrs["scale"] < 0 + code = instantiate_attention_template(attrs) return CodegenResult(code, headers) elif "layer_norm" in func_name: headers.append("cutlass/util/device_layernorm.h") diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 0c5b3ea9e0..0bbcc17db1 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -854,7 +854,7 @@ def stacked_attention_size(request): def test_stacked_attention_split_offload(stacked_attention_size): b, s, n, (h, h_v), bias_shape, scale, single_shape = stacked_attention_size - qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float16") + qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float32") if scale == "none": mod = get_relax_stacked_attention_module( qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape