diff --git a/examples/usage/llava/srt_llava_next_test.py b/examples/usage/llava/srt_llava_next_test.py index ec39e2264c..c083e33774 100644 --- a/examples/usage/llava/srt_llava_next_test.py +++ b/examples/usage/llava/srt_llava_next_test.py @@ -20,7 +20,7 @@ def image_qa(s, image, question): def single(): image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" pil_image, _ = load_image(image_url) - state = image_qa.run(image=pil_image, question="\nWhat is this?", max_new_tokens=512) + state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512) print(state["answer"], "\n") @@ -59,13 +59,12 @@ def batch(): mp.set_start_method("spawn", force=True) runtime = sgl.Runtime( - model_path="/mnt/bn/vl-research/checkpoints/onevision/llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mid_to_final_next_2p4m_am9_continual_ov", + model_path="lmms-lab/llava-onevision-qwen2-0.5b-ov", tokenizer_path="lmms-lab/llavanext-qwen-siglip-tokenizer", host="127.0.0.1", tp_size=1, port=8000, chat_template="chatml-llava", - disable_flashinfer=True, # port=8000, Optional: specify the port number for the HTTP server if meets rpc issue or connection reset by peer issue. ) # runtime = sgl.Runtime( diff --git a/examples/usage/llava_video/srt_example_llava_v.py b/examples/usage/llava_video/srt_example_llava_v.py index 27ba862d30..a13389accc 100644 --- a/examples/usage/llava_video/srt_example_llava_v.py +++ b/examples/usage/llava_video/srt_example_llava_v.py @@ -148,7 +148,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= parser.add_argument( "--video-dir", type=str, - default="./videos/Q98Z4OTh8RwmDonc.mp4", + default="./assets/jobs.mp4", help="The directory or path for the processed video files.", ) parser.add_argument( diff --git a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 b/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 deleted file mode 100644 index 32d912dbfa..0000000000 Binary files a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 and /dev/null differ diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index bb237f2191..af5fbd0549 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -21,7 +21,8 @@ def __init__( base_url: str, api_key: Optional[str] = None, verify: Optional[str] = None, - **kwargs + chat_template: Optional[str] = None, + **kwargs, ): super().__init__() self.support_concate_and_append = True @@ -38,7 +39,7 @@ def __init__( self._assert_success(res) self.model_info = res.json() - self.chat_template = get_chat_template_by_model_path( + self.chat_template = chat_template or get_chat_template_by_model_path( self.model_info["model_path"] ) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 48d1956a7d..6a51b3238a 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -332,13 +332,17 @@ def match_chat_ml(model_path: str): if "tinyllama" in model_path: return get_chat_template("chatml") # Now the suffix for qwen2 chat model is "instruct" - if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path) and ("llava" not in model_path): + if ( + "qwen" in model_path + and ("chat" in model_path or "instruct" in model_path) + and ("llava" not in model_path) + ): return get_chat_template("qwen") if ( "llava-v1.6-34b" in model_path or "llava-v1.6-yi-34b" in model_path or "llava-next-video-34b" in model_path - or "llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mid_to_final_next_2p4m_am9_continual_ov" in model_path + or "llava-onevision-qwen2" in model_path ): print("######################## Matched chatml-llava ########################") return get_chat_template("chatml-llava") diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 0bea6cc7d9..ea900c2ac2 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -407,7 +407,7 @@ def generate_chat_conv( for content in message.content: if content.type == "text": if num_image_url > 16: - real_content += "\n" # for video + real_content += "\n" # for video real_content += content.text elif content.type == "image_url": # NOTE: Only works for llava diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 38b820a77f..935b294a3c 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -273,7 +273,11 @@ def handle_generate_request( req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.pixel_values = recv_req.pixel_values if req.pixel_values is not None: - img_hash = sum(recv_req.image_hash) if type(recv_req.image_hash) is list else recv_req.image_hash + img_hash = ( + hash(recv_req.image_hash) + if type(recv_req.image_hash) is list + else recv_req.image_hash + ) req.pad_value = [ (img_hash) % self.model_config.vocab_size, (img_hash >> 16) % self.model_config.vocab_size, @@ -315,7 +319,9 @@ def handle_generate_request( self.forward_queue.append(req) def get_new_fill_batch(self) -> Optional[Batch]: - running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0 + running_bs = ( + len(self.running_batch.reqs) if self.running_batch is not None else 0 + ) if running_bs >= self.max_running_requests: return @@ -356,19 +362,15 @@ def get_new_fill_batch(self) -> Optional[Batch]: req.extend_input_len += delta req.prefix_indices = req.prefix_indices[:-delta] if req.image_offset is not None: - if isinstance(req.image_offse, list): - req.image_offset = [x + delta for x in req.image_offset] - else: - req.image_offset += delta + req.image_offset += delta + if req.extend_input_len == 0 and req.max_new_tokens() > 0: # Need at least one token to compute logits req.extend_input_len = 1 req.prefix_indices = req.prefix_indices[:-1] if req.image_offset is not None: - if isinstance(req.image_offset, list): - req.image_offset = [x + 1 for x in req.image_offset] - else: - req.image_offset += 1 + req.image_offset += 1 + if ( req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6d44d699d7..f379d2d233 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -190,7 +190,7 @@ async def _handle_single_request( # single image (anyres): num_patch, 3, 336 or 384, 336 or 384 pixel_values, image_hash, image_size = [], [], [] if len(obj.image_data) > 1: - aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres + aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres for image_data in obj.image_data: pixel_v, image_h, image_s = await self.get_pixel_values( image_data, aspect_ratio @@ -212,11 +212,11 @@ async def _handle_single_request( ) else: pixel_values, image_hash, image_size = None, None, None - + except Exception as e: print(f"Error in get_pixel_values: {e}") pixel_values, image_hash, image_size = None, None, None - + tokenized_obj = TokenizedGenerateReqInput( rid=rid, input_text=obj.text, diff --git a/python/sglang/srt/mm_utils.py b/python/sglang/srt/mm_utils.py index 3a750ba0ff..be13d6c361 100644 --- a/python/sglang/srt/mm_utils.py +++ b/python/sglang/srt/mm_utils.py @@ -17,8 +17,9 @@ import ast import base64 import math -from io import BytesIO import re +from io import BytesIO + import numpy as np from PIL import Image @@ -42,13 +43,20 @@ def select_best_resolution(original_size, possible_resolutions): for width, height in possible_resolutions: # Calculate the downscaled size to keep the aspect ratio scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + downscaled_width, downscaled_height = int(original_width * scale), int( + original_height * scale + ) # Calculate effective and wasted resolutions - effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) wasted_resolution = (width * height) - effective_resolution - if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) @@ -126,13 +134,23 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): tuple: The shape of the image patch grid in the format (width, height). """ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: - assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" # Use regex to extract the range from the input string matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) range_start = tuple(map(int, matches[0])) range_end = tuple(map(int, matches[-1])) # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) - grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] # Multiply all elements by patch_size grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: @@ -160,16 +178,26 @@ def process_anyres_image(image, processor, grid_pinpoints): patch_size = processor.size[0] except Exception as e: patch_size = processor.size["shortest_edge"] - assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" # Use regex to extract the range from the input string matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) range_start = tuple(map(int, matches[0])) range_end = tuple(map(int, matches[-1])) # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) - grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] # Multiply all elements by patch_size grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] - + if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -178,13 +206,19 @@ def process_anyres_image(image, processor, grid_pinpoints): image_padded = resize_and_pad_image(image, best_resolution) # For Siglip processor, only have size but no crop size - crop_size = processor.crop_size["height"] if "crop_size" in processor.__dict__ else processor.size["height"] - shortest_edge = processor.size["shortest_edge"] if "shortest_edge" in processor.size else processor.size["height"] + crop_size = ( + processor.crop_size["height"] + if "crop_size" in processor.__dict__ + else processor.size["height"] + ) + shortest_edge = ( + processor.size["shortest_edge"] + if "shortest_edge" in processor.size + else processor.size["height"] + ) patches = divide_to_patches(image_padded, crop_size) - image_original_resize = image.resize( - (shortest_edge, shortest_edge) - ) + image_original_resize = image.resize((shortest_edge, shortest_edge)) image_patches = [image_original_resize] + patches image_patches = [ diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 65241268b0..535d274280 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -15,20 +15,21 @@ """Inference-only LLaVa model compatible with HuggingFace weights.""" +import math +import re from typing import Iterable, List, Optional, Tuple -import re, math import numpy as np import torch from torch import nn from transformers import ( CLIPVisionConfig, CLIPVisionModel, - SiglipVisionConfig, - SiglipVisionModel, LlavaConfig, MistralConfig, Qwen2Config, + SiglipVisionConfig, + SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.config import CacheConfig @@ -44,7 +45,7 @@ from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM -import math + class LlavaLlamaForCausalLM(nn.Module): def __init__( @@ -67,17 +68,18 @@ def __init__( def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): - # hardcode for spatial_unpad + anyres image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad" offset_list = [] for image_s in image_size: if len(image_size) > 16: # 2x2 pooling with stride 2 - new_image_feature_len = math.ceil(self.image_size / self.patch_size / 2) ** 2 + new_image_feature_len = ( + math.ceil(self.image_size / self.patch_size / 2) ** 2 + ) else: new_image_feature_len = self.image_feature_len # multiimage - + height = width = self.num_patches_per_side if "anyres" in image_aspect_ratio: num_patch_width, num_patch_height = get_anyres_image_grid_shape( @@ -88,13 +90,17 @@ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): h = num_patch_height * height w = num_patch_width * width new_h, new_w = unpad_image_shape(h, w, image_s) - + if "anyres_max" in self.config.image_aspect_ratio: - matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", self.config.image_aspect_ratio) + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", self.config.image_aspect_ratio + ) if matched_anyres_max_num_patches: max_num_patches = int(matched_anyres_max_num_patches.group(1)) # times = math.sqrt(h * w / (max_num_patches * unit**2)) - times = math.sqrt(new_h * new_w / (max_num_patches * self.image_feature_len)) + times = math.sqrt( + new_h * new_w / (max_num_patches * self.image_feature_len) + ) if times > 1.1: new_h = int(new_h // times) new_w = int(new_w // times) @@ -202,23 +208,40 @@ def forward( base_image_feature = image_feature[0] image_feature = image_feature[1:] assert height * width == base_image_feature.shape[0] - + if "anyres_max" in image_aspect_ratio: - matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", image_aspect_ratio + ) if matched_anyres_max_num_patches: - max_num_patches = int(matched_anyres_max_num_patches.group(1)) - - if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + max_num_patches = int( + matched_anyres_max_num_patches.group(1) + ) + + if ( + image_aspect_ratio == "anyres" + or "anyres_max" in image_aspect_ratio + ): vision_tower_image_size = self.image_size try: - num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx][0], self.config.image_grid_pinpoints, vision_tower_image_size) + num_patch_width, num_patch_height = ( + get_anyres_image_grid_shape( + image_sizes[image_idx][0], + self.config.image_grid_pinpoints, + vision_tower_image_size, + ) + ) except Exception as e: print(f"Error: {e}") num_patch_width, num_patch_height = 2, 2 - image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) else: - image_feature = image_feature.view(2, 2, height, width, -1) - + image_feature = image_feature.view( + 2, 2, height, width, -1 + ) + # ( # num_patch_width, # num_patch_height, @@ -227,12 +250,16 @@ def forward( # self.image_grid_pinpoints, # self.vision_tower.config.image_size, # ) - + # image_feature = image_feature.view( # num_patch_height, num_patch_width, height, width, -1 # ) - if "unpad" in self.mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: + if ( + "unpad" in self.mm_patch_merge_type + and "anyres_max" in image_aspect_ratio + and matched_anyres_max_num_patches + ): unit = image_feature.shape[2] image_feature = image_feature.permute( 4, 0, 2, 1, 3 @@ -247,7 +274,11 @@ def forward( times = math.sqrt(h * w / (max_num_patches * unit**2)) if times > 1.1: image_feature = image_feature[None] - image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] + image_feature = nn.functional.interpolate( + image_feature, + [int(h // times), int(w // times)], + mode="bilinear", + )[0] image_feature = torch.cat( ( image_feature, @@ -257,9 +288,13 @@ def forward( ), dim=-1, ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = image_feature.flatten(1, 2).transpose( + 0, 1 + ) else: - image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.permute( + 0, 2, 1, 3, 4 + ).contiguous() image_feature = image_feature.flatten(0, 3) image_feature = torch.cat( (base_image_feature, image_feature), dim=0 @@ -276,9 +311,18 @@ def forward( 0, 3, 1, 2 ).contiguous() # N, C, H, W height, weight = image_feature.shape[2:] - scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)] - image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') - image_feature = image_feature.flatten(2).transpose(1, 2).contiguous() # N, C, H*W + scaled_shape = [ + math.ceil(height / 2), + math.ceil(weight / 2), + ] + image_feature = nn.functional.interpolate( + image_feature, size=scaled_shape, mode="bilinear" + ) + image_feature = ( + image_feature.flatten(2) + .transpose(1, 2) + .contiguous() + ) # N, C, H*W new_image_features.append(image_feature) image_features = new_image_features diff --git a/scripts/convert_llava.py b/scripts/convert_llava.py index f07ead0147..464f306bec 100644 --- a/scripts/convert_llava.py +++ b/scripts/convert_llava.py @@ -8,20 +8,25 @@ import json import os -from transformers import AutoConfig, AutoTokenizer, AutoProcessor, LlavaConfig, AutoImageProcessor +from transformers import ( + AutoConfig, + AutoImageProcessor, + AutoProcessor, + AutoTokenizer, + LlavaConfig, +) + def add_image_token(model_path: str, hub_path: str): tokenizer = AutoTokenizer.from_pretrained(model_path) - tokenizer.add_tokens( - [""], - special_tokens=True - ) + tokenizer.add_tokens([""], special_tokens=True) print(tokenizer) # tokenizer.save_pretrained(model_path) tokenizer.push_to_hub(hub_path, private=True) return tokenizer.convert_tokens_to_ids("") + def edit_model_config(model_path, image_token_index, hub_path): config = LlavaConfig.from_pretrained(model_path) @@ -46,4 +51,4 @@ def edit_model_config(model_path, image_token_index, hub_path): args = parser.parse_args() image_token_index = add_image_token(args.model_path, args.hub_path) - edit_model_config(args.model_path, image_token_index, args.hub_path) \ No newline at end of file + edit_model_config(args.model_path, image_token_index, args.hub_path) diff --git a/test/srt/test_multi_image_openai_server.py b/test/srt/test_multi_image_openai_server.py index f6abdcb67e..f52ae114db 100644 --- a/test/srt/test_multi_image_openai_server.py +++ b/test/srt/test_multi_image_openai_server.py @@ -1,8 +1,8 @@ import openai + client = openai.Client(api_key="EMPTY", base_url="http://127.0.0.1:30000/v1") import sys - request_1 = client.chat.completions.create( model="default", messages=[ @@ -15,7 +15,10 @@ "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg" }, }, - {"type": "text", "text": "Please describe this image. Please list the benchmarks and the models."}, + { + "type": "text", + "text": "Please describe this image. Please list the benchmarks and the models.", + }, ], }, ], @@ -37,7 +40,7 @@ # from decord import VideoReader, cpu # import numpy as np -# video_path = "/mnt/bn/vl-research/workspace/boli01/projects/demos/sglang_codebase/assets/jobs.mp4" +# video_path = "./assets/jobs.mp4" # max_frames_num = 32 # vr = VideoReader(video_path, ctx=cpu(0)) # total_frame_num = len(vr) @@ -92,4 +95,4 @@ # content = chunk.choices[0].delta.content # response_3 += content # sys.stdout.write(content) -# sys.stdout.flush() \ No newline at end of file +# sys.stdout.flush()