Skip to content

Remove AMD workarounds as they do more harm than good on recent ROCm releases. #8289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def vae_attention():
if model_management.xformers_enabled_vae():
logging.info("Using xformers attention in VAE")
return xformers_attention
elif model_management.pytorch_attention_enabled_vae():
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention in VAE")
return pytorch_attention
else:
Expand Down
10 changes: 2 additions & 8 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def is_amd():
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx1201 and gfx950
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151", "gfx1200", "gfx1201"]): # TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
Expand Down Expand Up @@ -875,8 +875,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
if d == torch.float16 and should_use_fp16(device):
return d

# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
if d == torch.bfloat16 and should_use_bf16(device):
return d

return torch.float32
Expand Down Expand Up @@ -1043,11 +1042,6 @@ def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION

def pytorch_attention_enabled_vae():
if is_amd():
return False # enabling pytorch attention on AMD currently causes crash when doing high res
return pytorch_attention_enabled()

def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
Expand Down