From 2f1c81bb2d3f7859376767a0ab4f6bda6365be51 Mon Sep 17 00:00:00 2001 From: Seungju Date: Thu, 29 Feb 2024 19:45:58 +0900 Subject: [PATCH] hf_olmo: support flash attn --- hf_olmo/modeling_olmo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 6a279cb10..496c55fdd 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -33,6 +33,7 @@ class OLMoForCausalLM(PreTrainedModel): config_class = OLMoConfig base_model_prefix = "model" _no_split_modules = ["OLMoBlock"] + _supports_flash_attn_2 = True def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None, init_params: bool = False): super().__init__(config)