From a3054057ae25c0896426d2b9d7b54a700bb4ac53 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Sun, 10 Nov 2024 02:24:36 +0000 Subject: [PATCH 1/7] add some tests for aria processor --- aria/model/processing_aria.py | 6 +- tests/test_aria_processor.py | 161 ++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 tests/test_aria_processor.py diff --git a/aria/model/processing_aria.py b/aria/model/processing_aria.py index 08a363a..7426f12 100644 --- a/aria/model/processing_aria.py +++ b/aria/model/processing_aria.py @@ -94,6 +94,7 @@ def __call__( max_image_size: Optional[int] = 980, split_image: Optional[bool] = False, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + return_final_prompts: Optional[bool] = False, ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring @@ -180,7 +181,10 @@ def __call__( max_length=max_length, ) - return BatchFeature(data={**text_inputs, **image_inputs}) + if return_final_prompts: + return BatchFeature(data={**text_inputs, **image_inputs}), prompt_strings + else: + return BatchFeature(data={**text_inputs, **image_inputs}) @staticmethod def _extract_kwargs(func: callable, **kwargs) -> dict: diff --git a/tests/test_aria_processor.py b/tests/test_aria_processor.py new file mode 100644 index 0000000..45d566d --- /dev/null +++ b/tests/test_aria_processor.py @@ -0,0 +1,161 @@ +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import AutoTokenizer + +from aria.model.processing_aria import AriaProcessor +from aria.model.vision_processor import AriaVisionProcessor + + +@pytest.fixture +def processor(): + tokenizer = AutoTokenizer.from_pretrained("rhymes-ai/Aria") + image_processor = AriaVisionProcessor(max_image_size=490) + return AriaProcessor( + tokenizer=tokenizer, + image_processor=image_processor, + image_token="<|img|>", + chat_template=tokenizer.chat_template, + ) + + +@pytest.fixture +def sample_image(): + return Image.fromarray(np.random.randint(0, 255, (768, 768, 3), dtype=np.uint8)) + + +@pytest.fixture +def sample_messages(): + return [ + { + "role": "user", + "content": [ + {"text": None, "type": "image"}, + {"text": "describe the image", "type": "text"}, + ], + } + ] + + +def test_apply_chat_template(processor, sample_messages): + text = processor.apply_chat_template(sample_messages, add_generation_prompt=True) + + assert ( + text + == "<|im_start|>user\n<|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" + ) + + text = processor.apply_chat_template(sample_messages, add_generation_prompt=False) + assert ( + text + == "<|im_start|>user\n<|img|>describe the image<|im_end|>\n" + ) + + +def test_chat_template_with_multiple_messages(processor): + messages = [ + { + "role": "user", + "content": [ + {"text": None, "type": "image"}, + {"text": "What's in this image?", "type": "text"}, + ], + }, + { + "role": "assistant", + "content": "This is a beautiful landscape.", + }, + { + "role": "user", + "content": [ + {"text": "Can you describe it in more detail?", "type": "text"}, + ], + }, + ] + + text = processor.apply_chat_template(messages, add_generation_prompt=True) + assert ( + text + == "<|im_start|>user\n<|img|>What's in this image?<|im_end|>\n<|im_start|>assistant\nThis is a beautiful landscape.<|im_end|>\n<|im_start|>user\nCan you describe it in more detail?<|im_end|>\n<|im_start|>assistant\n" + ) + + +def test_end_to_end_processing(processor, sample_messages, sample_image): + text = processor.apply_chat_template(sample_messages, add_generation_prompt=True) + inputs, prompts = processor( + text=text, + images=[sample_image], + return_tensors="pt", + max_image_size=980, + return_final_prompts=True, + ) + + # Verify the output contains all necessary keys + assert "input_ids" in inputs + assert "attention_mask" in inputs + assert "pixel_values" in inputs + + # Check shapes + assert len(inputs["input_ids"].shape) == 2 + assert len(inputs["attention_mask"].shape) == 2 + assert len(inputs["pixel_values"].shape) == 4 + + # Check device and dtype + assert inputs["input_ids"].device.type == "cpu" + assert inputs["pixel_values"].dtype == torch.float32 + + assert ( + prompts[0] + == "<|im_start|>user\n<|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" + ) + + +def test_multiple_images_in_conversation(processor, sample_image): + messages = [ + { + "role": "user", + "content": [ + {"text": None, "type": "image"}, + {"text": None, "type": "image"}, + {"text": "Compare the two images.", "type": "text"}, + ], + } + ] + + text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs, prompts = processor( + text=text, + images=[sample_image, sample_image], # Two images + return_tensors="pt", + max_image_size=980, + return_final_prompts=True, + ) + + assert "pixel_values" in inputs + assert inputs["pixel_values"].shape[0] == 2 # Batch size should be 2 for two images + + assert ( + prompts[0] + == "<|im_start|>user\n<|img|><|img|>Compare the two images.<|im_end|>\n<|im_start|>assistant\n" + ) + + +def test_split_image(processor, sample_messages, sample_image): + text = processor.apply_chat_template(sample_messages, add_generation_prompt=True) + inputs, prompts = processor( + text=text, + images=[sample_image], + return_tensors="pt", + max_image_size=490, + split_image=True, + return_final_prompts=True, + ) + + assert inputs["pixel_values"].shape == (5, 3, 490, 490) + assert inputs["pixel_mask"].shape == (5, 490, 490) + + assert ( + prompts[0] + == "<|im_start|>user\n<|img|><|img|><|img|><|img|><|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" + ) From 9b98c29ed0217ed948f7b21d16a42230d11d7610 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Sun, 10 Nov 2024 02:46:12 +0000 Subject: [PATCH 2/7] refactor: expand image tokens in processor --- aria/model/processing_aria.py | 18 +++++++++++++ tests/test_aria_processor.py | 50 ++++++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/aria/model/processing_aria.py b/aria/model/processing_aria.py index 7426f12..64e9538 100644 --- a/aria/model/processing_aria.py +++ b/aria/model/processing_aria.py @@ -169,6 +169,24 @@ def __call__( ) ) + max_image_size = ( + max_image_size + if max_image_size is not None + else self.image_processor.max_image_size + ) + if max_image_size == 490: + num_image_tokens = 128 + elif max_image_size == 980: + num_image_tokens = 256 + else: + raise ValueError( + f"max_image_size must be either 490 or 980, got {max_image_size}" + ) + prompt_strings = [ + sample.replace(self.image_token, self.image_token * num_image_tokens) + for sample in prompt_strings + ] + else: image_inputs = {} prompt_strings = text diff --git a/tests/test_aria_processor.py b/tests/test_aria_processor.py index 45d566d..2307157 100644 --- a/tests/test_aria_processor.py +++ b/tests/test_aria_processor.py @@ -81,7 +81,7 @@ def test_chat_template_with_multiple_messages(processor): ) -def test_end_to_end_processing(processor, sample_messages, sample_image): +def test_end_to_end_processing_980(processor, sample_messages, sample_image): text = processor.apply_chat_template(sample_messages, add_generation_prompt=True) inputs, prompts = processor( text=text, @@ -105,11 +105,37 @@ def test_end_to_end_processing(processor, sample_messages, sample_image): assert inputs["input_ids"].device.type == "cpu" assert inputs["pixel_values"].dtype == torch.float32 - assert ( - prompts[0] - == "<|im_start|>user\n<|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" + expected_prompt = "<|im_start|>user\n<|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" + expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 256) + + assert prompts[0] == expected_prompt + + +def test_end_to_end_processing_490(processor, sample_messages, sample_image): + text = processor.apply_chat_template(sample_messages, add_generation_prompt=True) + inputs, prompts = processor( + text=text, + images=[sample_image], + return_tensors="pt", + max_image_size=490, + return_final_prompts=True, ) + expected_prompt = "<|im_start|>user\n<|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" + expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 128) + + assert prompts[0] == expected_prompt + + +def test_end_to_end_processing_invalid_max_image_size( + processor, sample_messages, sample_image +): + text = processor.apply_chat_template(sample_messages, add_generation_prompt=True) + with pytest.raises(ValueError): + processor( + text=text, images=[sample_image], return_tensors="pt", max_image_size=1000 + ) + def test_multiple_images_in_conversation(processor, sample_image): messages = [ @@ -135,10 +161,10 @@ def test_multiple_images_in_conversation(processor, sample_image): assert "pixel_values" in inputs assert inputs["pixel_values"].shape[0] == 2 # Batch size should be 2 for two images - assert ( - prompts[0] - == "<|im_start|>user\n<|img|><|img|>Compare the two images.<|im_end|>\n<|im_start|>assistant\n" - ) + expected_prompt = "<|im_start|>user\n<|img|><|img|>Compare the two images.<|im_end|>\n<|im_start|>assistant\n" + expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 256) + + assert prompts[0] == expected_prompt def test_split_image(processor, sample_messages, sample_image): @@ -155,7 +181,7 @@ def test_split_image(processor, sample_messages, sample_image): assert inputs["pixel_values"].shape == (5, 3, 490, 490) assert inputs["pixel_mask"].shape == (5, 490, 490) - assert ( - prompts[0] - == "<|im_start|>user\n<|img|><|img|><|img|><|img|><|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" - ) + expected_prompt = "<|im_start|>user\n<|img|><|img|><|img|><|img|><|img|>describe the image<|im_end|>\n<|im_start|>assistant\n" + expected_prompt = expected_prompt.replace("<|img|>", "<|img|>" * 128) + + assert prompts[0] == expected_prompt From a4cb6c476c1ceaaeda796d524af45efb7b730b8d Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Sun, 10 Nov 2024 03:53:00 +0000 Subject: [PATCH 3/7] refactor: improve image handling and static cache support --- aria/model/configuration_aria.py | 4 + aria/model/modeling_aria.py | 185 +++++++++---------------------- 2 files changed, 57 insertions(+), 132 deletions(-) diff --git a/aria/model/configuration_aria.py b/aria/model/configuration_aria.py index efddd89..767fbc1 100644 --- a/aria/model/configuration_aria.py +++ b/aria/model/configuration_aria.py @@ -69,6 +69,7 @@ def __init__( self.image_token_index = image_token_index attn_implementation = kwargs.pop("attn_implementation", None) + self._attn_implementation = attn_implementation # Convert the keys and values of projector_patch_to_query_dict to integers # This ensures consistency even if they were provided as strings @@ -95,3 +96,6 @@ def __init__( text_config._attn_implementation = text_attn_implementation self.text_config = text_config + + # This is needed for the static kv cache + self.num_hidden_layers = self.text_config.num_hidden_layers diff --git a/aria/model/modeling_aria.py b/aria/model/modeling_aria.py index 4957fd9..92f6643 100644 --- a/aria/model/modeling_aria.py +++ b/aria/model/modeling_aria.py @@ -24,7 +24,6 @@ import torch.nn as nn from torch import nn from transformers import PreTrainedModel -from transformers.cache_utils import Cache from transformers.modeling_outputs import ModelOutput from transformers.utils import logging @@ -48,6 +47,7 @@ class AriaPretrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True + _supports_static_cache = True @property def _supports_sdpa(self): @@ -329,6 +329,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, AriaCausalLMOutputWithPast]: """ Forward pass of the AriaForConditionalGeneration model. @@ -371,69 +373,38 @@ def forward( # 1. Extra the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - # 2. Merge text and images - if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs, image_attn_mask = self.vision_tower( - pixel_values, - pixel_mask=pixel_mask, - ) - selected_image_feature = image_outputs.last_hidden_state - - image_features = self.multi_modal_projector( - selected_image_feature, attn_mask=image_attn_mask - ) - - inputs_embeds = inputs_embeds.to(image_features.dtype) - ( - inputs_embeds, - attention_mask, - labels, - position_ids, - ) = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels - ) - - # In case input_ids.shape[1] == 1 & pixel_values != None & past_key_values != None, we are in the case of - # generation with cache - elif ( - past_key_values is not None - and pixel_values is not None - and input_ids.shape[1] == 1 - ): - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors - # such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where( - first_layer_past_key_value.float().sum(-2) == 0 - ) - - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) + image_features = None + if pixel_values is not None: + image_outputs, image_attn_mask = self.vision_tower( + pixel_values, + pixel_mask=pixel_mask, + ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector( + selected_image_feature, attn_mask=image_attn_mask + ) - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + if image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] - attention_mask = torch.cat( - (extended_attention_mask, attention_mask[:, -target_length:]), dim=1 + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) outputs = self.language_model( attention_mask=attention_mask, @@ -444,6 +415,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] @@ -452,7 +425,11 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( + logits.device + ) shift_logits = logits[..., :-1, :][ shift_attention_mask.to(logits.device) != 0 ].contiguous() @@ -487,80 +464,24 @@ def prepare_inputs_for_generation( past_key_values=None, inputs_embeds=None, pixel_values=None, - pixel_mask=None, attention_mask=None, + cache_position=None, + num_logits_to_keep=None, **kwargs, ): - """ - Prepare inputs for generation step. + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + **kwargs, + ) - This method prepares the inputs for the generation step, handling both - text and image inputs, and managing the model's cache mechanism. + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values - Args: - input_ids (torch.LongTensor): Input token ids. - past_key_values (Cache or List[torch.FloatTensor], optional): Past key values for efficient processing. - inputs_embeds (torch.FloatTensor, optional): Input embeddings. - pixel_values (torch.FloatTensor, optional): Pixel values of the images. - pixel_mask (torch.LongTensor, optional): Mask for the pixel values. - attention_mask (torch.Tensor, optional): Attention mask. - **kwargs: Additional keyword arguments. - - Returns: - dict: A dictionary containing the prepared inputs for the generation step. - """ - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if ( - attention_mask is not None - and attention_mask.shape[1] > input_ids.shape[1] - ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[ - :, -(cache_length + input_ids.shape[1]) : - ] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_mask": pixel_mask, - } - ) return model_inputs From 4c559170ba7fad26337bc6b02935ec25aea5847b Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Sun, 10 Nov 2024 04:34:05 +0000 Subject: [PATCH 4/7] remove unused code --- aria/model/modeling_aria.py | 132 ------------------------------------ 1 file changed, 132 deletions(-) diff --git a/aria/model/modeling_aria.py b/aria/model/modeling_aria.py index 92f6643..637c8aa 100644 --- a/aria/model/modeling_aria.py +++ b/aria/model/modeling_aria.py @@ -183,138 +183,6 @@ def set_moe_aux_loss_coeff(self, value): """ self.language_model.set_aux_loss_coeff(value) - # copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration - def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids, attention_mask, labels - ): - """ - Merge input IDs with image features to create a combined input representation. - - This method handles the complex logic of interleaving text and image tokens, - adjusting attention masks and labels accordingly. - - Args: - image_features (torch.Tensor): Processed image features. - inputs_embeds (torch.Tensor): Text input embeddings. - input_ids (torch.Tensor): Input token IDs. - attention_mask (torch.Tensor): Attention mask for input tokens. - labels (torch.Tensor, optional): Labels for language modeling. - - Returns: - tuple: Contains the merged embeddings, updated attention mask, - updated labels, and position IDs. - """ - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum( - input_ids[:, -1] == torch.tensor(self.pad_token_id) - ) - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = ( - num_special_image_tokens.max() * (num_image_patches - 1) - ) + sequence_length - batch_indices, non_image_indices = torch.where( - input_ids != self.config.image_token_index - ) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = ( - torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - - 1 - ) - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, - max_embed_dim, - embed_dim, - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - ) - final_attention_mask = torch.zeros( - batch_size, - max_embed_dim, - dtype=attention_mask.dtype, - device=inputs_embeds.device, - ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), - self.config.ignore_index, - dtype=input_ids.dtype, - device=input_ids.device, - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ - batch_indices, non_image_indices - ] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ - batch_indices, non_image_indices - ] - if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[ - batch_indices, non_image_indices - ] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), - True, - dtype=torch.bool, - device=inputs_embeds.device, - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[ - :, None - ].to(target_device) - - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = ( - image_features.contiguous().reshape(-1, embed_dim).to(target_device) - ) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( - (final_attention_mask == 0), 1 - ) - - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) - indices_to_mask = new_token_positions[batch_indices, pad_indices] - - final_embedding[batch_indices, indices_to_mask] = 0 - - if labels is None: - final_labels = None - - return final_embedding, final_attention_mask, final_labels, position_ids - def forward( self, input_ids: torch.LongTensor = None, From 64ec4a7c71ed4d00c6ea8dc9011ea71a2741d7f9 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Mon, 11 Nov 2024 03:00:14 +0000 Subject: [PATCH 5/7] fix finetuning --- aria/data.py | 11 +++++++++++ aria/train.py | 1 + tests/test_apply_chat_template.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/aria/data.py b/aria/data.py index 192d15d..46021b2 100644 --- a/aria/data.py +++ b/aria/data.py @@ -31,6 +31,7 @@ def apply_chat_template_and_tokenize( tokenizer, num_image_crop: Iterable[torch.Tensor] = iter([]), max_length: int = 1024, + max_image_size: int = 980, ): IGNORE_TOKEN_ID = -100 im_start_tokens = tokenizer("<|im_start|>").input_ids @@ -76,6 +77,16 @@ def create_target(role, input_id): role = message["role"] text = "".join(process_content(content) for content in message["content"]) + if max_image_size == 490: + num_image_tokens = 128 + elif max_image_size == 980: + num_image_tokens = 256 + else: + raise ValueError( + f"max_image_size must be either 490 or 980, got {max_image_size}" + ) + text = text.replace("<|img|>", "<|img|>" * num_image_tokens) + _input_id = tokenize_message(role, text) input_id.extend(_input_id) target.extend(create_target(role, _input_id)) diff --git a/aria/train.py b/aria/train.py index 9753cb3..a229835 100644 --- a/aria/train.py +++ b/aria/train.py @@ -194,6 +194,7 @@ def collate_fn( tokenizer, iter(image_inputs.pop("num_crops")), max_length=max_seq_length, + max_image_size=processor.max_image_size, ) batch.update(image_inputs) diff --git a/tests/test_apply_chat_template.py b/tests/test_apply_chat_template.py index b948d59..8c71cd2 100644 --- a/tests/test_apply_chat_template.py +++ b/tests/test_apply_chat_template.py @@ -26,6 +26,7 @@ def test_apply_chat_template_single_user_message(tokenizer): } ] expected_output = "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n" + expected_output = expected_output.replace("<|img|>", "<|img|>" * 256) res = apply_chat_template_and_tokenize( [messages], num_image_crop=iter([1]), tokenizer=tokenizer ) @@ -37,6 +38,29 @@ def test_apply_chat_template_single_user_message(tokenizer): assert (labels == -100).sum() == input_ids.numel() +def test_apply_chat_template_single_user_message_490(tokenizer): + messages = [ + { + "content": [ + {"text": "Who wrote this book?\n", "type": "text"}, + {"text": None, "type": "image"}, + ], + "role": "user", + } + ] + expected_output = "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n" + expected_output = expected_output.replace("<|img|>", "<|img|>" * 128) + res = apply_chat_template_and_tokenize( + [messages], num_image_crop=iter([1]), tokenizer=tokenizer, max_image_size=490 + ) + input_ids = res["input_ids"] + input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] + assert input_str == expected_output + + labels = res["labels"] + assert (labels == -100).sum() == input_ids.numel() + + def test_apply_chat_template_single_assistant_message(tokenizer): messages = [ { @@ -66,6 +90,7 @@ def test_apply_chat_template_multiple_messages(tokenizer): }, ] expected_output = "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n" + expected_output = expected_output.replace("<|img|>", "<|img|>" * 256) res = apply_chat_template_and_tokenize( [messages], num_image_crop=iter([1]), tokenizer=tokenizer ) @@ -122,6 +147,7 @@ def test_apply_chat_template_multi_round_messages(tokenizer): }, ] expected_output = "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n<|im_start|>user\nWhat is the title of this book?<|im_end|>\n<|im_start|>assistant\nModern Printmaking: A Guide to Traditional and Digital Techniques<|im_end|>\n" + expected_output = expected_output.replace("<|img|>", "<|img|>" * 256) res = apply_chat_template_and_tokenize( [messages], num_image_crop=iter([1]), tokenizer=tokenizer ) @@ -187,6 +213,10 @@ def test_apply_chat_template_batch_messages(tokenizer): "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n", "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n<|im_start|>user\nWhat is the title of this book?<|im_end|>\n<|im_start|>assistant\nModern Printmaking: A Guide to Traditional and Digital Techniques<|im_end|>\n", ] + expected_output = [ + expected_output[0].replace("<|img|>", "<|img|>" * 256), + expected_output[1].replace("<|img|>", "<|img|>" * 256), + ] assert ( tokenizer.batch_decode(input_ids, skip_special_tokens=True) == expected_output ) From 297c476de0bb3ce9fd6c4e918bd5120e23b868ee Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Mon, 11 Nov 2024 03:09:28 +0000 Subject: [PATCH 6/7] update git ignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 19dcf22..f3fc4d9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ aria.egg-info/ wandb datasets/ gptfast/checkpoints/ -local_datasets/ \ No newline at end of file +local_datasets/ +eval \ No newline at end of file From c2c2aa6fcf496616a5bebb310a8fa185998eea55 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Mon, 11 Nov 2024 03:32:44 +0000 Subject: [PATCH 7/7] update git ignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f3fc4d9..0ece864 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ wandb datasets/ gptfast/checkpoints/ local_datasets/ -eval \ No newline at end of file +eval +build/ \ No newline at end of file