Skip to content

Commit

Permalink
add blog post: sglang v0.3 (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Sep 4, 2024
1 parent 8a7ea33 commit f7179dc
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 0 deletions.
105 changes: 105 additions & 0 deletions blog/2024-09-04-sglang-v0-3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
---
title: "SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision"
author: "The SGLang Team"
date: "September 4, 2024"
previewImg: /images/blog/sglang_v0_3/preview.png
---

We're excited to announce the release of [SGLang v0.3](https://github.com/sgl-project/sglang/tree/main), which brings significant performance enhancements and expanded support for novel model architectures. Here are the key updates:

- Up to 7x higher throughput for DeepSeek Multi-head Latent Attention (MLA)
- Up to 1.5x lower latency with `torch.compile` on small batch sizes
- Support for interleaved text and multi-image/video in LLaVA-OneVision
- Support for interleaved window attention and 2x longer context length in Gemma-2

In this blog post, we'll walk you through these key features. Please do not hesitate to report any issues or contribute ideas and code.


### DeepSeek Multi-head Latent Attention (MLA) Throughput Optimizations

[Multi-head Latent Attention](https://arxiv.org/pdf/2405.04434) (MLA) is a new attention variant introduced by the DeepSeek team to improve inference efficiency. Due to its differences from standard attention mechanisms, existing open-source libraries have not fully optimized this operation. In SGLang v0.3, we implemented various optimizations for MLA, including weight absorption, grouped decoding kernels, FP8 batched MatMul, and FP8 KV cache quantization. **Benchmark results show that SGLang v0.3 with MLA optimizations achieves 3x to 7x higher throughput than the baseline system.** The benchmark measures the peak output throughput of these models with BF16 and FP8 on H100 GPUs (tensor-parallelism=1 for lite models and tensor-parallelism=8 for big models) on the ShareGPT datasets. Reproducible instructions are in the appendix. While encouraging, there is still much room for improvement. We are actively working on more optimizations to fully reproduce the results from the DeepSeek paper. Related PRs:
[#905](https://github.com/sgl-project/sglang/pull/905),
[#1060](https://github.com/sgl-project/sglang/pull/1060),
[#1138](https://github.com/sgl-project/sglang/pull/1138),
[#469](https://github.com/flashinfer-ai/flashinfer/pull/469),
[#1285](https://github.com/sgl-project/sglang/pull/1285),
[#1286](https://github.com/sgl-project/sglang/pull/1286).

<img src="/images/blog/sglang_v0_3/deepseek_mla.svg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 70%;"></img>

### Torch.compile Latency Optimizations

[Torch.compile](https://pytorch.org/assets/pytorch2-2.pdf) is a major feature of PyTorch 2.0. On NVIDIA GPUs, it performs aggressive fusion and generates highly efficient Triton kernels. We've integrated torch.compile into SGLang for linear/norm/activation layers, combining it with FlashInfer attention and sampling kernels. We turn on torch.compile for batch sizes 1 to 32, where we observed the most acceleration. With this combination, SGLang is faster than [gpt-fast](https://github.com/pytorch-labs/gpt-fast) at batch size 1 and supports all online serving features, including continuous batching and RadixAttention for prefix caching. We are actively collaborating with the torch.compile and [torchao](https://github.com/pytorch/ao) teams to incorporate their latest optimizations into SGLang. To use torch.compile in SGLang, add `--enable-torch-compile` when launching the server. **SGLang w/ torch.compile yields up to a 1.5x speedup in the following benchmark.** Reproducible instructions are in the appendix.

<img src="/images/blog/sglang_v0_3/torch_compile.svg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 70%;"></img>

### LLaVA-OneVision Support with Interleaved Text, Multi-Image, and Video

[LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) is the first open model to achieve state-of-the-art performance in three important computer vision scenarios: single-image, multi-image, and video tasks. We collaborated with the LLaVA team to integrate these capabilities into SGLang v0.3. You can launch a server and query it using the OpenAI-compatible vision API, which supports interleaved text, multi-image, and video formats. Usage details are available [here](https://github.com/sgl-project/sglang/blob/c500f96bb16c686ee8ba5d5f1fc716a0bd8e5fff/README.md?plain=1#L241-L244). The authors validated the model's accuracy and reported benchmark results on the VideoDetailDescriptions and LLaVA-in-the-wild datasets (see [#1123](https://github.com/sgl-project/sglang/pull/1123#issuecomment-2301691452)). **SGLang archives up to 4.5x speedup than the authors’ original implementation in HuggingFace/transformers.**

<img src="/images/blog/sglang_v0_3/llava_onevision.svg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 70%;"></img>

### Gemma-2 Support with Interleaved Window Attention

Google's [Gemma-2 model](https://arxiv.org/abs/2408.00118) uses interleaved window attention to reduce computational complexity for long contexts, alternating between local sliding window attention (4K context length) and global attention (8K context length) in every other layer. We enhanced SGLang v0.3 to fully support the 8K context length by leveraging the optimized window attention kernel from FlashInfer kernels (which skips computation instead of masking) and refining our KV cache manager. Other libraries that lack this feature can only run with a 4K context length. You can launch the model with
```
python3 -m sglang.launch_server --model-path google/gemma-2b
```

<img src="/images/blog/sglang_v0_3/gemma2.svg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 70%;"></img>

## Acknowledgment

The DeepSeek MLA optimizations were contributed by Ke Bao and Yineng Zhang. The torch.compile optimizations were contributed by Liangsheng Yin. The LLaVA-OneVision contributions were made by Kaichen Zhang and Bo Li. The interleaved window attention was contributed by Ying Sheng. We also thank all 90+ open-source [contributors](https://github.com/sgl-project/sglang/graphs/contributors).

## Appendix

### Benchmark Instructions for DeepSeek MLA

```
# DeepSeekCoder-V2-Lite (BF16)
## Launch a server
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --enable-mla --disable-radix --trust-remote-code
python3 -m vllm.entrypoints.openai.api_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --disable-log-requests --trust-remote-code --max-model-len 4096
## Run benchmark
python3 -m sglang.bench_serving --backend sglang --num-prompts 5000
python3 -m sglang.bench_serving --backend vllm --num-prompts 5000
# DeepSeekCoder-V2 (BF16)
## Launch a server
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Instruct --disable-radix --tp 8 --trust-remote-code --enable-mla
python3 -m vllm.entrypoints.openai.api_server --model deepseek-ai/DeepSeek-Coder-V2-Instruct --disable-log-requests --tensor-parallel-size 8 --trust-remote-code --max-model-len 4096
## Run benchmark
python3 -m sglang.bench_serving --backend sglang --num-prompts 5000
python3 -m sglang.bench_serving --backend vllm --num-prompts 5000
# DeepSeekCoder-V2 (FP8)
## Launch a server
python3 -m sglang.launch_server --model neuralmagic/DeepSeek-Coder-V2-Instruct-FP8 --enable-mla --quantization fp8 --kv-cache-dtype fp8_e5m2 --disable-radix --tp 8 --trust-remote-code
python3 -m vllm.entrypoints.openai.api_server --model neuralmagic/DeepSeek-Coder-V2-Instruct-FP8 --quantization fp8 --disable-log-requests --tensor-parallel-size 8 --trust-remote-code --max-model-len 4096
## Run benchmark
python3 -m sglang.bench_serving --backend sglang --num-prompts 5000
python3 -m sglang.bench_serving --backend vllm --num-prompts 5000
```

### Benchmark Instructions for torch.compile

```
# SGLang
## Launch a server
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B --enable-torch-compile
## Run benchmark
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input-len 128 --random-output-len 512 --random-range-ratio 1 --num-prompts 1
# vLLM
## Launch a server
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --disable-log-requests
## Run benchmark
python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input-len 128 --random-output-len 512 --random-range-ratio 1 --num-prompts 1
```

1 change: 1 addition & 0 deletions public/images/blog/sglang_v0_3/deepseek_mla.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions public/images/blog/sglang_v0_3/gemma2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions public/images/blog/sglang_v0_3/llava_onevision.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added public/images/blog/sglang_v0_3/preview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions public/images/blog/sglang_v0_3/torch_compile.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit f7179dc

Please sign in to comment.