-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compile SageAttention CUDA 12.8, TORCH 2.7.0 - Blackwell #107
Comments
Hi, we have tested on Blackwell and the kernel works but the """
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import subprocess
from packaging.version import parse, Version
from typing import List, Set
import warnings
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
HAS_SM80 = False
HAS_SM86 = False
HAS_SM89 = False
HAS_SM90 = False
# Supported NVIDIA GPU architectures.
# SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0"}
# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
# if CUDA_HOME is None:
# raise RuntimeError(
# "Cannot find CUDA_HOME. CUDA must be available to build the package.")
# def get_nvcc_cuda_version(cuda_dir: str) -> Version:
# """Get the CUDA version from nvcc.
# Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
# """
# nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
# universal_newlines=True)
# output = nvcc_output.split()
# release_idx = output.index("release") + 1
# nvcc_cuda_version = parse(output[release_idx].split(",")[0])
# return nvcc_cuda_version
# def get_torch_arch_list() -> Set[str]:
# # TORCH_CUDA_ARCH_LIST can have one or more architectures,
# # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# # compiler to additionally include PTX code that can be runtime-compiled
# # and executed on the 8.6 or newer architectures. While the PTX code will
# # not give the best performance on the newer architectures, it provides
# # forward compatibility.
# env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
# if env_arch_list is None:
# return set()
# # List are separated by ; or space.
# torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
# if not torch_arch_list:
# return set()
# # Filter out the invalid architectures and print a warning.
# valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
# arch_list = torch_arch_list.intersection(valid_archs)
# # If none of the specified architectures are valid, raise an error.
# if not arch_list:
# raise RuntimeError(
# "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
# f"variable ({env_arch_list}) is supported. "
# f"Supported CUDA architectures are: {valid_archs}.")
# invalid_arch_list = torch_arch_list - valid_archs
# if invalid_arch_list:
# warnings.warn(
# f"Unsupported CUDA architectures ({invalid_arch_list}) are "
# "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
# f"({env_arch_list}). Supported CUDA architectures are: "
# f"{valid_archs}.")
# return arch_list
# # First, check the TORCH_CUDA_ARCH_LIST environment variable.
# compute_capabilities = get_torch_arch_list()
# if not compute_capabilities:
# # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# # GPUs on the current machine.
# device_count = torch.cuda.device_count()
# for i in range(device_count):
# major, minor = torch.cuda.get_device_capability(i)
# if major < 8:
# raise RuntimeError(
# "GPUs with compute capability below 8.0 are not supported.")
# compute_capabilities.add(f"{major}.{minor}")
# nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
# if not compute_capabilities:
# raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.")
# # Validate the NVCC CUDA version.
# if nvcc_cuda_version < Version("12.0"):
# raise RuntimeError("CUDA 12.0 or higher is required to build the package.")
# if nvcc_cuda_version < Version("12.4") and any(cc.startswith("8.9") for cc in compute_capabilities):
# raise RuntimeError(
# "CUDA 12.4 or higher is required for compute capability 8.9.")
# if nvcc_cuda_version < Version("12.3") and any(cc.startswith("9.0") for cc in compute_capabilities):
# if any(cc.startswith("9.0") for cc in compute_capabilities):
# raise RuntimeError(
# "CUDA 12.3 or higher is required for compute capability 9.0.")
# Add target compute capabilities to NVCC flags.
# for capability in compute_capabilities:
# num = capability[0] + capability[2]
# if num == "80":
# HAS_SM80 = True
# elif num == "86":
# HAS_SM86 = True
# elif num == "89":
# HAS_SM89 = True
# elif num == "90":
# HAS_SM90 = True
# num = num + "a" # convert sm90 to sm9a
# NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
# if capability.endswith("+PTX"):
# NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
NVCC_FLAGS += ["-gencode", f"arch=compute_120,code=sm_120"]
ext_modules = []
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm80",
sources=[
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm89",
sources=[
"csrc/qattn/pybind_sm89.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)
# Fused kernels.
fused_extension = CUDAExtension(
name="sageattention._fused",
sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(fused_extension)
setup(
name='sageattention',
version='2.1.0',
author='SageAttention team',
license='Apache 2.0 License',
description='Accurate and efficient plug-and-play low-bit attention.',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
url='https://github.com/thu-ml/SageAttention',
packages=find_packages(),
python_requires='>=3.9',
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
) |
Thanks for answering but it still didn't work, it compiles, generates the whl but when using it in comfyui with Hunyuan it shows these messages! |
@alisson-anjos |
RTX 5090 |
I have successfully compiled and ran this on Debian 13 testing with both Flux and HunyuanVideo in ComfyUI:
Flux speed with torch.compile and sageattention hit 5.3it/s for 1024p, my previous record on 4090 was around 3.6it/s. Baseline for 5090 was 3.44it/s. This is with fp8 scaledmm (fp8 fast mode in Comfyu) The fp16 cuda mode works with Flux but produces Nans on HunyuanVideo.
|
@alisson-anjos we have updated the compilation code in #109 . The problem you met is due to Triton. Triton now have troubles on sm120 and @kijai has met the same problem too. So we use per_warp quantization which uses cuda kernel to do quantization (we are working on per thread quantization cuda kernel). Now |
it seems because you compiled on debian, it wouldn't work on ubuntu 24.04: (comfyui) aa@AAdesktop:~/ComfyUI$ python main.py --use-sage-attention ComfyUI-Manager: installing dependencies done.** ComfyUI startup time: 2025-02-17 17:29:19.189 Prestartup times for custom nodes: Checkpoint files will always be loaded safely. |
I updated gcc to latest version. Now I'm getting this: The config attributes {'use_flow_sigmas': True, 'prediction_type': 'flow_prediction'} were passed to FlowMatchDiscreteScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file. |
Hello, I would like to know if to compile SageAttention using CUDA 12.8 and torch 2.7.0 I need to change setup.py and have the files ?
attn_cuda_sm120.h
pybind_sm120.cpp
qk_int_sv_f8_cuda_sm120.cu
specific to this version, could you guide me with this? Some models that I use, such as the Hunyuan, have implementations that use SageAttention. The problem is that I cannot use video cards from the Blackwell architecture (RTX 50xx) because many packages do not yet support torch 2.7.0.
The text was updated successfully, but these errors were encountered: