Skip to content

Commit

Permalink
feat: refactor sdp backends
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Aug 2, 2024
1 parent d772e53 commit 3ded163
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 40 deletions.
2 changes: 1 addition & 1 deletion sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def build_sam2_video_predictor(

def _load_checkpoint(model, ckpt_path):
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["model"]
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
missing_keys, unexpected_keys = model.load_state_dict(sd)
if missing_keys:
logging.error(missing_keys)
Expand Down
21 changes: 6 additions & 15 deletions sam2/modeling/sam/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings
from sam2.utils.misc import get_sdp_backends

warnings.simplefilter(action="ignore", category=FutureWarning)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()


class TwoWayTransformer(nn.Module):
Expand Down Expand Up @@ -245,12 +245,8 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:

dropout_p = self.dropout_p if self.training else 0.0
# Attention
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):

with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)

out = self._recombine_heads(out)
Expand Down Expand Up @@ -311,13 +307,8 @@ def forward(
)

dropout_p = self.dropout_p if self.training else 0.0
# Attention
with torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
):

with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)

out = self._recombine_heads(out)
Expand Down
45 changes: 21 additions & 24 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import os
import warnings
from threading import Thread
from typing import Dict, List
from typing import Dict, List, Union

import numpy as np
import torch
from PIL import Image
from torch.nn.attention import SDPBackend
from tqdm import tqdm

VARIANTS: List[str] = ["tiny", "small", "base_plus", "large"]
Expand All @@ -24,34 +25,30 @@
}


def get_sdpa_settings():
def get_sdp_backends(dropout_p: float) -> Union[List[SDPBackend], SDPBackend]:
backends = []
if torch.cuda.is_available():
old_gpu = torch.cuda.get_device_properties(0).major < 7
# only use Flash Attention on Ampere (8.0) or newer GPUs
use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
if not use_flash_attn:
warnings.warn(
"Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
category=UserWarning,
stacklevel=2,
)
# keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
# available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
if pytorch_version < (2, 2):
warnings.warn(
f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
"Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
category=UserWarning,
stacklevel=2,
)
math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn

if torch.cuda.get_device_properties(0).major < 7:
backends.append(SDPBackend.EFFICIENT_ATTENTION)

if use_flash_attn:
backends.append(SDPBackend.FLASH_ATTENTION)

if pytorch_version < (2, 2) or not use_flash_attn:
backends.append(SDPBackend.MATH)

if (
SDPBackend.EFFICIENT_ATTENTION in backends and dropout_p > 0.0
) and SDPBackend.MATH not in backends:
backends.append(SDPBackend.MATH)

else:
old_gpu = True
use_flash_attn = False
math_kernel_on = True
backends.extend([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])

return old_gpu, use_flash_attn, math_kernel_on
return backends


def get_connected_components(mask):
Expand Down

0 comments on commit 3ded163

Please sign in to comment.