Skip to content

Commit

Permalink
Final Update
Browse files Browse the repository at this point in the history
Added argument for prompt 
--prompt="your prompt here"
  • Loading branch information
ShadoWxShinigamI authored Oct 26, 2023
1 parent 5467b46 commit 32daa79
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions qwen-batch-single-pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,50 @@
import re
import torch
import time
import argparse
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import argparse

# Argument Parsing
parser = argparse.ArgumentParser(description='Image Captioning Script')
parser.add_argument('--imgdir', type=str, default='path/to/img/dir', help='Directory containing images')
parser.add_argument('--exist', type=str, default='replace', choices=['skip', 'add', 'replace'], help='Handling of existing captions')
parser.add_argument('--prompt', type=str, default='describe this image in detail, in less than 35 words', help='Prompt to use for image captioning')
args = parser.parse_args()

# Function to check for unwanted elements in the caption
def has_unwanted_elements(caption):
patterns = [r'<ref>.*?</ref>', r'<box>.*?</box>']
patterns = [r'<ref>.*?</ref>', r'<box>.*?</box>', r'\[\d+\]', r'\(\[\d+\]\)']
return any(re.search(pattern, caption) for pattern in patterns)

# Function to clean up the caption
def clean_caption(caption):
caption = re.sub(r'<ref>(.*?)</ref>', r'\1', caption)
caption = re.sub(r'<box>.*?</box>', '', caption)
return caption.strip()

# Argument parsing
parser = argparse.ArgumentParser(description='Image Captioning Script')
parser.add_argument('--imgdir', type=str, default='img/dir/here', help='Path to image directory')
parser.add_argument('--exist', type=str, choices=['skip', 'add', 'replace'], default='replace', help='Handling of existing txt files')
args = parser.parse_args()

image_directory = args.imgdir
# Supported image types
image_types = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']

# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="cuda", trust_remote_code=True, use_flash_attn=True).eval()
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="cuda", trust_remote_code=True).eval()

files = [f for f in os.listdir(image_directory) if os.path.splitext(f)[1].lower() in image_types]
# Get the list of image files in the specified directory
files = [f for f in os.listdir(args.imgdir) if os.path.splitext(f)[1].lower() in image_types]

# Initialize the progress bar
pbar = tqdm(total=len(files), desc="Captioning", dynamic_ncols=True, position=0, leave=True)
start_time = time.time()

print("Captioning phase:")
for i in range(len(files)):
filename = files[i]
image_path = os.path.join(image_directory, filename)
image_path = os.path.join(args.imgdir, filename)

# Check for existing txt file and handle based on the argument
# Handle based on the argument 'exist'
txt_filename = os.path.splitext(filename)[0] + '.txt'
txt_path = os.path.join(image_directory, txt_filename)
txt_path = os.path.join(args.imgdir, txt_filename)

if args.exist == 'skip' and os.path.exists(txt_path):
pbar.update(1)
Expand All @@ -49,25 +54,28 @@ def clean_caption(caption):
with open(txt_path, 'r', encoding='utf-8') as f:
existing_content = f.read()

# Generate the caption using the model
query = tokenizer.from_list_format([
{'image': image_path},
{'text': 'describe this image in detail, as if you are an art critic in less than 35 words'},
{'text': args.prompt},
])
response, _ = model.chat(tokenizer, query=query, history=None)


# Clean up the caption if necessary
if has_unwanted_elements(response):
response = clean_caption(response)

# Write the caption to the corresponding .txt file
with open(txt_path, 'w', encoding='utf-8') as f:
if args.exist == 'add' and os.path.exists(txt_path):
f.write(existing_content + "\n" + response)
else:
f.write(response)

# Update progress bar with some additional information about the process
elapsed_time = time.time() - start_time
images_per_sec = (i + 1) / elapsed_time
estimated_time_remaining = (len(files) - i - 1) / images_per_sec

pbar.set_postfix({"Time Elapsed": f"{elapsed_time:.2f}s", "ETA": f"{estimated_time_remaining:.2f}s", "Speed": f"{images_per_sec:.2f} img/s"})
pbar.update(1)

Expand Down

0 comments on commit 32daa79

Please sign in to comment.