Skip to content

Commit

Permalink
load with non-standard filenames
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Oct 23, 2024
1 parent 4769ef8 commit b1fa486
Showing 1 changed file with 43 additions and 23 deletions.
66 changes: 43 additions & 23 deletions lmdeploy/turbomind/deploy/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import partial
from glob import glob
from typing import Iterator, Tuple

Expand All @@ -15,6 +16,8 @@
WEIGHT_PATTERN = 'pytorch_model*.bin'
SAFE_WEIGHT_INDEX_NAME = 'model.safetensors.index.json'
SAFE_WEIGHT_PATTERN = 'model*.safetensors'
EXTRA_WEIGHT_PATTERNS = ['*.pt', '*.bin']
EXTRA_SAFE_WEIGHT_PATTERN = '*.safetensors'


class BaseLoader(ABC):
Expand All @@ -27,17 +30,19 @@ def __init__(self, model_path: str, pattern):
def get_index(self, index_path: str,
file_pattern: str) -> Tuple[dict, list]:
"""get shards and weight map (if possible) for the model."""
index_path = osp.join(self.model_path, index_path)
get_path = partial(osp.join, self.model_path)
index_path = get_path(index_path)
if osp.exists(index_path):
with open(index_path, 'r') as f:
index = json.load(f)
index = index['weight_map']
shards = set(index.values())
shards = [osp.join(self.model_path, x) for x in shards]
index = index['weight_map']
shards = list(map(get_path, set(index.values())))
else:
index = {}
file_pattern = osp.join(self.model_path, file_pattern)
shards = glob(file_pattern)
shards = glob(get_path(file_pattern))
if not shards:
raise RuntimeError(
f'failed to locate weight files for {self.model_path}')
return sorted(shards), index

@abstractmethod
Expand All @@ -47,12 +52,13 @@ def items(self) -> Iterator[Tuple[int, dict]]:

class SafetensorsLoader(BaseLoader):

def __init__(self, model_path: str, pattern: str):
def __init__(self,
model_path: str,
pattern: str,
index_name: str = None,
file_pattern: str = None):
super().__init__(model_path, pattern)
self.pattern = pattern
self.model_path = model_path
self.shards, index = self.get_index(SAFE_WEIGHT_INDEX_NAME,
SAFE_WEIGHT_PATTERN)
self.shards, index = self.get_index(index_name, file_pattern)
if not index:
for shard in self.shards:
with safe_open(shard, 'pt') as f:
Expand Down Expand Up @@ -83,9 +89,10 @@ def items(self):

class PytorchLoader(BaseLoader):

def __init__(self, model_path: str, pattern: str):
def __init__(self, model_path: str, pattern: str, index_name: str,
file_pattern: str):
super().__init__(model_path, pattern)
self.shards, index = self.get_index(WEIGHT_INDEX_NAME, WEIGHT_PATTERN)
self.shards, index = self.get_index(index_name, file_pattern)
for k in index.keys():
match = re.findall(self.pattern, k)
if match:
Expand Down Expand Up @@ -122,14 +129,27 @@ def items(self):


def create_loader(model_path: str, pattern: str) -> BaseLoader:
cls = None
args = (model_path, pattern)

if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)):
cls = SafetensorsLoader
elif glob(osp.join(model_path, SAFE_WEIGHT_PATTERN)):
cls = SafetensorsLoader
elif osp.exists(osp.join(model_path, WEIGHT_INDEX_NAME)):
cls = PytorchLoader
elif glob(osp.join(model_path, WEIGHT_PATTERN)):
cls = PytorchLoader
assert cls is not None, f'Failed to find valid loader for {model_path}'
return cls(model_path, pattern)
return SafetensorsLoader(*args, index_name=SAFE_WEIGHT_INDEX_NAME)

if glob(osp.join(model_path, SAFE_WEIGHT_PATTERN)):
return SafetensorsLoader(*args, file_pattern=SAFE_WEIGHT_PATTERN)

if osp.exists(osp.join(model_path, WEIGHT_INDEX_NAME)):
return PytorchLoader(*args, index_name=WEIGHT_INDEX_NAME)

if glob(osp.join(model_path, WEIGHT_PATTERN)):
return PytorchLoader(*args, file_pattern=WEIGHT_PATTERN)

# non-standard safetensors model (*.safetensors)
if glob(osp.join(model_path, EXTRA_SAFE_WEIGHT_PATTERN)):
return SafetensorsLoader(*args, file_pattern=EXTRA_SAFE_WEIGHT_PATTERN)

# non-standard pytorch model (*.bin, *.pt)
for p in EXTRA_WEIGHT_PATTERNS:
if glob(osp.join(model_path, p)):
return PytorchLoader(*args, file_pattern=p)

raise RuntimeError(f'Failed to find valid loader for {model_path}')

0 comments on commit b1fa486

Please sign in to comment.