From 8deb376b53d4b7f393a0a44d9cd95910bcf690ed Mon Sep 17 00:00:00 2001
From: ShadoWxShinigamI <116374738+ShadoWxShinigamI@users.noreply.github.com>
Date: Thu, 26 Oct 2023 22:28:44 +0530
Subject: [PATCH] Final Update
Added argument for prompt
--prompt="your prompt here"
---
qwen-batch-single-pass.py | 40 +++++++++++++++++++++++----------------
1 file changed, 24 insertions(+), 16 deletions(-)
diff --git a/qwen-batch-single-pass.py b/qwen-batch-single-pass.py
index ccee8b4..39c43e5 100644
--- a/qwen-batch-single-pass.py
+++ b/qwen-batch-single-pass.py
@@ -2,45 +2,50 @@
import re
import torch
import time
+import argparse
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
-from transformers.generation import GenerationConfig
-import argparse
+# Argument Parsing
+parser = argparse.ArgumentParser(description='Image Captioning Script')
+parser.add_argument('--imgdir', type=str, default='path/to/img/dir', help='Directory containing images')
+parser.add_argument('--exist', type=str, default='replace', choices=['skip', 'add', 'replace'], help='Handling of existing captions')
+parser.add_argument('--prompt', type=str, default='describe this image in detail, in less than 35 words', help='Prompt to use for image captioning')
+args = parser.parse_args()
+
+# Function to check for unwanted elements in the caption
def has_unwanted_elements(caption):
patterns = [r'[.*?]', r'.*?']
return any(re.search(pattern, caption) for pattern in patterns)
+# Function to clean up the caption
def clean_caption(caption):
caption = re.sub(r'[(.*?)]', r'\1', caption)
caption = re.sub(r'.*?', '', caption)
return caption.strip()
-# Argument parsing
-parser = argparse.ArgumentParser(description='Image Captioning Script')
-parser.add_argument('--imgdir', type=str, default='img/dir/here', help='Path to image directory')
-parser.add_argument('--exist', type=str, choices=['skip', 'add', 'replace'], default='replace', help='Handling of existing txt files')
-args = parser.parse_args()
-
-image_directory = args.imgdir
+# Supported image types
image_types = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
+# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="cuda", trust_remote_code=True, use_flash_attn=True).eval()
-files = [f for f in os.listdir(image_directory) if os.path.splitext(f)[1].lower() in image_types]
+# Get the list of image files in the specified directory
+files = [f for f in os.listdir(args.imgdir) if os.path.splitext(f)[1].lower() in image_types]
+# Initialize the progress bar
pbar = tqdm(total=len(files), desc="Captioning", dynamic_ncols=True, position=0, leave=True)
start_time = time.time()
print("Captioning phase:")
for i in range(len(files)):
filename = files[i]
- image_path = os.path.join(image_directory, filename)
+ image_path = os.path.join(args.imgdir, filename)
- # Check for existing txt file and handle based on the argument
+ # Handle based on the argument 'exist'
txt_filename = os.path.splitext(filename)[0] + '.txt'
- txt_path = os.path.join(image_directory, txt_filename)
+ txt_path = os.path.join(args.imgdir, txt_filename)
if args.exist == 'skip' and os.path.exists(txt_path):
pbar.update(1)
@@ -49,25 +54,28 @@ def clean_caption(caption):
with open(txt_path, 'r', encoding='utf-8') as f:
existing_content = f.read()
+ # Generate the caption using the model
query = tokenizer.from_list_format([
{'image': image_path},
- {'text': 'describe this image in detail, as if you are an art critic in less than 35 words'},
+ {'text': args.prompt},
])
response, _ = model.chat(tokenizer, query=query, history=None)
-
+
+ # Clean up the caption if necessary
if has_unwanted_elements(response):
response = clean_caption(response)
+ # Write the caption to the corresponding .txt file
with open(txt_path, 'w', encoding='utf-8') as f:
if args.exist == 'add' and os.path.exists(txt_path):
f.write(existing_content + "\n" + response)
else:
f.write(response)
+ # Update progress bar with some additional information about the process
elapsed_time = time.time() - start_time
images_per_sec = (i + 1) / elapsed_time
estimated_time_remaining = (len(files) - i - 1) / images_per_sec
-
pbar.set_postfix({"Time Elapsed": f"{elapsed_time:.2f}s", "ETA": f"{estimated_time_remaining:.2f}s", "Speed": f"{images_per_sec:.2f} img/s"})
pbar.update(1)