Skip to content

Commit

Permalink
[misc] 1. Add platform detection before using torch.cuda;
Browse files Browse the repository at this point in the history
       2. Remove unnecessary detection of Navi4x platform;
  • Loading branch information
qli88 committed Oct 25, 2024
1 parent eb931f4 commit ae81124
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
4 changes: 1 addition & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def __init__(
self.act_fn = SiluAndMul()

def forward(self, x):
# Navi4x is diff from other HIP devices by using e4m3fn fp8 format
if is_hip() and not is_navi4x() \
and x.shape[0] == 1 and x.shape[1] == 1:
if is_hip() and x.shape[0] == 1 and x.shape[1] == 1:
out = torch.empty(x.shape[0],
self.gate_up_proj.weight.shape[0] // 2,
dtype=x.dtype,
Expand Down
2 changes: 1 addition & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def is_hip() -> bool:

@lru_cache(maxsize=None)
def is_navi4x() -> bool:
if not torch.cuda.is_available():
if not is_hip() or not torch.cuda.is_available():
return False
# All (visible) GPUs must be of the same type,
# otherwise FP8 results can't be guaranteed.
Expand Down

0 comments on commit ae81124

Please sign in to comment.