Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds logic to accept LoRAs #2

Merged
merged 3 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ These are the build arguments:
| MODEL_NAME | your model name | false |
| MODEL_REVISION | your model revision | true |
| MODEL_BASE_PATH | your model base path | true |
| LORA_ADAPTER_NAME | your lora adapter name | true |
| LORA_ADAPTER_REVISION | your lora adapter revision | true |

### ⏫ push docker image to your docker registry
```bash
Expand Down
2 changes: 1 addition & 1 deletion builder/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Reccomended to lock the version number to avoid unexpected changes.

runpod==1.2.0
exllamav2==0.0.2
exllamav2==0.0.6
pandas
ninja
fastparquet
Expand Down
10 changes: 4 additions & 6 deletions src/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,23 @@

# Get the hugging face token
HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN", None)
MODEL_NAME = os.environ.get("MODEL_NAME")
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/")


def download_model():
def download_model(model_name: str, model_revision: str):
# Download the model from hugging face
download_kwargs = {}

if HUGGING_FACE_HUB_TOKEN:
download_kwargs["token"] = HUGGING_FACE_HUB_TOKEN

DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{MODEL_NAME.split('/')[1]}"
DOWNLOAD_PATH = f"{MODEL_BASE_PATH}{model_name.split('/')[1]}"

print(f"Downloading model to: {DOWNLOAD_PATH}")

downloaded_path = snapshot_download(
repo_id=MODEL_NAME,
revision=MODEL_REVISION,
repo_id=model_name,
revision=model_revision,
local_dir=DOWNLOAD_PATH,
local_dir_use_symlinks=False,
**download_kwargs,
Expand Down
17 changes: 13 additions & 4 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from exllamav2.model import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config, ExLlamaV2Lora
from exllamav2.tokenizer import ExLlamaV2Tokenizer
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler,
ExLlamaV2StreamingGenerator,
)
Expand All @@ -11,6 +10,8 @@

MODEL_NAME = os.environ.get("MODEL_NAME")
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main")
LORA_NAME = os.environ.get("LORA_ADAPTER_NAME", None)
LORA_REVISION = os.environ.get("LORA_ADAPTER_REVISION", "main")
MODEL_BASE_PATH = os.environ.get("MODEL_BASE_PATH", "/runpod-volume/")


Expand All @@ -23,7 +24,9 @@ def setup(self):
if not os.path.isdir(model_directory):
print("Downloading model...")
try:
download_model()
download_model(model_name=MODEL_NAME, model_revision=MODEL_REVISION)
if LORA_NAME is not None:
download_model(model_name=LORA_NAME, model_revision=LORA_REVISION)
except Exception as e:
print(f"Error downloading model: {e}")
# delete model directory if it exists
Expand All @@ -41,6 +44,12 @@ def setup(self):
self.cache = ExLlamaV2Cache(self.model)
self.settings = ExLlamaV2Sampler.Settings()

# Load LORA adapter if specified
self.lora_adapter = None
if LORA_NAME is not None:
lora_directory = f"{MODEL_BASE_PATH}{LORA_NAME.split('/')[1]}"
self.lora_adapter = ExLlamaV2Lora.from_directory(self.model, lora_directory)

def predict(self, settings):
### Set the generation settings
self.settings.temperature = settings["temperature"]
Expand All @@ -63,7 +72,7 @@ def streamGenerate(self, prompt, max_new_tokens):
generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)
generator.warmup()

generator.begin_stream(input_ids, self.settings)
generator.begin_stream(input_ids, self.settings, loras=self.lora_adapter)
generated_tokens = 0

while True:
Expand Down