From 006dc24786ea6f5c569e6f90fb4364c7f248dc53 Mon Sep 17 00:00:00 2001 From: Srujan Gurram Date: Mon, 16 Oct 2023 14:47:51 +0530 Subject: [PATCH 1/3] Added logic to accept loras --- src/download_model.py | 10 ++++------ src/inference.py | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/download_model.py b/src/download_model.py index 38e2ea7..76034c9 100644 --- a/src/download_model.py +++ b/src/download_model.py @@ -3,25 +3,23 @@ # Get the hugging face token HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN", None) -MODEL_NAME = os.environ.get("MODEL_NAME") -MODEL_REVISION = os.environ.get("MODEL_REVISION", "main") MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/") -def download_model(): +def download_model(model_name: str, model_revision: str): # Download the model from hugging face download_kwargs = {} if HUGGING_FACE_HUB_TOKEN: download_kwargs["token"] = HUGGING_FACE_HUB_TOKEN - DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{MODEL_NAME.split('/')[1]}" + DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{model_name.split('/')[1]}" print(f"Downloading model to: {DOWNLOAD_PATH}") downloaded_path = snapshot_download( - repo_id=MODEL_NAME, - revision=MODEL_REVISION, + repo_id=model_name, + revision=model_revision, local_dir=DOWNLOAD_PATH, local_dir_use_symlinks=False, **download_kwargs, diff --git a/src/inference.py b/src/inference.py index 1f83053..8ace319 100644 --- a/src/inference.py +++ b/src/inference.py @@ -1,8 +1,7 @@ import os -from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config +from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora from exllamav2.tokenizer import ExLlamaV2Tokenizer from exllamav2.generator import ( - ExLlamaV2BaseGenerator, ExLlamaV2Sampler, ExLlamaV2StreamingGenerator, ) @@ -11,6 +10,8 @@ MODEL_NAME = os.environ.get("MODEL_NAME") MODEL_REVISION = os.environ.get("MODEL_REVISION", "main") +LORA_NAME = os.environ.get("LORA_ADAPTER_NAME", None) +LORA_REVISION = os.environ.get("LORA_ADAPTER_REVISION", "main") MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/") @@ -23,7 +24,9 @@ def setup(self): if not os.path.isdir(model_directory): print("Downloading model...") try: - download_model() + download_model(model_name=MODEL_NAME, model_revision=MODEL_REVISION) + if LORA_NAME is not None: + download_model(model_name=LORA_NAME, model_revision=LORA_REVISION) except Exception as e: print(f"Error downloading model: {e}") # delete model directory if it exists @@ -41,6 +44,12 @@ def setup(self): self.cache = ExLlamaV2Cache(self.model) self.settings = ExLlamaV2Sampler.Settings() + # Load LORA adapter if specified + self.lora_adapter = None + if LORA_NAME is not None: + lora_directory = f"{MODEL_BASE_PATH}{LORA_NAME.split('/')[1]}" + self.lora_adapter = ExLlamaV2Lora.from_directory(self.model, lora_directory) + def predict(self, settings): ### Set the generation settings self.settings.temperature = settings["temperature"] @@ -63,7 +72,7 @@ def streamGenerate(self, prompt, max_new_tokens): generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer) generator.warmup() - generator.begin_stream(input_ids, self.settings) + generator.begin_stream(input_ids, self.settings, loras=self.lora_adapter) generated_tokens = 0 while True: From 05ba5e4ac7f6f506b9630ff66db84e9064591b99 Mon Sep 17 00:00:00 2001 From: Srujan Gurram Date: Mon, 16 Oct 2023 14:50:01 +0530 Subject: [PATCH 2/3] Add support for configuring the Lora adapter name and revision --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 8733fd5..e18a8ce 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ These are the build arguments: | MODEL_NAME | your model name | false | | MODEL_REVISION | your model revision | true | | MODEL_BASE_PATH | your model base path | true | +| LORA_ADAPTER_NAME | your lora adapter name | true | +| LORA_ADAPTER_REVISION | your lora adapter revision | true | ### ⏫ push docker image to your docker registry ```bash From 90ced7c9a0bd5313b2843344f6aa305235af0c9f Mon Sep 17 00:00:00 2001 From: Srujan Gurram Date: Mon, 16 Oct 2023 15:00:00 +0530 Subject: [PATCH 3/3] Update exllamav2 version to 0.0.6 in requirements.txt --- builder/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/builder/requirements.txt b/builder/requirements.txt index 152f708..b382dcb 100644 --- a/builder/requirements.txt +++ b/builder/requirements.txt @@ -2,7 +2,7 @@ # Reccomended to lock the version number to avoid unexpected changes. runpod==1.2.0 -exllamav2==0.0.2 +exllamav2==0.0.6 pandas ninja fastparquet