Skip to content

Commit 923a931

Browse files
authored
enable full passthrough of vllm engine args, with backwards compatibi… (deepjavalibrary#2639)
1 parent 958d447 commit 923a931

File tree

6 files changed

+451
-261
lines changed

6 files changed

+451
-261
lines changed

engines/python/setup/djl_python/properties_manager/properties.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ class Properties(BaseModel):
6161
input_formatter: Optional[Callable] = None
6262
waiting_steps: Optional[int] = None
6363
mpi_mode: bool = False
64-
tgi_compat: Optional[bool] = False
65-
bedrock_compat: Optional[bool] = False
66-
enable_lora: Optional[bool] = False
64+
tgi_compat: bool = False
65+
bedrock_compat: bool = False
66+
enable_lora: bool = False
6767

6868
# Spec_dec
6969
draft_model_id: Optional[str] = None

engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py

+167-105
Original file line numberDiff line numberDiff line change
@@ -11,70 +11,76 @@
1111
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
1212
# the specific language governing permissions and limitations under the License.
1313
import ast
14-
from typing import Optional, Any, Mapping, Tuple, Dict
15-
16-
from pydantic import field_validator, model_validator
14+
import logging
15+
from typing import Optional, Any, Dict, Tuple
16+
from pydantic import field_validator, model_validator, ConfigDict, Field
17+
from vllm import EngineArgs
18+
from vllm.utils import FlexibleArgumentParser
19+
from vllm.engine.arg_utils import StoreBoolean
1720

1821
from djl_python.properties_manager.properties import Properties
1922

23+
DTYPE_MAPPER = {
24+
"float32": "float32",
25+
"fp32": "float32",
26+
"float16": "float16",
27+
"fp16": "float16",
28+
"bfloat16": "bfloat16",
29+
"bf16": "bfloat16",
30+
"auto": "auto"
31+
}
32+
33+
34+
def construct_vllm_args_list(vllm_engine_args: dict,
35+
parser: FlexibleArgumentParser):
36+
# Modified from https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/utils.py#L1258
37+
args_list = []
38+
store_boolean_arguments = {
39+
action.dest
40+
for action in parser._actions if isinstance(action, StoreBoolean)
41+
}
42+
for engine_arg, engine_arg_value in vllm_engine_args.items():
43+
if str(engine_arg_value).lower() in {
44+
'true', 'false'
45+
} and engine_arg not in store_boolean_arguments:
46+
if str(engine_arg_value).lower() == 'true':
47+
args_list.append(f"--{engine_arg}")
48+
else:
49+
args_list.append(f"--{engine_arg}={engine_arg_value}")
50+
return args_list
51+
2052

2153
class VllmRbProperties(Properties):
2254
engine: Optional[str] = None
23-
dtype: Optional[str] = "auto"
24-
load_format: Optional[str] = "auto"
25-
quantize: Optional[str] = None
55+
# The following configs have different names in DJL compared to vLLM, we only accept DJL name currently
2656
tensor_parallel_degree: int = 1
2757
pipeline_parallel_degree: int = 1
28-
max_rolling_batch_prefill_tokens: Optional[int] = None
29-
# Adjustable prefix model length for certain 32k or longer model
30-
max_model_len: Optional[int] = None
31-
enforce_eager: Optional[bool] = False
32-
# TODO: this default may change with different vLLM versions
33-
# TODO: try to get good default from vLLM to prevent revisiting
34-
# TODO: last time check: vllm 0.3.1
35-
gpu_memory_utilization: Optional[float] = 0.9
36-
enable_lora: Optional[bool] = False
37-
max_loras: Optional[int] = 4
38-
max_lora_rank: Optional[int] = 16
39-
fully_sharded_loras: bool = False
40-
lora_extra_vocab_size: int = 256
58+
# The following configs have different names in DJL compared to vLLM, either is accepted
59+
quantize: Optional[str] = Field(alias="quantization",
60+
default=EngineArgs.quantization)
61+
max_rolling_batch_prefill_tokens: Optional[int] = Field(
62+
alias="max_num_batched_tokens",
63+
default=EngineArgs.max_num_batched_tokens)
64+
cpu_offload_gb_per_gpu: float = Field(alias="cpu_offload_gb",
65+
default=EngineArgs.cpu_offload_gb)
66+
# The following configs have different defaults, or additional processing in DJL compared to vLLM
67+
dtype: str = "auto"
68+
max_loras: int = 4
69+
# The following configs have broken processing in vllm via the FlexibleArgumentParser
4170
long_lora_scaling_factors: Optional[Tuple[float, ...]] = None
42-
lora_dtype: Optional[str] = 'auto'
43-
max_cpu_loras: Optional[int] = None
71+
use_v2_block_manager: bool = True
72+
# Tool calling properties
73+
enable_auto_tool_choice: bool = False
74+
tool_call_parser: Optional[str] = None
4475

4576
# Neuron vLLM properties
46-
device: Optional[str] = None
77+
device: str = 'auto'
4778
preloaded_model: Optional[Any] = None
4879
generation_config: Optional[Any] = None
4980
override_neuron_config: Optional[Dict] = None
5081

51-
max_logprobs: Optional[int] = 20
52-
enable_chunked_prefill: Optional[bool] = None
53-
cpu_offload_gb_per_gpu: Optional[int] = 0
54-
enable_prefix_caching: Optional[bool] = False
55-
disable_sliding_window: Optional[bool] = False
56-
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
57-
use_v2_block_manager: bool = False
58-
tokenizer_mode: str = 'auto'
59-
60-
# Speculative decoding configuration.
61-
speculative_model: Optional[str] = None
62-
speculative_model_quantization: Optional[str] = None
63-
speculative_draft_tensor_parallel_size: Optional[int] = None
64-
num_speculative_tokens: Optional[int] = None
65-
speculative_max_model_len: Optional[int] = None
66-
speculative_disable_by_batch_size: Optional[int] = None
67-
ngram_prompt_lookup_max: Optional[int] = None
68-
ngram_prompt_lookup_min: Optional[int] = None
69-
spec_decoding_acceptance_method: str = 'rejection_sampler'
70-
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
71-
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
72-
qlora_adapter_name_or_path: Optional[str] = None
73-
disable_logprobs_during_spec_decoding: Optional[bool] = None
74-
75-
# Tool calling properties
76-
enable_auto_tool_choice: Optional[bool] = False
77-
tool_call_parser: Optional[str] = None
82+
# This allows generic vllm engine args to be passed in and set with vllm
83+
model_config = ConfigDict(extra='allow', populate_by_name=True)
7884

7985
@field_validator('engine')
8086
def validate_engine(cls, engine):
@@ -83,64 +89,13 @@ def validate_engine(cls, engine):
8389
f"Need python engine to start vLLM RollingBatcher")
8490
return engine
8591

