Skip to content

Commit

Permalink
fix: skip probing for incompatible scenario (#879)
Browse files Browse the repository at this point in the history
When determining the max available context length, vllm will
generate dummy inputs with `max_num_batched_tokens` tokens
and profiling peak VRAM usage.

1. skip probing for multi-gpu serving

If tensor-parallelism is enabled, there are more than 1 gpu workers. 
However, we have no way to adjust the context length for profiling
for these workers in vllm.

2. skip probing when chunked-prefill is enabled

[chunked prefill](https://arxiv.org/pdf/2308.16369) is a feature
to improve thoughput for long context serving (max_model_len > 32K).
It makes the peak memory usage nonlinearly related to max_model_len.

**Notes for Reviewers**:

test case:
- [x] single gpu with enough vram to run max-model-len
- [x] single gpu, find max-model-len after several probing
- [x] single gpu, unable to run with even small model len
- [x] multi gpu
- [x] chunked prefill enabled

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh authored Feb 13, 2025
1 parent 2a9b5c5 commit 5895b9b
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions presets/workspace/inference/vllm/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, *args, **kwargs):
# They should start with "kaito-" prefix to avoid conflict with vllm args
self.add_argument("--kaito-adapters-dir", type=str, default="/mnt/adapter", help="Directory where adapters are stored in KAITO preset.")
self.add_argument("--kaito-config-file", type=str, default="", help="Additional args for KAITO preset.")
self.add_argument("--kaito-max-probe-steps", type=int, default=6, help="Maximum number of steps to find the max available seq len fitting in the GPU memory.")
self.add_argument("--kaito-max-probe-steps", type=int, help="Maximum number of steps to find the max available seq len fitting in the GPU memory.")

def _reset_vllm_defaults(self):
local_rank = int(os.environ.get("LOCAL_RANK",
Expand Down Expand Up @@ -110,10 +110,10 @@ def to_yaml(self) -> str:
def load_lora_adapters(adapters_dir: str) -> Optional[LoRAModulePath]:
lora_list: List[LoRAModulePath] = []

logger.info(f"Loading LoRA adapters from {adapters_dir}")
if not os.path.exists(adapters_dir):
return lora_list

logger.info(f"Loading LoRA adapters from {adapters_dir}")
for adapter in os.listdir(adapters_dir):
adapter_path = os.path.join(adapters_dir, adapter)
if os.path.isdir(adapter_path):
Expand All @@ -127,6 +127,9 @@ def find_max_available_seq_len(engine_config: EngineConfig, max_probe_steps: int
"""
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L335
executor_class = LLMEngine._get_executor_cls(engine_config)
if engine_config.scheduler_config.enable_chunked_prefill:
logger.info("Chunked Prefill is enabled, skip probing.")
return engine_config.model_config.max_model_len
executor = executor_class(
model_config=engine_config.model_config,
cache_config=engine_config.cache_config,
Expand All @@ -140,15 +143,14 @@ def find_max_available_seq_len(engine_config: EngineConfig, max_probe_steps: int
observability_config=engine_config.observability_config,
)

model_max_blocks = int(engine_config.model_config.max_model_len / engine_config.cache_config.block_size)
res = binary_search_with_limited_steps(model_max_blocks, max_probe_steps, lambda x: is_context_length_safe(executor, x))
res = binary_search_with_limited_steps(engine_config.model_config.max_model_len, max_probe_steps, lambda x: is_context_length_safe(executor, x))

# release memory
del executor
gc.collect()
torch.cuda.empty_cache()

return engine_config.cache_config.block_size * res
return res

def binary_search_with_limited_steps(upper: int, max_probe_steps: int, is_valid_fn: Callable[[int], bool]) -> int:
"""
Expand Down Expand Up @@ -180,27 +182,33 @@ def binary_search_with_limited_steps(upper: int, max_probe_steps: int, is_valid_

return low

def is_context_length_safe(executor: ExecutorBase, num_gpu_blocks: int) -> bool:
def is_context_length_safe(executor: ExecutorBase, context_length: int) -> bool:
"""
Check if the avilable gpu blocks is enough for the given num_gpu_blocks.
"""
context_length = executor.cache_config.block_size * num_gpu_blocks
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/worker/model_runner.py#L1232
executor.scheduler_config.max_num_batched_tokens = context_length

try:
logger.info(f"Try to determine available gpu blocks for context length {context_length}")
# see https://github.com/vllm-project/vllm/blob/v0.6.3/vllm/engine/llm_engine.py#L477
available_gpu_blocks, _ = executor.determine_num_available_blocks()
except torch.OutOfMemoryError as e:
return False

num_gpu_blocks = context_length // executor.cache_config.block_size
return available_gpu_blocks >= num_gpu_blocks

def try_set_max_available_seq_len(args: argparse.Namespace):
if args.max_model_len is not None:
logger.info(f"max_model_len is set to {args.max_model_len}, skip probing.")
return

if args.tensor_parallel_size > 1 or args.pipeline_parallel_size > 1:
logger.info("Multi-GPU serving is enabled, skip probing.")
return

max_probe_steps = 0
if args.kaito_max_probe_steps is not None:
try:
Expand Down

0 comments on commit 5895b9b

Please sign in to comment.