From 9d83c6d715e8cdb802f82335e651923baab5cfc6 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 17 May 2024 18:18:59 +0800 Subject: [PATCH] [lazy] fix lazy cls init (#5720) * fix * fix * fix * fix * fix * remove kernel intall * rebase revert fix * fix * fix --- .github/workflows/build_on_pr.yml | 2 +- colossalai/lazy/pretrained.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 5bdadca783b3..a3a6d5a6ab0d 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v -e . + pip install -v -e . pip install -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py index 21d44d4244d3..736ffc5e4ea2 100644 --- a/colossalai/lazy/pretrained.py +++ b/colossalai/lazy/pretrained.py @@ -1,3 +1,4 @@ +import copy import os from typing import Callable, Optional, Union @@ -74,6 +75,24 @@ def new_from_pretrained( subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + + kwargs.pop("state_dict", None) + kwargs.pop("from_tf", False) + kwargs.pop("from_flax", False) + kwargs.pop("output_loading_info", False) + kwargs.pop("trust_remote_code", None) + kwargs.pop("low_cpu_mem_usage", None) + kwargs.pop("device_map", None) + kwargs.pop("max_memory", None) + kwargs.pop("offload_folder", None) + kwargs.pop("offload_state_dict", False) + kwargs.pop("load_in_8bit", False) + kwargs.pop("load_in_4bit", False) + kwargs.pop("quantization_config", None) + kwargs.pop("adapter_kwargs", {}) + kwargs.pop("adapter_name", "default") + kwargs.pop("use_flash_attention_2", False) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) if len(kwargs) > 0: @@ -108,6 +127,10 @@ def new_from_pretrained( **kwargs, ) else: + config = copy.deepcopy(config) + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + config._attn_implementation = kwarg_attn_imp model_kwargs = kwargs if commit_hash is None: