Skip to content

Commit

Permalink
Update model_name check to find closest match
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Apr 16, 2024
1 parent c7eb68a commit 1cfd350
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
46 changes: 30 additions & 16 deletions textwiz/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import math
import importlib.metadata
import difflib
from packaging import version

import torch
Expand Down Expand Up @@ -34,6 +35,28 @@
}



def check_model_name(model_name: str, available_models: list[str] | tuple[str] = ALLOWED_MODELS):
"""Ensure that the `model_name` is valid, and raise an error trying to find closest model otherwise.
Parameters
----------
model_name : str
The model name.
available_models : list[str] | tuple[str], optional
List of available models to check for closest match. By default check all models.
"""

if model_name not in available_models:
closest_match = difflib.get_close_matches(model_name, available_models, n=1)
if len(closest_match) > 0:
raise ValueError(f'The model name you provided is invalid. Perhaps you meant "{closest_match[0]}"?')
else:
raise ValueError(f'The model name you provided is invalid.')
else:
return


def get_model_params(model_name: str) -> float:
"""Return the approximate number of params of the model, in billions.
Expand All @@ -47,10 +70,7 @@ def get_model_params(model_name: str) -> float:
float
The number of parameters.
"""

if model_name not in ALLOWED_MODELS:
raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.')

check_model_name(model_name)
return ALL_MODELS_PARAMS[model_name]


Expand All @@ -67,10 +87,7 @@ def get_model_dtype(model_name: str) -> torch.dtype:
torch.dtype
The default dtype.
"""

if model_name not in ALLOWED_MODELS:
raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.')

check_model_name(model_name)
return ALL_MODELS_DTYPES[model_name]


Expand All @@ -87,10 +104,7 @@ def get_model_context_size(model_name: str) -> int:
int
The context size.
"""

if model_name not in ALLOWED_MODELS:
raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.')

check_model_name(model_name)
return ALL_MODELS_CONTEXT_SIZE[model_name]


Expand Down Expand Up @@ -270,8 +284,8 @@ def load_model(model_name: str, quantization_8bits: bool = False, quantization_4
The model.
"""

if model_name not in ALLOWED_MODELS:
raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.')
# Check name
check_model_name(model_name)

# Check package versions
check_versions(model_name)
Expand Down Expand Up @@ -408,8 +422,8 @@ def load_tokenizer(model_name: str):
The tokenizer.
"""

if model_name not in ALLOWED_MODELS:
raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.')
# Check name
check_model_name(model_name)

# Check package versions
check_versions(model_name)
Expand Down
4 changes: 2 additions & 2 deletions textwiz/models/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(self, model_name: str, quantization_8bits: bool = False, quantizati
dtype: torch.dtype | None = None, max_fraction_gpu_0: float = 0.8, max_fraction_gpus: float = 0.8,
device_map: dict | str | None = None, gpu_rank: int = 0):

if model_name not in loader.ALLOWED_CAUSAL_MODELS:
raise ValueError(f'The model name must be one of {*loader.ALLOWED_CAUSAL_MODELS,}.')
# Check name against only causal models
loader.check_model_name(model_name, loader.ALLOWED_CAUSAL_MODELS)

super().__init__(model_name, quantization_8bits, quantization_4bits, dtype, max_fraction_gpu_0,
max_fraction_gpus, device_map, gpu_rank)
Expand Down
4 changes: 2 additions & 2 deletions textwiz/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def __init__(self, model_name: str, quantization_8bits: bool = False, quantizati
dtype: torch.dtype | None = None, max_fraction_gpu_0: float = 0.8, max_fraction_gpus: float = 0.8,
device_map: dict | str | None = None, gpu_rank: int = 0):

if model_name not in loader.ALLOWED_EMBEDDING_MODELS:
raise ValueError(f'The model name must be one of {*loader.ALLOWED_EMBEDDING_MODELS,}.')
# Check name against only embedding models
loader.check_model_name(model_name, loader.ALLOWED_EMBEDDING_MODELS)

super().__init__(model_name, quantization_8bits, quantization_4bits, dtype, max_fraction_gpu_0,
max_fraction_gpus, device_map, gpu_rank)
Expand Down

0 comments on commit 1cfd350

Please sign in to comment.