86-
@field_validator('long_lora_scaling_factors', mode='before')
87-
def validate_long_lora_scaling_factors(cls, val):
88-
if isinstance(val, str):
89-
val = ast.literal_eval(val)
90-
if not isinstance(val, tuple):
91-
if isinstance(val, list):
92-
val = tuple(float(v) for v in val)
93-
elif isinstance(val, float):
94-
val = (val, )
95-
elif isinstance(val, int):
96-
val = (float(val), )
97-
else:
98-
raise ValueError(
99-
"long_lora_scaling_factors must be convertible to a tuple of floats."
100-
)
101-
return val
102-
103-
@field_validator('limit_mm_per_prompt', mode="before")
104-
def validate_limit_mm_per_prompt(cls, val) -> Mapping[str, int]:
105-
out_dict: Dict[str, int] = {}
106-
for item in val.split(","):
107-
kv_parts = [part.lower().strip() for part in item.split("=")]
108-
if len(kv_parts) != 2:
109-
raise ValueError("Each item should be in the form key=value")
110-
key, value = kv_parts
111-
112-
try:
113-
parsed_value = int(value)
114-
except ValueError as e:
115-
raise ValueError(
116-
f"Failed to parse value of item {key}={value}") from e
117-
118-
if key in out_dict and out_dict[key] != parsed_value:
119-
raise ValueError(
120-
f"Conflicting values specified for key: {key}")
121-
out_dict[key] = parsed_value
122-
return out_dict
123-
124-
@field_validator('override_neuron_config', mode="before")
125-
def validate_override_neuron_config(cls, val):
126-
if isinstance(val, str):
127-
neuron_config = ast.literal_eval(val)
128-
if not isinstance(neuron_config, dict):
129-
raise ValueError(
130-
f"Invalid json format for override_neuron_config")
131-
return neuron_config
132-
elif isinstance(val, Dict):
133-
return val
134-
else:
135-
raise ValueError("Invalid json format for override_neuron_config")
136-
137-
@model_validator(mode='after')
138-
def validate_speculative_model(self):
139-
if self.speculative_model is not None and not self.use_v2_block_manager:
92+
@field_validator('dtype')
93+
def validate_dtype(cls, val):
94+
if val not in DTYPE_MAPPER:
14095
raise ValueError(
141-
"Speculative decoding requires usage of the V2 block manager. Enable it with option.use_v2_block_manager=true."
96+
f"Invalid dtype={val} provided. Must be one of {DTYPE_MAPPER.keys()}"
14297
)
143-
return self
98+
return DTYPE_MAPPER[val]
14499

