Skip to content

Commit

Permalink
Update codebase to work better with Mixtral out of the box
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Feb 14, 2024
1 parent aeb0765 commit 9752a6c
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Fine-tune any LLM in minutes (ft. LLaMA, CodeLlama, Mistral)
# Fine-tune any LLM in minutes (ft. Mixtral, LLaMA, Mistral)

### Tired of prompt engineering? You've come to the right place.

Expand Down
2 changes: 1 addition & 1 deletion config/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 8
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
Expand Down
2 changes: 1 addition & 1 deletion src/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
vllm_image = Image.from_registry(
"nvidia/cuda:12.1.0-base-ubuntu22.04", add_python="3.10"
).pip_install(
"vllm==0.2.5",
"vllm==0.2.6",
"torch==2.1.2",
)

Expand Down
5 changes: 3 additions & 2 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from .common import stub, vllm_image, VOLUME_CONFIG

N_INFERENCE_GPU = 2

@stub.cls(
gpu="A100",
gpu=modal.gpu.H100(count=N_INFERENCE_GPU),
image=vllm_image,
volumes=VOLUME_CONFIG,
allow_concurrent_inputs=30,
Expand All @@ -18,7 +19,7 @@ def __init__(self, model_path: str) -> None:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

engine_args = AsyncEngineArgs(model=model_path, gpu_memory_utilization=0.95)
engine_args = AsyncEngineArgs(model=model_path, gpu_memory_utilization=0.95, tensor_parallel_size=N_INFERENCE_GPU)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

@modal.method()
Expand Down

0 comments on commit 9752a6c

Please sign in to comment.