Skip to content

Commit

Permalink
removed unused function
Browse files Browse the repository at this point in the history
Signed-off-by: 严照东 <[email protected]>
  • Loading branch information
严照东 committed Sep 26, 2023
1 parent 2fb5318 commit 349d8df
Showing 1 changed file with 0 additions and 51 deletions.
51 changes: 0 additions & 51 deletions flagai/model/aquila2/modeling_aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,57 +765,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

@classmethod
def from_pretrain(self, model_dir, model_name, **kwargs):
download_path = os.path.join(model_dir, model_name)
if os.path.exists(download_path):
return self.from_pretrained(download_path, **kwargs)


config_path = os.path.join(download_path, "config.json")
checkpoint_path = os.path.join(download_path, "pytorch_model.bin")
from flagai.model.file_utils import _get_model_id
model_id = _get_model_id(model_name)
if model_id and model_id != "null":
model_files = eval(_get_model_files(model_name))
print("model files:" + str(model_files))
for file_name in model_files:
if not file_name.endswith("bin"):
_get_vocab_path(download_path, file_name, model_id)

if os.path.exists(
os.path.join(download_path, 'config.json')):
if os.getenv('ENV_TYPE') == 'deepspeed+mpu':
model_parallel_size = int(os.getenv("MODEL_PARALLEL_SIZE"))
if model_parallel_size > 1:
# if gpus == nums_of_modelhub_models
# can load
# else need to download the pytorch_model.bin and to recut.
model_hub_parallel_size = 0
for f in model_files:
if "pytorch_model_" in f:
model_hub_parallel_size += 1
else:
model_parallel_size = 1

if "pytorch_model_01.bin" in model_files and model_parallel_size > 1 and model_hub_parallel_size == model_parallel_size:
# Only to download the model slices(megatron-lm).
for file_to_load in model_files:
if "pytorch_model_" in file_to_load:
_get_checkpoint_path(download_path, file_to_load,
model_id)

elif 'pytorch_model.bin' in model_files:
checkpoint_path = _get_checkpoint_path(
download_path, 'pytorch_model.bin', model_id)
else:
checkpoint_merge = {}
# maybe multi weights files
for file_to_load in model_files:
if "pytorch_model-0" in file_to_load:
_get_checkpoint_path(download_path, file_to_load,
model_id)

def get_input_embeddings(self):
return self.model.embed_tokens

Expand Down

0 comments on commit 349d8df

Please sign in to comment.