Skip to content

Commit

Permalink
Merge pull request #55 from rhymes-ai/processor
Browse files Browse the repository at this point in the history
Refactor: precompute the image token embedding placeholders in `AriaProcessor` for smplified forward pass and static kv cache support
  • Loading branch information
xffxff authored Nov 11, 2024
2 parents d424be8 + c2c2aa6 commit 0d987a8
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 266 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ aria.egg-info/
wandb
datasets/
gptfast/checkpoints/
local_datasets/
local_datasets/
eval
build/
11 changes: 11 additions & 0 deletions aria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def apply_chat_template_and_tokenize(
tokenizer,
num_image_crop: Iterable[torch.Tensor] = iter([]),
max_length: int = 1024,
max_image_size: int = 980,
):
IGNORE_TOKEN_ID = -100
im_start_tokens = tokenizer("<|im_start|>").input_ids
Expand Down Expand Up @@ -76,6 +77,16 @@ def create_target(role, input_id):
role = message["role"]
text = "".join(process_content(content) for content in message["content"])

if max_image_size == 490:
num_image_tokens = 128
elif max_image_size == 980:
num_image_tokens = 256
else:
raise ValueError(
f"max_image_size must be either 490 or 980, got {max_image_size}"
)
text = text.replace("<|img|>", "<|img|>" * num_image_tokens)

_input_id = tokenize_message(role, text)
input_id.extend(_input_id)
target.extend(create_target(role, _input_id))
Expand Down
4 changes: 4 additions & 0 deletions aria/model/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self.image_token_index = image_token_index

attn_implementation = kwargs.pop("attn_implementation", None)
self._attn_implementation = attn_implementation

# Convert the keys and values of projector_patch_to_query_dict to integers
# This ensures consistency even if they were provided as strings
Expand All @@ -95,3 +96,6 @@ def __init__(
text_config._attn_implementation = text_attn_implementation

self.text_config = text_config

# This is needed for the static kv cache
self.num_hidden_layers = self.text_config.num_hidden_layers
317 changes: 53 additions & 264 deletions aria/model/modeling_aria.py

Large diffs are not rendered by default.

24 changes: 23 additions & 1 deletion aria/model/processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __call__(
max_image_size: Optional[int] = 980,
split_image: Optional[bool] = False,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
return_final_prompts: Optional[bool] = False,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring
Expand Down Expand Up @@ -168,6 +169,24 @@ def __call__(
)
)

max_image_size = (
max_image_size
if max_image_size is not None
else self.image_processor.max_image_size
)
if max_image_size == 490:
num_image_tokens = 128
elif max_image_size == 980:
num_image_tokens = 256
else:
raise ValueError(
f"max_image_size must be either 490 or 980, got {max_image_size}"
)
prompt_strings = [
sample.replace(self.image_token, self.image_token * num_image_tokens)
for sample in prompt_strings
]

else:
image_inputs = {}
prompt_strings = text
Expand All @@ -180,7 +199,10 @@ def __call__(
max_length=max_length,
)

return BatchFeature(data={**text_inputs, **image_inputs})
if return_final_prompts:
return BatchFeature(data={**text_inputs, **image_inputs}), prompt_strings
else:
return BatchFeature(data={**text_inputs, **image_inputs})

@staticmethod
def _extract_kwargs(func: callable, **kwargs) -> dict:
Expand Down
1 change: 1 addition & 0 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def collate_fn(
tokenizer,
iter(image_inputs.pop("num_crops")),
max_length=max_seq_length,
max_image_size=processor.max_image_size,
)

batch.update(image_inputs)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_apply_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_apply_chat_template_single_user_message(tokenizer):
}
]
expected_output = "<|im_start|>user\nWho wrote this book?\n<fim_prefix><|img|><fim_suffix><|im_end|>\n"
expected_output = expected_output.replace("<|img|>", "<|img|>" * 256)
res = apply_chat_template_and_tokenize(
[messages], num_image_crop=iter([1]), tokenizer=tokenizer
)
Expand All @@ -37,6 +38,29 @@ def test_apply_chat_template_single_user_message(tokenizer):
assert (labels == -100).sum() == input_ids.numel()


