Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLIP Finetuning Issue (Custom ViT image encoder) : Repetitive Output for All images #222

Open
bruceMug opened this issue Nov 10, 2024 · 0 comments

Comments

@bruceMug
Copy link

I'm attempting to finetune the BLIP model using a custom ViT for vision encoding. My ViT model is trained to classify medical images into three classes: healthy, COVID, and other. I’ve replaced the BLIP vision transformer with this custom ViT model and adjusted the image embedding shape to match BLIP's requirements. However, after training, the model produces repetitive single-word predictions for all images or even the same caption for all images.

Below are the details of my implementation, training setup, and the resulting predictions.

Steps to Reproduce:

  1. Model Setup:
  • Replace the vision transformer in BLIP with a custom ViT model fine-tuned for image classification (I provided the path to folder).
  • Ensure the ViT output embedding shape is padded to match BLIP’s expected input (1, 577, 768).
  1. Training:
  • Use a custom dataset with captions.
  • Train using the code provided below with AdamW optimizer and batch size of 2.

Code:

Click to expand
from transformers import ViTModel

vit_model = ViTModel.from_pretrained('/content/vit-base-custom/checkpoint-1000')
blip_model.vision_model = vit_model

class CustomBLIPWithViT(BlipForConditionalGeneration):
    def forward(self, pixel_values, text_input_ids, text_attention_mask, labels, return_dict=None):
        outputs = self.vision_model(pixel_values)
        image_embeds = outputs.last_hidden_state
        batch_size = image_embeds.shape[0]

        if image_embeds.shape[1] != 577:
            num_patches = image_embeds.shape[1]
            padding_size = 577 - num_patches
            padding = torch.zeros(batch_size, padding_size, 768).to(image_embeds.device)
            image_embeds = torch.cat([image_embeds, padding], dim=1)

        outputs = self.text_decoder(
            input_ids=text_input_ids,
            attention_mask=text_attention_mask,
            encoder_hidden_states=image_embeds,
            labels=labels,
            return_dict=return_dict,
            reduction="mean",
        )
        return outputs

custom_blip_model = CustomBLIPWithViT.from_pretrained('Salesforce/blip-image-captioning-base')
custom_blip_model.vision_model = vit_model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
processor.tokenizer.padding_size = 'right'
processor.image_processor.size = {'height': 224, 'width': 224}

train_data = ImageCaptioningDataset(train_dataset, processor)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=2)
test_data = ImageCaptioningDataset(test_dataset, processor)
test_dataloader = DataLoader(test_data, shuffle=False, batch_size=2)

optimizer = torch.optim.AdamW(custom_blip_model.parameters(), lr=lr)
custom_blip_model.to(device)
custom_blip_model.train()

for epoch in range(epochs):
  for idx, batch in tqdm(enumerate(train_dataloader), desc=f"Epoch: {epoch+1}"):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)
    attention_mask = batch.pop("attention_mask").to(device)

    outputs = custom_blip_model(text_input_ids=input_ids,
                    pixel_values=pixel_values,
                    text_attention_mask=attention_mask,
                    labels=input_ids
                    )

    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Inference Code:

Click to expand
for i, example in tqdm(enumerate(test_dataset), total=len(test_dataset)):
    image = example["image"]
    true_cap = example['text']
    inputs = processor(images=image, return_tensors="pt").to(device)
    pixel_values = inputs.pixel_values

    generated_ids = custom_blip_model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    predictions.append(generated_caption)
    true_captions.append(true_cap)
    assert len(predictions) == len(true_captions), f"Length mismatch: {len(predictions)} predictions vs {len(true_captions)} true captions"

Observed Output:
The model generates repetitive single-word captions as shown below (my test set had :

['probably probably probably probably probably probably probably probably probably probably ...',
 'probably probably probably probably probably probably probably probably probably probably ...', ...]

In other cases, the model generates the same caption for all the images in the test set.

Expected Behavior:
Generate unique captions that accurately describe each image in my dataset.

Any insights into why this repetitive output might occur or suggestions on how to adapt my ViT output to work better with BLIP's text generation would be greatly appreciated!

@switchingImageEncoder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant