diff --git a/flagai/model/aquila2/modeling_aquila.py b/flagai/model/aquila2/modeling_aquila.py index b0731cce..b1ae0cac 100755 --- a/flagai/model/aquila2/modeling_aquila.py +++ b/flagai/model/aquila2/modeling_aquila.py @@ -765,57 +765,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - @classmethod - def from_pretrain(self, model_dir, model_name, **kwargs): - download_path = os.path.join(model_dir, model_name) - if os.path.exists(download_path): - return self.from_pretrained(download_path, **kwargs) - - - config_path = os.path.join(download_path, "config.json") - checkpoint_path = os.path.join(download_path, "pytorch_model.bin") - from flagai.model.file_utils import _get_model_id - model_id = _get_model_id(model_name) - if model_id and model_id != "null": - model_files = eval(_get_model_files(model_name)) - print("model files:" + str(model_files)) - for file_name in model_files: - if not file_name.endswith("bin"): - _get_vocab_path(download_path, file_name, model_id) - - if os.path.exists( - os.path.join(download_path, 'config.json')): - if os.getenv('ENV_TYPE') == 'deepspeed+mpu': - model_parallel_size = int(os.getenv("MODEL_PARALLEL_SIZE")) - if model_parallel_size > 1: - # if gpus == nums_of_modelhub_models - # can load - # else need to download the pytorch_model.bin and to recut. - model_hub_parallel_size = 0 - for f in model_files: - if "pytorch_model_" in f: - model_hub_parallel_size += 1 - else: - model_parallel_size = 1 - - if "pytorch_model_01.bin" in model_files and model_parallel_size > 1 and model_hub_parallel_size == model_parallel_size: - # Only to download the model slices(megatron-lm). - for file_to_load in model_files: - if "pytorch_model_" in file_to_load: - _get_checkpoint_path(download_path, file_to_load, - model_id) - - elif 'pytorch_model.bin' in model_files: - checkpoint_path = _get_checkpoint_path( - download_path, 'pytorch_model.bin', model_id) - else: - checkpoint_merge = {} - # maybe multi weights files - for file_to_load in model_files: - if "pytorch_model-0" in file_to_load: - _get_checkpoint_path(download_path, file_to_load, - model_id) - def get_input_embeddings(self): return self.model.embed_tokens