Skip to content

Commit

Permalink
Added logic to accept loras
Browse files Browse the repository at this point in the history
  • Loading branch information
Royal-lobster committed Oct 16, 2023
1 parent 18bc25f commit 006dc24
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
10 changes: 4 additions & 6 deletions src/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,23 @@

# Get the hugging face token
HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN", None)
MODEL_NAME = os.environ.get("MODEL_NAME")
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/")


def download_model():
def download_model(model_name: str, model_revision: str):
# Download the model from hugging face
download_kwargs = {}

if HUGGING_FACE_HUB_TOKEN:
download_kwargs["token"] = HUGGING_FACE_HUB_TOKEN

DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{MODEL_NAME.split('/')[1]}"
DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{model_name.split('/')[1]}"

print(f"Downloading model to: {DOWNLOAD_PATH}")

downloaded_path = snapshot_download(
repo_id=MODEL_NAME,
revision=MODEL_REVISION,
repo_id=model_name,
revision=model_revision,
local_dir=DOWNLOAD_PATH,
local_dir_use_symlinks=False,
**download_kwargs,
Expand Down
17 changes: 13 additions & 4 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora
from exllamav2.tokenizer import ExLlamaV2Tokenizer
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler,
ExLlamaV2StreamingGenerator,
)
Expand All @@ -11,6 +10,8 @@

MODEL_NAME = os.environ.get("MODEL_NAME")
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
LORA_NAME = os.environ.get("LORA_ADAPTER_NAME", None)
LORA_REVISION = os.environ.get("LORA_ADAPTER_REVISION", "main")
MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/")


Expand All @@ -23,7 +24,9 @@ def setup(self):
if not os.path.isdir(model_directory):
print("Downloading model...")
try:
download_model()
download_model(model_name=MODEL_NAME, model_revision=MODEL_REVISION)
if LORA_NAME is not None:
download_model(model_name=LORA_NAME, model_revision=LORA_REVISION)
except Exception as e:
print(f"Error downloading model: {e}")
# delete model directory if it exists
Expand All @@ -41,6 +44,12 @@ def setup(self):
self.cache = ExLlamaV2Cache(self.model)
self.settings = ExLlamaV2Sampler.Settings()

# Load LORA adapter if specified
self.lora_adapter = None
if LORA_NAME is not None:
lora_directory = f"{MODEL_BASE_PATH}{LORA_NAME.split('/')[1]}"
self.lora_adapter = ExLlamaV2Lora.from_directory(self.model, lora_directory)

def predict(self, settings):
### Set the generation settings
self.settings.temperature = settings["temperature"]
Expand All @@ -63,7 +72,7 @@ def streamGenerate(self, prompt, max_new_tokens):
generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)
generator.warmup()

generator.begin_stream(input_ids, self.settings)
generator.begin_stream(input_ids, self.settings, loras=self.lora_adapter)
generated_tokens = 0

while True:
Expand Down

0 comments on commit 006dc24

Please sign in to comment.