Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to inference infographicVQA task #25

Open
ShubhamAwasthi1 opened this issue Mar 20, 2023 · 1 comment
Open

Unable to inference infographicVQA task #25

ShubhamAwasthi1 opened this issue Mar 20, 2023 · 1 comment

Comments

@ShubhamAwasthi1
Copy link

I am trying to run the inference of the model for infographic vqa task. The instruction mention the cli command for a dummy task and is as follows:
python -m pix2struct.example_inference
--gin_search_paths="pix2struct/configs"
--gin_file=models/pix2struct.gin
--gin_file=runs/inference.gin
--gin_file=sizes/base.gin
--gin.MIXTURE_OR_TASK_NAME="'dummy_pix2struct'"
--gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}"
--gin.BATCH_SIZE=1
--gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'"
--image=$HOME/test_image.jpg

I have added the task task name, check point and text prompt for vqa task. But they are not in accordance to the requirement. Please provide a correct set of input values to perform the inference for the task.

python -m pix2struct.example_inference
--gin_search_paths="pix2struct/configs"
--gin_file=models/pix2struct.gin
--gin_file=runs/inference.gin
--gin_file=sizes/base.gin
--gin.MIXTURE_OR_TASK_NAME="InfographicVQA"
--gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}"
--gin.BATCH_SIZE=1
--gin.CHECKPOINT_PATH="gs://pix2struct-data/infographicvqa_large/checkpoint_182000"
--image="my_input_image.jpeg"
--text="What is written on the image of the calendar ?"

@NielsRogge
Copy link

NielsRogge commented Apr 2, 2023

Hi,

Inference might be a bit easier now with the HuggingFace integration. All checkpoints are on the hub, so if you want to try the infographicsvqa-large checkpoint, you can do that as follows:

from PIL import Image
import requests
from transformers import AutoProcessor, Pix2StructForConditionalGeneration

processor = AutoProcessor.from_pretrained("google/pix2struct-infographics-vqa-large")
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-infographics-vqa-large")

url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)

question = "What is written on the image of the calendar?"

inputs = processor(images=image, text=question, return_tensors="pt")

# autoregressive generation
predicted_ids = model.generate(**inputs)
predicted_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants