You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi everyone!
Trying to initialize Llava1.6-34b-hf with flash attention 2 but getting the following issue, after which it doesn't work properly and doesn't speed up inference.
The point is I explicitly pass torch_dtye=torch.float16. The question is how to handle this warning and does it affect the flash-attention work on inference.
Code is bellow model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-34b-hf", device_map="cuda", torch_dtype=torch.float16, attn_implementation="flash_attention_2" )
The text was updated successfully, but these errors were encountered:
Hi everyone!
Trying to initialize Llava1.6-34b-hf with flash attention 2 but getting the following issue, after which it doesn't work properly and doesn't speed up inference.
The point is I explicitly pass torch_dtye=torch.float16. The question is how to handle this warning and does it affect the flash-attention work on inference.
Code is bellow
model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-34b-hf", device_map="cuda", torch_dtype=torch.float16, attn_implementation="flash_attention_2" )
The text was updated successfully, but these errors were encountered: