Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support For Directory Batch With a Single Prompt #136

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions qwen-batch-single-pass-v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import re
import torch
import time
import argparse
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

# Argument Parsing
parser = argparse.ArgumentParser(description='Image Captioning Script')
parser.add_argument('--imgdir', type=str, default='path/to/img/dir', help='Directory containing images')
parser.add_argument('--exist', type=str, default='replace', choices=['skip', 'add', 'replace'], help='Handling of existing captions')
parser.add_argument('--prompt', type=str, default='describe this image in detail, in less than 35 words', help='Prompt to use for image captioning')
parser.add_argument('--sub', type=lambda x: (str(x).lower() == 'true'), default=False, help='Search for images in subdirectories')
args = parser.parse_args()

# Function to check for unwanted elements in the caption
def has_unwanted_elements(caption):
patterns = [r'<ref>.*?</ref>', r'<box>.*?</box>']
return any(re.search(pattern, caption) for pattern in patterns)

# Function to clean up the caption
def clean_caption(caption):
caption = re.sub(r'<ref>(.*?)</ref>', r'\1', caption)
caption = re.sub(r'<box>.*?</box>', '', caption)
return caption.strip()

# Supported image types
image_types = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']

# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="cuda", trust_remote_code=True, use_flash_attn=True).eval()

# Function to get files recursively from a directory
def get_files_from_directory(directory, image_types, search_subdirectories=False):
if search_subdirectories:
return [os.path.join(dp, f) for dp, dn, filenames in os.walk(directory) for f in filenames if os.path.splitext(f)[1].lower() in image_types]
else:
return [f for f in os.listdir(directory) if os.path.splitext(f)[1].lower() in image_types]

# Get the list of image files in the specified directory, possibly including subdirectories
files = get_files_from_directory(args.imgdir, image_types, args.sub)

# Initialize the progress bar
pbar = tqdm(total=len(files), desc="Captioning", dynamic_ncols=True, position=0, leave=True)
start_time = time.time()

print("Captioning phase:")
for i in range(len(files)):
filename = files[i]
image_path = os.path.join(args.imgdir, filename)

# Handle based on the argument 'exist'
txt_filename = os.path.splitext(filename)[0] + '.txt'
txt_path = os.path.join(args.imgdir, txt_filename)

if args.exist == 'skip' and os.path.exists(txt_path):
pbar.update(1)
continue
elif args.exist == 'add' and os.path.exists(txt_path):
with open(txt_path, 'r', encoding='utf-8') as f:
existing_content = f.read()

# Generate the caption using the model
query = tokenizer.from_list_format([
{'image': image_path},
{'text': args.prompt},
])
response, _ = model.chat(tokenizer, query=query, history=None)

# Clean up the caption if necessary
if has_unwanted_elements(response):
response = clean_caption(response)

# Write the caption to the corresponding .txt file
with open(txt_path, 'w', encoding='utf-8') as f:
if args.exist == 'add' and os.path.exists(txt_path):
f.write(existing_content + " " + response)
else:
f.write(response)

# Update progress bar with some additional information about the process
elapsed_time = time.time() - start_time
images_per_sec = (i + 1) / elapsed_time
estimated_time_remaining = (len(files) - i - 1) / images_per_sec
pbar.set_postfix({"Time Elapsed": f"{elapsed_time:.2f}s", "ETA": f"{estimated_time_remaining:.2f}s", "Speed": f"{images_per_sec:.2f} img/s"})
pbar.update(1)

pbar.close()
82 changes: 82 additions & 0 deletions qwen-batch-single-pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import re
import torch
import time
import argparse
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

# Argument Parsing
parser = argparse.ArgumentParser(description='Image Captioning Script')
parser.add_argument('--imgdir', type=str, default='path/to/img/dir', help='Directory containing images')
parser.add_argument('--exist', type=str, default='replace', choices=['skip', 'add', 'replace'], help='Handling of existing captions')
parser.add_argument('--prompt', type=str, default='describe this image in detail, in less than 35 words', help='Prompt to use for image captioning')
args = parser.parse_args()

# Function to check for unwanted elements in the caption
def has_unwanted_elements(caption):
patterns = [r'<ref>.*?</ref>', r'<box>.*?</box>']
return any(re.search(pattern, caption) for pattern in patterns)

# Function to clean up the caption
def clean_caption(caption):
caption = re.sub(r'<ref>(.*?)</ref>', r'\1', caption)
caption = re.sub(r'<box>.*?</box>', '', caption)
return caption.strip()

# Supported image types
image_types = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']

# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="cuda", trust_remote_code=True, use_flash_attn=True).eval()

# Get the list of image files in the specified directory
files = [f for f in os.listdir(args.imgdir) if os.path.splitext(f)[1].lower() in image_types]

# Initialize the progress bar
pbar = tqdm(total=len(files), desc="Captioning", dynamic_ncols=True, position=0, leave=True)
start_time = time.time()

print("Captioning phase:")
for i in range(len(files)):
filename = files[i]
image_path = os.path.join(args.imgdir, filename)

# Handle based on the argument 'exist'
txt_filename = os.path.splitext(filename)[0] + '.txt'
txt_path = os.path.join(args.imgdir, txt_filename)

if args.exist == 'skip' and os.path.exists(txt_path):
pbar.update(1)
continue
elif args.exist == 'add' and os.path.exists(txt_path):
with open(txt_path, 'r', encoding='utf-8') as f:
existing_content = f.read()

# Generate the caption using the model
query = tokenizer.from_list_format([
{'image': image_path},
{'text': args.prompt},
])
response, _ = model.chat(tokenizer, query=query, history=None)

# Clean up the caption if necessary
if has_unwanted_elements(response):
response = clean_caption(response)

# Write the caption to the corresponding .txt file
with open(txt_path, 'w', encoding='utf-8') as f:
if args.exist == 'add' and os.path.exists(txt_path):
f.write(existing_content + " " + response)
else:
f.write(response)

# Update progress bar with some additional information about the process
elapsed_time = time.time() - start_time
images_per_sec = (i + 1) / elapsed_time
estimated_time_remaining = (len(files) - i - 1) / images_per_sec
pbar.set_postfix({"Time Elapsed": f"{elapsed_time:.2f}s", "ETA": f"{estimated_time_remaining:.2f}s", "Speed": f"{images_per_sec:.2f} img/s"})
pbar.update(1)

pbar.close()