Skip to content

Commit

Permalink
fix conflicts with latest main before PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Aug 14, 2024
1 parent c09d8e5 commit 6b373c5
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 73 deletions.
5 changes: 2 additions & 3 deletions examples/usage/llava/srt_llava_next_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def image_qa(s, image, question):
def single():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
state = image_qa.run(image=pil_image, question="<image>\nWhat is this?", max_new_tokens=512)
state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512)
print(state["answer"], "\n")


Expand Down Expand Up @@ -59,13 +59,12 @@ def batch():

mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(
model_path="/mnt/bn/vl-research/checkpoints/onevision/llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mid_to_final_next_2p4m_am9_continual_ov",
model_path="lmms-lab/llava-onevision-qwen2-0.5b-ov",
tokenizer_path="lmms-lab/llavanext-qwen-siglip-tokenizer",
host="127.0.0.1",
tp_size=1,
port=8000,
chat_template="chatml-llava",
disable_flashinfer=True,
# port=8000, Optional: specify the port number for the HTTP server if meets rpc issue or connection reset by peer issue.
)
# runtime = sgl.Runtime(
Expand Down
2 changes: 1 addition & 1 deletion examples/usage/llava_video/srt_example_llava_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=
parser.add_argument(
"--video-dir",
type=str,
default="./videos/Q98Z4OTh8RwmDonc.mp4",
default="./assets/jobs.mp4",
help="The directory or path for the processed video files.",
)
parser.add_argument(
Expand Down
Binary file not shown.
5 changes: 3 additions & 2 deletions python/sglang/lang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(
base_url: str,
api_key: Optional[str] = None,
verify: Optional[str] = None,
**kwargs
chat_template: Optional[str] = None,
**kwargs,
):
super().__init__()
self.support_concate_and_append = True
Expand All @@ -38,7 +39,7 @@ def __init__(
self._assert_success(res)
self.model_info = res.json()

self.chat_template = get_chat_template_by_model_path(
self.chat_template = chat_template or get_chat_template_by_model_path(
self.model_info["model_path"]
)

Expand Down
8 changes: 6 additions & 2 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,17 @@ def match_chat_ml(model_path: str):
if "tinyllama" in model_path:
return get_chat_template("chatml")
# Now the suffix for qwen2 chat model is "instruct"
if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path) and ("llava" not in model_path):
if (
"qwen" in model_path
and ("chat" in model_path or "instruct" in model_path)
and ("llava" not in model_path)
):
return get_chat_template("qwen")
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mid_to_final_next_2p4m_am9_continual_ov" in model_path
or "llava-onevision-qwen2" in model_path
):
print("######################## Matched chatml-llava ########################")
return get_chat_template("chatml-llava")
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def generate_chat_conv(
for content in message.content:
if content.type == "text":
if num_image_url > 16:
real_content += "\n" # for video
real_content += "\n" # for video
real_content += content.text
elif content.type == "image_url":
# NOTE: Only works for llava
Expand Down
22 changes: 12 additions & 10 deletions python/sglang/srt/managers/controller/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,11 @@ def handle_generate_request(
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
img_hash = sum(recv_req.image_hash) if type(recv_req.image_hash) is list else recv_req.image_hash
img_hash = (
hash(recv_req.image_hash)
if type(recv_req.image_hash) is list
else recv_req.image_hash
)
req.pad_value = [
(img_hash) % self.model_config.vocab_size,
(img_hash >> 16) % self.model_config.vocab_size,
Expand Down Expand Up @@ -315,7 +319,9 @@ def handle_generate_request(
self.forward_queue.append(req)

def get_new_fill_batch(self) -> Optional[Batch]:
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests:
return

Expand Down Expand Up @@ -356,19 +362,15 @@ def get_new_fill_batch(self) -> Optional[Batch]:
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
if isinstance(req.image_offse, list):
req.image_offset = [x + delta for x in req.image_offset]
else:
req.image_offset += delta
req.image_offset += delta

if req.extend_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
if isinstance(req.image_offset, list):
req.image_offset = [x + 1 for x in req.image_offset]
else:
req.image_offset += 1
req.image_offset += 1

if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def _handle_single_request(
# single image (anyres): num_patch, 3, 336 or 384, 336 or 384
pixel_values, image_hash, image_size = [], [], []
if len(obj.image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for image_data in obj.image_data:
pixel_v, image_h, image_s = await self.get_pixel_values(
image_data, aspect_ratio
Expand All @@ -212,11 +212,11 @@ async def _handle_single_request(
)
else:
pixel_values, image_hash, image_size = None, None, None

except Exception as e:
print(f"Error in get_pixel_values: {e}")
pixel_values, image_hash, image_size = None, None, None

tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text,
Expand Down
62 changes: 48 additions & 14 deletions python/sglang/srt/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import ast
import base64
import math
from io import BytesIO
import re
from io import BytesIO

import numpy as np
from PIL import Image

Expand All @@ -42,13 +43,20 @@ def select_best_resolution(original_size, possible_resolutions):
for width, height in possible_resolutions:
# Calculate the downscaled size to keep the aspect ratio
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)

# Calculate effective and wasted resolutions
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution

if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
Expand Down Expand Up @@ -126,13 +134,23 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
tuple: The shape of the image patch grid in the format (width, height).
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
Expand Down Expand Up @@ -160,16 +178,26 @@ def process_anyres_image(image, processor, grid_pinpoints):
patch_size = processor.size[0]
except Exception as e:
patch_size = processor.size["shortest_edge"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
assert patch_size in [
224,
336,
384,
448,
512,
], "patch_size should be in [224, 336, 384, 448, 512]"
# Use regex to extract the range from the input string
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
grid_pinpoints = [
(i, j)
for i in range(range_start[0], range_end[0] + 1)
for j in range(range_start[1], range_end[1] + 1)
]
# Multiply all elements by patch_size
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]

if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
Expand All @@ -178,13 +206,19 @@ def process_anyres_image(image, processor, grid_pinpoints):
image_padded = resize_and_pad_image(image, best_resolution)

# For Siglip processor, only have size but no crop size
crop_size = processor.crop_size["height"] if "crop_size" in processor.__dict__ else processor.size["height"]
shortest_edge = processor.size["shortest_edge"] if "shortest_edge" in processor.size else processor.size["height"]
crop_size = (
processor.crop_size["height"]
if "crop_size" in processor.__dict__
else processor.size["height"]
)
shortest_edge = (
processor.size["shortest_edge"]
if "shortest_edge" in processor.size
else processor.size["height"]
)
patches = divide_to_patches(image_padded, crop_size)

image_original_resize = image.resize(
(shortest_edge, shortest_edge)
)
image_original_resize = image.resize((shortest_edge, shortest_edge))

image_patches = [image_original_resize] + patches
image_patches = [
Expand Down
Loading

0 comments on commit 6b373c5

Please sign in to comment.