Skip to content

Commit fa1db71

Browse files
committed
Enable bfloat16 on AMD
It works fine on ROCm 6.4.1. Also it is faster and avoid OOMs in VAE Decode. Fixes: #7400 Fixes: ROCm/ROCm#4742
1 parent e4288bf commit fa1db71

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

comfy/model_management.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,8 +865,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
865865
if d == torch.float16 and should_use_fp16(device):
866866
return d
867867

868-
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
869-
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
868+
if d == torch.bfloat16 and should_use_bf16(device):
870869
return d
871870

872871
return torch.float32

0 commit comments

Comments
 (0)