From 9752a6cf239c1278adb91aa480d6a03e86e3ad9e Mon Sep 17 00:00:00 2001 From: Richard Gong Date: Wed, 14 Feb 2024 23:36:21 +0000 Subject: [PATCH] Update codebase to work better with Mixtral out of the box --- README.md | 2 +- config/mixtral.yml | 2 +- src/common.py | 2 +- src/inference.py | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9a15ec0..a36a445 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Fine-tune any LLM in minutes (ft. LLaMA, CodeLlama, Mistral) +# Fine-tune any LLM in minutes (ft. Mixtral, LLaMA, Mistral) ### Tired of prompt engineering? You've come to the right place. diff --git a/config/mixtral.yml b/config/mixtral.yml index 15ea108..301842d 100644 --- a/config/mixtral.yml +++ b/config/mixtral.yml @@ -69,7 +69,7 @@ wandb_name: wandb_log_model: gradient_accumulation_steps: 1 -micro_batch_size: 8 +micro_batch_size: 16 num_epochs: 1 optimizer: adamw_bnb_8bit lr_scheduler: cosine diff --git a/src/common.py b/src/common.py index c640c90..32d15bf 100644 --- a/src/common.py +++ b/src/common.py @@ -22,7 +22,7 @@ vllm_image = Image.from_registry( "nvidia/cuda:12.1.0-base-ubuntu22.04", add_python="3.10" ).pip_install( - "vllm==0.2.5", + "vllm==0.2.6", "torch==2.1.2", ) diff --git a/src/inference.py b/src/inference.py index 8dc6bec..b4f0481 100644 --- a/src/inference.py +++ b/src/inference.py @@ -3,9 +3,10 @@ from .common import stub, vllm_image, VOLUME_CONFIG +N_INFERENCE_GPU = 2 @stub.cls( - gpu="A100", + gpu=modal.gpu.H100(count=N_INFERENCE_GPU), image=vllm_image, volumes=VOLUME_CONFIG, allow_concurrent_inputs=30, @@ -18,7 +19,7 @@ def __init__(self, model_path: str) -> None: from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine - engine_args = AsyncEngineArgs(model=model_path, gpu_memory_utilization=0.95) + engine_args = AsyncEngineArgs(model=model_path, gpu_memory_utilization=0.95, tensor_parallel_size=N_INFERENCE_GPU) self.engine = AsyncLLMEngine.from_engine_args(engine_args) @modal.method()