Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile SageAttention CUDA 12.8, TORCH 2.7.0 - Blackwell #107

Open
alisson-anjos opened this issue Feb 12, 2025 · 8 comments
Open

Compile SageAttention CUDA 12.8, TORCH 2.7.0 - Blackwell #107

alisson-anjos opened this issue Feb 12, 2025 · 8 comments

Comments

@alisson-anjos
Copy link

Hello, I would like to know if to compile SageAttention using CUDA 12.8 and torch 2.7.0 I need to change setup.py and have the files ?

attn_cuda_sm120.h
pybind_sm120.cpp
qk_int_sv_f8_cuda_sm120.cu

specific to this version, could you guide me with this? Some models that I use, such as the Hunyuan, have implementations that use SageAttention. The problem is that I cannot use video cards from the Blackwell architecture (RTX 50xx) because many packages do not yet support torch 2.7.0.

@jason-huang03
Copy link
Member

jason-huang03 commented Feb 13, 2025

Hi, we have tested on Blackwell and the kernel works but the setup.py script needs some changes. I provide the script here, you can replace the setup.py with the code below:

"""
Copyright (c) 2024 by SageAttention team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import subprocess
from packaging.version import parse, Version
from typing import List, Set
import warnings

from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME

HAS_SM80 = False
HAS_SM86 = False
HAS_SM89 = False
HAS_SM90 = False

# Supported NVIDIA GPU architectures.
# SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0"}

# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
    "-O3",
    "-std=c++17",
    "-U__CUDA_NO_HALF_OPERATORS__",
    "-U__CUDA_NO_HALF_CONVERSIONS__",
    "--use_fast_math",
    "--threads=8",
    "-Xptxas=-v",
    "-diag-suppress=174", # suppress the specific warning
]

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]

# if CUDA_HOME is None:
#     raise RuntimeError(
#         "Cannot find CUDA_HOME. CUDA must be available to build the package.")

# def get_nvcc_cuda_version(cuda_dir: str) -> Version:
#     """Get the CUDA version from nvcc.

#     Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
#     """
#     nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
#                                           universal_newlines=True)
#     output = nvcc_output.split()
#     release_idx = output.index("release") + 1
#     nvcc_cuda_version = parse(output[release_idx].split(",")[0])
#     return nvcc_cuda_version

# def get_torch_arch_list() -> Set[str]:
#     # TORCH_CUDA_ARCH_LIST can have one or more architectures,
#     # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
#     # compiler to additionally include PTX code that can be runtime-compiled
#     # and executed on the 8.6 or newer architectures. While the PTX code will
#     # not give the best performance on the newer architectures, it provides
#     # forward compatibility.
#     env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
#     if env_arch_list is None:
#         return set()

#     # List are separated by ; or space.
#     torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
#     if not torch_arch_list:
#         return set()

#     # Filter out the invalid architectures and print a warning.
#     valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
#     arch_list = torch_arch_list.intersection(valid_archs)
#     # If none of the specified architectures are valid, raise an error.
#     if not arch_list:
#         raise RuntimeError(
#             "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
#             f"variable ({env_arch_list}) is supported. "
#             f"Supported CUDA architectures are: {valid_archs}.")
#     invalid_arch_list = torch_arch_list - valid_archs
#     if invalid_arch_list:
#         warnings.warn(
#             f"Unsupported CUDA architectures ({invalid_arch_list}) are "
#             "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
#             f"({env_arch_list}). Supported CUDA architectures are: "
#             f"{valid_archs}.")
#     return arch_list

# # First, check the TORCH_CUDA_ARCH_LIST environment variable.
# compute_capabilities = get_torch_arch_list()
# if not compute_capabilities:
#     # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
#     # GPUs on the current machine.
#     device_count = torch.cuda.device_count()
#     for i in range(device_count):
#         major, minor = torch.cuda.get_device_capability(i)
#         if major < 8:
#             raise RuntimeError(
#                 "GPUs with compute capability below 8.0 are not supported.")
#         compute_capabilities.add(f"{major}.{minor}")

# nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
# if not compute_capabilities:
#     raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.")

# # Validate the NVCC CUDA version.
# if nvcc_cuda_version < Version("12.0"):
#     raise RuntimeError("CUDA 12.0 or higher is required to build the package.")
# if nvcc_cuda_version < Version("12.4") and any(cc.startswith("8.9") for cc in compute_capabilities):
#     raise RuntimeError(
#         "CUDA 12.4 or higher is required for compute capability 8.9.")
# if nvcc_cuda_version < Version("12.3") and any(cc.startswith("9.0") for cc in compute_capabilities):
#     if any(cc.startswith("9.0") for cc in compute_capabilities):
#         raise RuntimeError(
#             "CUDA 12.3 or higher is required for compute capability 9.0.")

# Add target compute capabilities to NVCC flags.
# for capability in compute_capabilities:
#     num = capability[0] + capability[2]
#     if num == "80":
#         HAS_SM80 = True
#     elif num == "86":
#         HAS_SM86 = True
#     elif num == "89":
#         HAS_SM89 = True
#     elif num == "90":
#         HAS_SM90 = True
#         num = num + "a" # convert sm90 to sm9a
#     NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
#     if capability.endswith("+PTX"):
#         NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]

NVCC_FLAGS += ["-gencode", f"arch=compute_120,code=sm_120"]

ext_modules = []

qattn_extension = CUDAExtension(
    name="sageattention._qattn_sm80",
    sources=[
        "csrc/qattn/pybind_sm80.cpp",
        "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
    ],
    extra_compile_args={
        "cxx": CXX_FLAGS,
        "nvcc": NVCC_FLAGS,
    },
)
ext_modules.append(qattn_extension)

qattn_extension = CUDAExtension(
    name="sageattention._qattn_sm89",
    sources=[
        "csrc/qattn/pybind_sm89.cpp",
        "csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
    ],
    extra_compile_args={
        "cxx": CXX_FLAGS,
        "nvcc": NVCC_FLAGS,
    },
)
ext_modules.append(qattn_extension)

# Fused kernels.
fused_extension = CUDAExtension(
    name="sageattention._fused",
    sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"],
    extra_compile_args={
        "cxx": CXX_FLAGS,
        "nvcc": NVCC_FLAGS,
    },
)
ext_modules.append(fused_extension)

setup(
    name='sageattention', 
    version='2.1.0',  
    author='SageAttention team',
    license='Apache 2.0 License',  
    description='Accurate and efficient plug-and-play low-bit attention.',  
    long_description=open('README.md').read(),  
    long_description_content_type='text/markdown', 
    url='https://github.com/thu-ml/SageAttention', 
    packages=find_packages(),
    python_requires='>=3.9',
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)

@alisson-anjos
Copy link
Author

alisson-anjos commented Feb 13, 2025

@jason-huang03

Thanks for answering but it still didn't work, it compiles, generates the whl but when using it in comfyui with Hunyuan it shows these messages!

Image

@jason-huang03
Copy link
Member

jason-huang03 commented Feb 14, 2025

@alisson-anjos
Are you using 5090 or B series?

@alisson-anjos
Copy link
Author

alisson-anjos commented Feb 14, 2025

@alisson-anjos
Are you using 5090 or B series?

RTX 5090

Image

@kijai
Copy link

kijai commented Feb 14, 2025

I have successfully compiled and ran this on Debian 13 testing with both Flux and HunyuanVideo in ComfyUI:

  • installed latest pytorch nightly: pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
  • compiled and installed latest triton from source
  • edited the sageattention setup.py like shown above
  • specifically selected the sageattn_qk_int8_pv_fp8_cuda -mode, autodetection will currently fail

