Skip to content

Commit

Permalink
yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Oct 21, 2024
1 parent b10dad1 commit 634d9b0
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,8 @@ def graph_capture_get_metadata_for_batch(
# The encoder decoder model works only with XFormers backend.
# Assert the same.
if is_hip():
assert (
self.runner.attn_backend.get_name() == "ROCM_FLASH"
), (f"Expected attn_backend name to be 'ROCM_FLASH', but "
assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), (
f"Expected attn_backend name to be 'ROCM_FLASH', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
Expand All @@ -355,9 +354,8 @@ def get_graph_input_buffers(
# The encoder decoder model works only with XFormers backend.
# Assert the same.
if is_hip():
assert (
self.runner.attn_backend.get_name() == "ROCM_FLASH"
), (f"Expected attn_backend name to be 'ROCM_FLASH', but "
assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), (
f"Expected attn_backend name to be 'ROCM_FLASH', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
Expand All @@ -383,9 +381,8 @@ def prepare_graph_input_buffers(
# Assert the same.

if is_hip():
assert (
self.runner.attn_backend.get_name() == "ROCM_FLASH"
), (f"Expected attn_backend name to be 'ROCM_FLASH', but "
assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), (
f"Expected attn_backend name to be 'ROCM_FLASH', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
Expand Down

0 comments on commit 634d9b0

Please sign in to comment.