def test_apply_chat_template_single_user_message_490(tokenizer):
messages = [
{
"content": [
{"text": "Who wrote this book?\n", "type": "text"},
{"text": None, "type": "image"},
],
"role": "user",
}
]
expected_output = "<|im_start|>user\nWho wrote this book?\n<fim_prefix><|img|><fim_suffix><|im_end|>\n"
expected_output = expected_output.replace("<|img|>", "<|img|>" * 128)
res = apply_chat_template_and_tokenize(
[messages], num_image_crop=iter([1]), tokenizer=tokenizer, max_image_size=490
)
input_ids = res["input_ids"]
input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
assert input_str == expected_output

labels = res["labels"]
assert (labels == -100).sum() == input_ids.numel()


def test_apply_chat_template_single_assistant_message(tokenizer):
messages = [
{
Expand Down Expand Up @@ -66,6 +90,7 @@ def test_apply_chat_template_multiple_messages(tokenizer):
},
]
expected_output = "<|im_start|>user\nWho wrote this book?\n<fim_prefix><|img|><fim_suffix><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n"
expected_output = expected_output.replace("<|img|>", "<|img|>" * 256)
res = apply_chat_template_and_tokenize(
[messages], num_image_crop=iter([1]), tokenizer=tokenizer
)
Expand Down Expand Up @@ -122,6 +147,7 @@ def test_apply_chat_template_multi_round_messages(tokenizer):
},
]
expected_output = "<|im_start|>user\nWho wrote this book?\n<fim_prefix><|img|><fim_suffix><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n<|im_start|>user\nWhat is the title of this book?<|im_end|>\n<|im_start|>assistant\nModern Printmaking: A Guide to Traditional and Digital Techniques<|im_end|>\n"
expected_output = expected_output.replace("<|img|>", "<|img|>" * 256)
res = apply_chat_template_and_tokenize(
[messages], num_image_crop=iter([1]), tokenizer=tokenizer
)
Expand Down Expand Up @@ -187,6 +213,10 @@ def test_apply_chat_template_batch_messages(tokenizer):
"<|im_start|>user\nWho wrote this book?\n<fim_prefix><|img|><fim_suffix><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n",
"<|im_start|>user\nWho wrote this book?\n<fim_prefix><|img|><fim_suffix><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n<|im_start|>user\nWhat is the title of this book?<|im_end|>\n<|im_start|>assistant\nModern Printmaking: A Guide to Traditional and Digital Techniques<|im_end|>\n",
]
expected_output = [
expected_output[0].replace("<|img|>", "<|img|>" * 256),
expected_output[1].replace("<|img|>", "<|img|>" * 256),
]
assert (
tokenizer.batch_decode(input_ids, skip_special_tokens=True) == expected_output
)
187 changes: 187 additions & 0 deletions tests/test_aria_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import numpy as np
import pytest
import torch
from PIL import Image
from transformers import AutoTokenizer

from aria.model.processing_aria import AriaProcessor
from aria.model.vision_processor import AriaVisionProcessor


@pytest.fixture
def processor():
tokenizer = AutoTokenizer.from_pretrained("rhymes-ai/Aria")
image_processor = AriaVisionProcessor(max_image_size=490)
return AriaProcessor(
tokenizer=tokenizer,
image_processor=image_processor,
image_token="<|img|>",
chat_template=tokenizer.chat_template,
)


@pytest.fixture
def sample_image():
return Image.fromarray(np.random.randint(0, 255, (768, 768, 3), dtype=np.uint8))


@pytest.fixture
def sample_messages():
return [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "describe the image", "type": "text"},
],
}
]


def test_apply_chat_template(processor, sample_messages):
text = processor.apply_chat_template(sample_messages, add_generation_prompt=True)

assert (
text
== "<|im_start|>user\n<fim_prefix><|img|><fim_suffix>describe the image<|im_end|>\n<|im_start|>assistant\n"
)