145100
@model_validator(mode='after')
146101
def validate_pipeline_parallel(self):
@@ -159,3 +114,110 @@ def validate_tool_call_parser(self):
159114
raise ValueError(
160115
f"Invalid tool call parser: {self.tool_call_parser} "
161116
f"(chose from {{ {','.join(valid_tool_parses)} }})")
117+
118+
@field_validator('override_neuron_config', mode="before")
119+
def validate_override_neuron_config(cls, val):
120+
if isinstance(val, str):
121+
neuron_config = ast.literal_eval(val)
122+
if not isinstance(neuron_config, dict):
123+
raise ValueError(
124+
f"Invalid json format for override_neuron_config")
125+
return neuron_config
126+
elif isinstance(val, Dict):
127+
return val
128+
else:
129+
raise ValueError("Invalid json format for override_neuron_config")
130+
131+
# TODO: processing of this field is broken in vllm via from_cli_args
132+
# we should upstream a fix for this to vllm
133+
@field_validator('long_lora_scaling_factors', mode='before')
134+
def validate_long_lora_scaling_factors(cls, val):
135+
if isinstance(val, str):
136+
val = ast.literal_eval(val)
137+
if not isinstance(val, tuple):
138+
if isinstance(val, list):
139+
val = tuple(float(v) for v in val)
140+
elif isinstance(val, float):
141+
val = (val, )
142+
elif isinstance(val, int):
143+
val = (float(val), )
144+
else:
145+
raise ValueError(
146+
"long_lora_scaling_factors must be convertible to a tuple of floats."
147+
)
148+
return val
149+
150+
def handle_lmi_vllm_config_conflicts(self, additional_vllm_engine_args):
151+
152+
def validate_potential_lmi_vllm_config_conflict(
153+
lmi_config_name, vllm_config_name):
154+
lmi_config_val = self.__getattribute__(lmi_config_name)
155+
vllm_config_val = additional_vllm_engine_args.get(vllm_config_name)
156+
if vllm_config_val is not None and lmi_config_val is not None:
157+
if vllm_config_val != lmi_config_val:
158+
raise ValueError(
159+
f"Both the DJL {lmi_config_val}={lmi_config_val} and vLLM {vllm_config_name}={vllm_config_val} configs have been set with conflicting values."
160+
f"We currently only accept the DJL config {lmi_config_name}, please remove the vllm {vllm_config_name} configuration."
161+
)
162+
163+
validate_potential_lmi_vllm_config_conflict("tensor_parallel_degree",
164+
"tensor_parallel_size")
165+
validate_potential_lmi_vllm_config_conflict("pipeline_parallel_degree",
166+
"pipeline_parallel_size")
167+
validate_potential_lmi_vllm_config_conflict("max_rolling_batch_size",
168+
"max_num_seqs")
169+
170+
def generate_vllm_engine_arg_dict(self,
171+
passthrough_vllm_engine_args) -> dict:
172+
vllm_engine_args = {
173+
'model': self.model_id_or_path,
174+
'tensor_parallel_size': self.tensor_parallel_degree,
175+
'pipeline_parallel_size': self.pipeline_parallel_degree,
176+
'max_num_seqs': self.max_rolling_batch_size,
177+
'dtype': DTYPE_MAPPER[self.dtype],
178+
'revision': self.revision,
179+
'max_loras': self.max_loras,
180+
'enable_lora': self.enable_lora,
181+
'trust_remote_code': self.trust_remote_code,
182+
'cpu_offload_gb': self.cpu_offload_gb_per_gpu,
183+
'use_v2_block_manager': self.use_v2_block_manager,
184+
'quantization': self.quantize,
185+
'device': self.device,
186+
}
187+
if self.max_rolling_batch_prefill_tokens is not None:
188+
vllm_engine_args[
189+
'max_num_batched_tokens'] = self.max_rolling_batch_prefill_tokens
190+
if self.device == 'neuron':
191+
vllm_engine_args['block_size'] = passthrough_vllm_engine_args.get(
192+
"max_model_len")
193+
vllm_engine_args.update(passthrough_vllm_engine_args)
194+
return vllm_engine_args
195+
196+
def get_engine_args(self) -> EngineArgs:
197+
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
198+
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
199+
vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict(
200+
additional_vllm_engine_args)
201+
logging.debug(
202+
f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}"
203+
)
204+
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
205+
args_list = construct_vllm_args_list(vllm_engine_arg_dict, parser)
206+
args = parser.parse_args(args=args_list)
207+
engine_args = EngineArgs.from_cli_args(args)
208+
# we have to do this separately because vllm converts it into a string
209+
engine_args.long_lora_scaling_factors = self.long_lora_scaling_factors
210+
# These neuron configs are not implemented in the vllm arg parser
211+
if self.device == 'neuron':
212+
setattr(engine_args, 'preloaded_model', self.preloaded_model)
213+
setattr(engine_args, 'generation_config', self.generation_config)
214+
setattr(engine_args, 'override_neuron_config',
215+
self.override_neuron_config)
216+
return engine_args
217+
218+
def get_additional_vllm_engine_args(self) -> Dict[str, Any]:
219+
return {
220+
k: v
221+
for k, v in self.__pydantic_extra__.items()
222+
if k in EngineArgs.__annotations__
223+
}

engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py

-68
Original file line numberDiff line numberDiff line change
@@ -208,74 +208,6 @@ def get_lora_request(lora_name: str, lora_requests: dict) -> dict:
208208
return lora_requests[lora_name]
209209

210210

211-
def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
212-
if config.device == "neuron":
213-
return EngineArgs(model=config.model_id_or_path,
214-
tensor_parallel_size=config.tensor_parallel_degree,
215-
dtype=DTYPE_MAPPER[config.dtype],
216-
seed=0,
217-
max_model_len=config.max_model_len,
218-
max_num_seqs=config.max_rolling_batch_size,
219-
block_size=config.max_model_len,
220-
trust_remote_code=config.trust_remote_code,
221-
revision=config.revision,
222-
device=config.device,
223-
override_neuron_config=config.override_neuron_config)
224-
else:
225-
return EngineArgs(
226-
model=config.model_id_or_path,
227-
tensor_parallel_size=config.tensor_parallel_degree,
228-
pipeline_parallel_size=config.pipeline_parallel_degree,
229-
dtype=DTYPE_MAPPER[config.dtype],
230-
seed=0,
231-
max_model_len=config.max_model_len,
232-
enforce_eager=config.enforce_eager,
233-
gpu_memory_utilization=config.gpu_memory_utilization,
234-
max_num_batched_tokens=config.max_rolling_batch_prefill_tokens,
235-
trust_remote_code=config.trust_remote_code,
236-
load_format=config.load_format,
237-
quantization=config.quantize,
238-
enable_lora=config.enable_lora,
239-
max_loras=config.max_loras,
240-
max_lora_rank=config.max_lora_rank,
241-
fully_sharded_loras=config.fully_sharded_loras,
242-
lora_extra_vocab_size=config.lora_extra_vocab_size,
243-
long_lora_scaling_factors=config.long_lora_scaling_factors,
244-
lora_dtype=config.lora_dtype,
245-
max_cpu_loras=config.max_cpu_loras,
246-
revision=config.revision,
247-
max_logprobs=config.max_logprobs,
248-
enable_chunked_prefill=config.enable_chunked_prefill,
249-
cpu_offload_gb=config.cpu_offload_gb_per_gpu,
250-
enable_prefix_caching=config.enable_prefix_caching,
251-
disable_sliding_window=config.disable_sliding_window,
252-
max_num_seqs=config.max_rolling_batch_size,
253-
use_v2_block_manager=config.use_v2_block_manager,
254-
speculative_model=config.speculative_model,
255-
speculative_model_quantization=config.
256-
speculative_model_quantization,
257-
speculative_draft_tensor_parallel_size=config.
258-
speculative_draft_tensor_parallel_size,
259-
num_speculative_tokens=config.num_speculative_tokens,
260-
speculative_max_model_len=config.speculative_max_model_len,
261-
speculative_disable_by_batch_size=config.
262-
speculative_disable_by_batch_size,
263-
ngram_prompt_lookup_max=config.ngram_prompt_lookup_max,
264-
ngram_prompt_lookup_min=config.ngram_prompt_lookup_min,
265-
spec_decoding_acceptance_method=config.
266-
spec_decoding_acceptance_method,
267-
typical_acceptance_sampler_posterior_threshold=config.
268-
typical_acceptance_sampler_posterior_threshold,
269-
typical_acceptance_sampler_posterior_alpha=config.
270-
typical_acceptance_sampler_posterior_alpha,
271-
qlora_adapter_name_or_path=config.qlora_adapter_name_or_path,
272-
disable_logprobs_during_spec_decoding=config.
273-
disable_logprobs_during_spec_decoding,
274-
limit_mm_per_prompt=config.limit_mm_per_prompt,
275-
tokenizer_mode=config.tokenizer_mode,
276-
)
277-
278-
279211
def get_prompt_inputs(request: Request):
280212
text_prompt = request.request_input.input_text
281213
multi_modal_data = request.parameters.pop("mm_data", None)

0 commit comments

Comments
 (0)