diff --git a/src/download_model.py b/src/download_model.py index 38e2ea7..76034c9 100644 --- a/src/download_model.py +++ b/src/download_model.py @@ -3,25 +3,23 @@ # Get the hugging face token HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN", None) -MODEL_NAME = os.environ.get("MODEL_NAME") -MODEL_REVISION = os.environ.get("MODEL_REVISION", "main") MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/") -def download_model(): +def download_model(model_name: str, model_revision: str): # Download the model from hugging face download_kwargs = {} if HUGGING_FACE_HUB_TOKEN: download_kwargs["token"] = HUGGING_FACE_HUB_TOKEN - DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{MODEL_NAME.split('/')[1]}" + DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{model_name.split('/')[1]}" print(f"Downloading model to: {DOWNLOAD_PATH}") downloaded_path = snapshot_download( - repo_id=MODEL_NAME, - revision=MODEL_REVISION, + repo_id=model_name, + revision=model_revision, local_dir=DOWNLOAD_PATH, local_dir_use_symlinks=False, **download_kwargs, diff --git a/src/inference.py b/src/inference.py index 1f83053..8ace319 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,8 +1,7 @@ import os -from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config +from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora from exllamav2.tokenizer import ExLlamaV2Tokenizer from exllamav2.generator import ( - ExLlamaV2BaseGenerator, ExLlamaV2Sampler, ExLlamaV2StreamingGenerator, ) @@ -11,6 +10,8 @@ MODEL_NAME = os.environ.get("MODEL_NAME") MODEL_REVISION = os.environ.get("MODEL_REVISION", "main") +LORA_NAME = os.environ.get("LORA_ADAPTER_NAME", None) +LORA_REVISION = os.environ.get("LORA_ADAPTER_REVISION", "main") MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/") @@ -23,7 +24,9 @@ def setup(self): if not os.path.isdir(model_directory): print("Downloading model...") try: - download_model() + download_model(model_name=MODEL_NAME, model_revision=MODEL_REVISION) + if LORA_NAME is not None: + download_model(model_name=LORA_NAME, model_revision=LORA_REVISION) except Exception as e: print(f"Error downloading model: {e}") # delete model directory if it exists @@ -41,6 +44,12 @@ def setup(self): self.cache = ExLlamaV2Cache(self.model) self.settings = ExLlamaV2Sampler.Settings() + # Load LORA adapter if specified + self.lora_adapter = None + if LORA_NAME is not None: + lora_directory = f"{MODEL_BASE_PATH}{LORA_NAME.split('/')[1]}" + self.lora_adapter = ExLlamaV2Lora.from_directory(self.model, lora_directory) + def predict(self, settings): ### Set the generation settings self.settings.temperature = settings["temperature"] @@ -63,7 +72,7 @@ def streamGenerate(self, prompt, max_new_tokens): generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer) generator.warmup() - generator.begin_stream(input_ids, self.settings) + generator.begin_stream(input_ids, self.settings, loras=self.lora_adapter) generated_tokens = 0 while True: