diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 0b76f466702fc..a099f36b0a465 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -1,4 +1,5 @@ import json +import math import os from typing import Dict, List @@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files): "embed_tokens", "lm_head", ] + scaling = peft_helper.lora_alpha / peft_helper.r + assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 + + # test RSLoRA + config = dict(r=8, + lora_alpha=16, + target_modules=["gate_proj"], + use_rslora=True) + peft_helper = PEFTHelper.from_dict(config) + + scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r) + assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 expected_error = "vLLM only supports modules_to_save being None." with pytest.raises(ValueError, match=expected_error): @@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files): modules_to_save=["lm_head"], ) PEFTHelper.from_dict(config) - expected_error = "vLLM does not yet support RSLoRA." - with pytest.raises(ValueError, match=expected_error): - config = dict(r=8, - lora_alpha=16, - target_modules=["gate_proj"], - use_rslora=True) - PEFTHelper.from_dict(config) expected_error = "vLLM does not yet support DoRA." with pytest.raises(ValueError, match=expected_error): diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index dde347b78bf81..93ad4651f4b77 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -67,15 +67,9 @@ def from_config( peft_helper: PEFTHelper, embeddings_tensor: Optional[torch.Tensor] = None, ) -> "LoRALayerWeights": - return cls( - module_name, - peft_helper.r, - peft_helper.lora_alpha, - None, - None, - None, - embeddings_tensor, - ) + return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None, + None, None, embeddings_tensor, + peft_helper.vllm_lora_scaling_factor) @classmethod def create_dummy_lora_weights( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 5c0e4e5cbc636..9cfcc6bba727f 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -173,7 +173,7 @@ def from_lora_tensors( return cls(lora_model_id, peft_helper.r, loras, - scaling_factor=peft_helper.vllm_scaling_factor) + scaling_factor=peft_helper.vllm_long_context_scaling_factor) @classmethod def from_local_checkpoint( diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index edf4ba5659575..ddd42ae93d290 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -4,6 +4,8 @@ from dataclasses import MISSING, dataclass, field, fields from typing import Literal, Optional, Union +from vllm.utils import print_info_once + @dataclass class PEFTHelper: @@ -14,21 +16,22 @@ class PEFTHelper: bias: Literal["none", "all", "lora_only"] = field(default="none") modules_to_save: Optional[list[str]] = field(default=None) + # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) use_rslora: bool = field(default=False) + # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) use_dora: bool = field(default=False) - # long lora field + # long context lora field context_length: int = field(default=0) # Extra vllm field, start with 'vllm_' to avoid conflict + vllm_lora_scaling_factor: float = field(default=1.0) vllm_max_position_embeddings: Optional[int] = field(default=False) - vllm_scaling_factor: Optional[float] = field(default=None) + vllm_long_context_scaling_factor: Optional[float] = field(default=None) def _validate_features(self): error_msg = [] if self.modules_to_save: error_msg.append("vLLM only supports modules_to_save being None.") - if self.use_rslora: - error_msg.append("vLLM does not yet support RSLoRA.") if self.use_dora: error_msg.append("vLLM does not yet support DoRA.") @@ -38,10 +41,15 @@ def _validate_features(self): def __post_init__(self): self._validate_features() + if self.use_rslora: + print_info_once("Loading LoRA weights trained with rsLoRA.") + self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) + else: + self.vllm_lora_scaling_factor = self.lora_alpha / self.r if self.context_length: if self.vllm_max_position_embeddings is None: self.vllm_max_position_embeddings = self.context_length - self.vllm_scaling_factor = float( + self.vllm_long_context_scaling_factor = float( math.ceil(self.context_length / self.vllm_max_position_embeddings))