diff --git a/textwiz/loader.py b/textwiz/loader.py index 8e99731..b77f643 100644 --- a/textwiz/loader.py +++ b/textwiz/loader.py @@ -2,6 +2,7 @@ import re import math import importlib.metadata +import difflib from packaging import version import torch @@ -34,6 +35,28 @@ } + +def check_model_name(model_name: str, available_models: list[str] | tuple[str] = ALLOWED_MODELS): + """Ensure that the `model_name` is valid, and raise an error trying to find closest model otherwise. + + Parameters + ---------- + model_name : str + The model name. + available_models : list[str] | tuple[str], optional + List of available models to check for closest match. By default check all models. + """ + + if model_name not in available_models: + closest_match = difflib.get_close_matches(model_name, available_models, n=1) + if len(closest_match) > 0: + raise ValueError(f'The model name you provided is invalid. Perhaps you meant "{closest_match[0]}"?') + else: + raise ValueError(f'The model name you provided is invalid.') + else: + return + + def get_model_params(model_name: str) -> float: """Return the approximate number of params of the model, in billions. @@ -47,10 +70,7 @@ def get_model_params(model_name: str) -> float: float The number of parameters. """ - - if model_name not in ALLOWED_MODELS: - raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.') - + check_model_name(model_name) return ALL_MODELS_PARAMS[model_name] @@ -67,10 +87,7 @@ def get_model_dtype(model_name: str) -> torch.dtype: torch.dtype The default dtype. """ - - if model_name not in ALLOWED_MODELS: - raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.') - + check_model_name(model_name) return ALL_MODELS_DTYPES[model_name] @@ -87,10 +104,7 @@ def get_model_context_size(model_name: str) -> int: int The context size. """ - - if model_name not in ALLOWED_MODELS: - raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.') - + check_model_name(model_name) return ALL_MODELS_CONTEXT_SIZE[model_name] @@ -270,8 +284,8 @@ def load_model(model_name: str, quantization_8bits: bool = False, quantization_4 The model. """ - if model_name not in ALLOWED_MODELS: - raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.') + # Check name + check_model_name(model_name) # Check package versions check_versions(model_name) @@ -408,8 +422,8 @@ def load_tokenizer(model_name: str): The tokenizer. """ - if model_name not in ALLOWED_MODELS: - raise ValueError(f'The model name must be one of {*ALLOWED_MODELS,}.') + # Check name + check_model_name(model_name) # Check package versions check_versions(model_name) diff --git a/textwiz/models/causal.py b/textwiz/models/causal.py index 5909b84..63483c4 100644 --- a/textwiz/models/causal.py +++ b/textwiz/models/causal.py @@ -34,8 +34,8 @@ def __init__(self, model_name: str, quantization_8bits: bool = False, quantizati dtype: torch.dtype | None = None, max_fraction_gpu_0: float = 0.8, max_fraction_gpus: float = 0.8, device_map: dict | str | None = None, gpu_rank: int = 0): - if model_name not in loader.ALLOWED_CAUSAL_MODELS: - raise ValueError(f'The model name must be one of {*loader.ALLOWED_CAUSAL_MODELS,}.') + # Check name against only causal models + loader.check_model_name(model_name, loader.ALLOWED_CAUSAL_MODELS) super().__init__(model_name, quantization_8bits, quantization_4bits, dtype, max_fraction_gpu_0, max_fraction_gpus, device_map, gpu_rank) diff --git a/textwiz/models/embedding.py b/textwiz/models/embedding.py index c3522a1..e9716e0 100644 --- a/textwiz/models/embedding.py +++ b/textwiz/models/embedding.py @@ -14,8 +14,8 @@ def __init__(self, model_name: str, quantization_8bits: bool = False, quantizati dtype: torch.dtype | None = None, max_fraction_gpu_0: float = 0.8, max_fraction_gpus: float = 0.8, device_map: dict | str | None = None, gpu_rank: int = 0): - if model_name not in loader.ALLOWED_EMBEDDING_MODELS: - raise ValueError(f'The model name must be one of {*loader.ALLOWED_EMBEDDING_MODELS,}.') + # Check name against only embedding models + loader.check_model_name(model_name, loader.ALLOWED_EMBEDDING_MODELS) super().__init__(model_name, quantization_8bits, quantization_4bits, dtype, max_fraction_gpu_0, max_fraction_gpus, device_map, gpu_rank)