diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 6a279cb10..496c55fdd 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -33,6 +33,7 @@ class OLMoForCausalLM(PreTrainedModel): config_class = OLMoConfig base_model_prefix = "model" _no_split_modules = ["OLMoBlock"] + _supports_flash_attn_2 = True def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None, init_params: bool = False): super().__init__(config)