Flux speed with torch.compile and sageattention hit 5.3it/s for 1024p, my previous record on 4090 was around 3.6it/s. Baseline for 5090 was 3.44it/s. This is with fp8 scaledmm (fp8 fast mode in Comfyu)

The fp16 cuda mode works with Flux but produces Nans on HunyuanVideo.
The triton mode crashes whole ComfyUI:

python: /home/kijai/AI/triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp:40: int mlir::triton::gpu::{anonymous}::getMMAVersionSafe(int, mlir::triton::DotOp): Assertion `false && "computeCapability not supported"' failed.
Aborted (core dumped)

@jason-huang03
Copy link
Member

@alisson-anjos we have updated the compilation code in #109 . The problem you met is due to Triton. Triton now have troubles on sm120 and @kijai has met the same problem too. So we use per_warp quantization which uses cuda kernel to do quantization (we are working on per thread quantization cuda kernel). Now sageattn jumps to sageattn_qk_int8_pv_fp8 with qk_quant_gran="per_warp" and pv_accum_dtype="fp32", which has excellent speed. Also you can use sageattn_qk_int8_pv_fp16 with `qk_quant_gran="per_warp"', but speed of this kernel is currently less satisfactory on RTX5090.

@aliabougazia
Copy link

I have successfully compiled and ran this on Debian 13 testing with both Flux and HunyuanVideo in ComfyUI:

  • installed latest pytorch nightly: pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
  • compiled and installed latest triton from source
  • edited the sageattention setup.py like shown above
  • specifically selected the sageattn_qk_int8_pv_fp8_cuda -mode, autodetection will currently fail

Flux speed with torch.compile and sageattention hit 5.3it/s for 1024p, my previous record on 4090 was around 3.6it/s. Baseline for 5090 was 3.44it/s. This is with fp8 scaledmm (fp8 fast mode in Comfyu)

The fp16 cuda mode works with Flux but produces Nans on HunyuanVideo. The triton mode crashes whole ComfyUI:

