From a464e7b4f325c76cabc84d24a6bba93cff03a949 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 5 Mar 2024 17:56:03 +0100 Subject: [PATCH] Update loading of models with new attention methods --- pyproject.toml | 4 ++-- textwiz/__init__.py | 2 +- textwiz/loader.py | 45 +++++++++++++++++++++++++++++++++------------ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1b93f41..8f5d494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,8 @@ dependencies = [ "scipy", "pyyaml", "packaging", - "torch>=2.0.1", - "transformers>=4.33.1", + "torch>=2.2.0", + "transformers>=4.37", "tokenizers>=0.13.3", "sentencepiece", "protobuf", diff --git a/textwiz/__init__.py b/textwiz/__init__.py index 5932495..03ccdcf 100644 --- a/textwiz/__init__.py +++ b/textwiz/__init__.py @@ -12,7 +12,7 @@ from . import loader, conversation_template, prompt_template -__version__ = '0.3.0' +__version__ = '0.4.0' def is_chat_model(model_name: str) -> bool: diff --git a/textwiz/loader.py b/textwiz/loader.py index 2e73978..23cdcfe 100644 --- a/textwiz/loader.py +++ b/textwiz/loader.py @@ -708,10 +708,39 @@ def load_model(model_name: str, quantization_8bits: bool = False, quantization_4 device_map = 'balanced' # Load model - model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], device_map=device_map, - torch_dtype=dtype, load_in_8bit=quantization_8bits, - load_in_4bit=quantization_4bits, low_cpu_mem_usage=True, - **additional_kwargs) + # We first try with flash attention 2 + try: + model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], attn_implementation='flash_attention_2', + device_map=device_map, torch_dtype=dtype, load_in_8bit=quantization_8bits, + load_in_4bit=quantization_4bits, low_cpu_mem_usage=True, **additional_kwargs) + success = True + except: + success = False + + # Second try with Pytorch native sdpa (which may sometimes but not for all models also use flash attention 2) + if not success: + try: + model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], attn_implementation='sdpa', + device_map=device_map, torch_dtype=dtype, load_in_8bit=quantization_8bits, + load_in_4bit=quantization_4bits, low_cpu_mem_usage=True, **additional_kwargs) + success = True + except: + success = False + + # Last try with BetterTransformer, which is the same as sdpa but with coverage for more models + if not success: + model = AutoModelForCausalLM.from_pretrained(ALL_MODELS_MAPPING[model_name], attn_implementation='eager', device_map=device_map, + torch_dtype=dtype, load_in_8bit=quantization_8bits, load_in_4bit=quantization_4bits, + low_cpu_mem_usage=True, **additional_kwargs) + # For some reason bettertransformer is supported for codegen2 models but makes them crash during the forward + if not ('codegen2-' in model_name): + # Convert to better transformer to use Pytorch optimizations if supported by the model + try: + model = model.to_bettertransformer() + except: + warnings.warn(('The default manual attention implementation will be used. This will result in slower generation and ' + 'higher memory usage. This should not be an issue for small models.')) + # If the flag is active we directly put our model on one gpu without using any device_map (this is # more efficient). But if the model is quantized, this is already done automatically because quantization @@ -719,14 +748,6 @@ def load_model(model_name: str, quantization_8bits: bool = False, quantization_4 if only_move_to_one_gpu and not quantization: # This operation is in-place for nn.Module model.cuda(gpu_rank) - - # For some reason bettertransformer is supported for codegen2 models but makes them crash during the forward - if not ('codegen2-' in model_name): - # Convert to better transformer to use Pytorch optimizations if supported by the model - try: - model = model.to_bettertransformer() - except: - pass model.eval()