text = processor.apply_chat_template(sample_messages, add_generation_prompt=False)
assert (
text
== "<|im_start|>user\n<fim_prefix><|img|><fim_suffix>describe the image<|im_end|>\n"
)


def test_chat_template_with_multiple_messages(processor):
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "What's in this image?", "type": "text"},
],
},
{
"role": "assistant",
"content": "This is a beautiful landscape.",
},
{
"role": "user",
"content": [
{"text": "Can you describe it in more detail?", "type": "text"},
],
},
]

text = processor.apply_chat_template(messages, add_generation_prompt=True)
assert (
text
== "<|im_start|>user\n<fim_prefix><|img|><fim_suffix>What's in this image?<|im_end|>\n<|im_start|>assistant\nThis is a beautiful landscape.<|im_end|>\n<|im_start|>user\nCan you describe it in more detail?<|im_end|>\n<|im_start|>assistant\n"
)


def test_end_to_end_processing_980(processor, sample_messages, sample_image):
text = processor.apply_chat_template(sample_messages, add_generation_prompt=True)
inputs, prompts = processor(
text=text,
images=[sample_image],
return_tensors="pt",
max_image_size=980,
return_final_prompts=True,
)

# Verify the output contains all necessary keys
assert "input_ids" in inputs
assert "attention_mask" in inputs
assert "pixel_values" in inputs

# Check shapes
assert len(inputs["input_ids"].shape) == 2
assert len(inputs["attention_mask"].shape) == 2
assert len(inputs["pixel_values"].shape) == 4

# Check device and dtype
assert inputs["input_ids"].device.type == "cpu"
assert inputs["pixel_values"].dtype == torch.float32

expected_prompt = "<|im_start|>user\n<fim_prefix><|img|><fim_suffix>describe the image<|im_end|>\n<|im_start|>assistant\n"
expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 256)

assert prompts[0] == expected_prompt


def test_end_to_end_processing_490(processor, sample_messages, sample_image):
text = processor.apply_chat_template(sample_messages, add_generation_prompt=True)
inputs, prompts = processor(
text=text,
images=[sample_image],
return_tensors="pt",
max_image_size=490,
return_final_prompts=True,
)

expected_prompt = "<|im_start|>user\n<fim_prefix><|img|><fim_suffix>describe the image<|im_end|>\n<|im_start|>assistant\n"
expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 128)

assert prompts[0] == expected_prompt


def test_end_to_end_processing_invalid_max_image_size(
processor, sample_messages, sample_image
):
text = processor.apply_chat_template(sample_messages, add_generation_prompt=True)
with pytest.raises(ValueError):
processor(
text=text, images=[sample_image], return_tensors="pt", max_image_size=1000
)


def test_multiple_images_in_conversation(processor, sample_image):
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": None, "type": "image"},
{"text": "Compare the two images.", "type": "text"},
],
}
]

text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs, prompts = processor(
text=text,
images=[sample_image, sample_image], # Two images
return_tensors="pt",
max_image_size=980,
return_final_prompts=True,
)

assert "pixel_values" in inputs
assert inputs["pixel_values"].shape[0] == 2 # Batch size should be 2 for two images

expected_prompt = "<|im_start|>user\n<fim_prefix><|img|><fim_suffix><fim_prefix><|img|><fim_suffix>Compare the two images.<|im_end|>\n<|im_start|>assistant\n"
expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 256)

assert prompts[0] == expected_prompt


def test_split_image(processor, sample_messages, sample_image):
text = processor.apply_chat_template(sample_messages, add_generation_prompt=True)
inputs, prompts = processor(
text=text,
images=[sample_image],
return_tensors="pt",
max_image_size=490,
split_image=True,
return_final_prompts=True,
)

assert inputs["pixel_values"].shape == (5, 3, 490, 490)
assert inputs["pixel_mask"].shape == (5, 490, 490)

expected_prompt = "<|im_start|>user\n<fim_prefix><|img|><|img|><|img|><|img|><|img|><fim_suffix>describe the image<|im_end|>\n<|im_start|>assistant\n"
expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 128)

assert prompts[0] == expected_prompt

0 comments on commit 0d987a8

Please sign in to comment.