python: /home/kijai/AI/triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp:40: int mlir::triton::gpu::{anonymous}::getMMAVersionSafe(int, mlir::triton::DotOp): Assertion `false && "computeCapability not supported"' failed.
Aborted (core dumped)

it seems because you compiled on debian, it wouldn't work on ubuntu 24.04:

(comfyui) aa@AAdesktop:~/ComfyUI$ python main.py --use-sage-attention
[START] Security scan
[DONE] Security scan

ComfyUI-Manager: installing dependencies done.

** ComfyUI startup time: 2025-02-17 17:29:19.189
** Platform: Linux
** Python version: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0]
** Python executable: /home/aa/miniconda3/envs/comfyui/bin/python
** ComfyUI Path: /home/aa/ComfyUI
** ComfyUI Base Folder Path: /home/aa/ComfyUI
** User directory: /home/aa/ComfyUI/user
** ComfyUI-Manager config path: /home/aa/ComfyUI/user/default/ComfyUI-Manager/config.ini
** Log path: /home/aa/ComfyUI/user/comfyui.log

Prestartup times for custom nodes:
1.1 seconds: /home/aa/ComfyUI/custom_nodes/ComfyUI-Manager

Checkpoint files will always be loaded safely.
Total VRAM 32607 MB, total RAM 60276 MB
pytorch version: 2.7.0.dev20250216+cu128
Set vram state to: NORMAL_VRAM
Device: cuda:0 NVIDIA GeForce RTX 5090 : cudaMallocAsync
Traceback (most recent call last):
File "/home/aa/ComfyUI/main.py", line 136, in
import execution
File "/home/aa/ComfyUI/execution.py", line 13, in
import nodes
File "/home/aa/ComfyUI/nodes.py", line 22, in
import comfy.diffusers_load
File "/home/aa/ComfyUI/comfy/diffusers_load.py", line 3, in
import comfy.sd
File "/home/aa/ComfyUI/comfy/sd.py", line 12, in
import comfy.ldm.genmo.vae.model
File "/home/aa/ComfyUI/comfy/ldm/genmo/vae/model.py", line 13, in
from comfy.ldm.modules.attention import optimized_attention
File "/home/aa/ComfyUI/comfy/ldm/modules/attention.py", line 22, in
from sageattention import sageattn
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/sageattention/init.py", line 1, in
from .core import sageattn, sageattn_varlen
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/sageattention/core.py", line 20, in
from .triton.quant_per_block import per_block_int8 as per_block_int8_triton
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/sageattention/triton/quant_per_block.py", line 18, in
import triton
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/init.py", line 8, in
from .runtime import (
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/runtime/init.py", line 1, in
from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics)
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 9, in
from .jit import KernelInterface
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/runtime/jit.py", line 12, in
from ..runtime.driver import driver
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/runtime/driver.py", line 1, in
from ..backends import backends
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/backends/init.py", line 50, in
backends = _discover_backends()
^^^^^^^^^^^^^^^^^^^^
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/backends/init.py", line 43, in _discover_backends
compiler = _load_module(name, os.path.join(root, name, 'compiler.py'))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/backends/init.py", line 12, in _load_module
spec.loader.exec_module(module)
File "/home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/backends/nvidia/compiler.py", line 2, in
from triton._C.libtriton import ir, passes, llvm, nvidia
ImportError: /home/aa/miniconda3/envs/comfyui/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/aa/miniconda3/envs/comfyui/lib/python3.12/site-packages/triton/_C/libtriton.so)

@aliabougazia
Copy link

I have successfully compiled and ran this on Debian 13 testing with both Flux and HunyuanVideo in ComfyUI:

  • installed latest pytorch nightly: pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
  • compiled and installed latest triton from source
  • edited the sageattention setup.py like shown above
  • specifically selected the sageattn_qk_int8_pv_fp8_cuda -mode, autodetection will currently fail

Flux speed with torch.compile and sageattention hit 5.3it/s for 1024p, my previous record on 4090 was around 3.6it/s. Baseline for 5090 was 3.44it/s. This is with fp8 scaledmm (fp8 fast mode in Comfyu)

The fp16 cuda mode works with Flux but produces Nans on HunyuanVideo. The triton mode crashes whole ComfyUI:

python: /home/kijai/AI/triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp:40: int mlir::triton::gpu::{anonymous}::getMMAVersionSafe(int, mlir::triton::DotOp): Assertion `false && "computeCapability not supported"' failed.
Aborted (core dumped)

I updated gcc to latest version. Now I'm getting this:

The config attributes {'use_flow_sigmas': True, 'prediction_type': 'flow_prediction'} were passed to FlowMatchDiscreteScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
Scheduler config: FrozenDict({'num_train_timesteps': 1000, 'flow_shift': 9.0, 'reverse': True, 'solver': 'euler', 'n_tokens': None, '_use_default_values': ['n_tokens', 'num_train_timesteps']})
Using accelerate to load and assign model weights to device...
Loading LoRA: HunyuanVideo/ReverseCowgirl with strength: 0.74
Requested to load HyVideoModel
loaded completely 29489.67171936035 12555.953247070312 True
Input (height, width, video_length) = (720, 480, 65)
The config attributes {'use_flow_sigmas': True, 'prediction_type': 'flow_prediction'} were passed to FlowMatchDiscreteScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
Scheduler config: FrozenDict({'num_train_timesteps': 1000, 'flow_shift': 8.0, 'reverse': True, 'solver': 'euler', 'n_tokens': None, '_use_default_values': ['n_tokens', 'num_train_timesteps']})
Sampling 65 frames in 17 latents at 480x720 with 25 inference steps
0%| | 0/25 [00:00<?, ?it/s]python: /home/kijai/AI/triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp:40: int mlir::triton::gpu::{anonymous}::getMMAVersionSafe(int, mlir::triton::DotOp): Assertion `false && "computeCapability not supported"' failed.
Aborted (core dumped)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants