forked from camenduru/minigpt4
-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
116 lines (87 loc) · 4.06 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import os
import random
import glob
import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
parser.add_argument("--image-folder", type=str, required=True, help="path to the input image folder")
parser.add_argument("--beam-search-numbers", type=int, default=1, help="beam search numbers")
parser.add_argument("--model", type=str, default='llama', help="Model to be used for generation. Options: 'llama' (default), 'llama7b'")
parser.add_argument("--save-in-imgfolder", action="store_true", help="save captions in the input image folder")
options = parser.parse_args()
return options
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
def describe_image(image_path, chat, chat_state, img, num_beams=1, temperature=1.0):
chat_state = CONV_VISION.copy()
img_list = []
gr_img = Image.open(image_path)
llm_message = chat.upload_img(gr_img, chat_state, img_list)
chat.ask("Describe this image.", chat_state)
generated_caption = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0]
return generated_caption
if __name__ == '__main__':
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
if args.model == "llama7b":
model_config.llama_model = "camenduru/MiniGPT4-7B"
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:0')
vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor)
chat_state = CONV_VISION.copy()
img_list = []
image_folder = args.image_folder
num_beams = args.beam_search_numbers
temperature = 1.0 # default temperature
image_extensions = ['jpg', 'jpeg', 'png', 'bmp', "webp"]
image_paths = []
for ext in image_extensions:
image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext}')))
image_paths.extend(glob.glob(os.path.join(image_folder, f'*.{ext.upper()}')))
if not args.save_in_imgfolder:
if not os.path.exists("mycaptions"):
os.makedirs("mycaptions")
for image_path in image_paths:
start_time = time.time()
caption = describe_image(image_path, chat, chat_state, img_list, num_beams, temperature)
if args.save_in_imgfolder:
output_path = os.path.join(image_folder, "{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0]))
else:
output_path = "mycaptions/{}_caption.txt".format(os.path.splitext(os.path.basename(image_path))[0])
with open(output_path, "w") as f:
f.write(caption)
end_time = time.time()
time_taken = end_time - start_time
print(f"Caption for {os.path.basename(image_path)} saved in '{output_path}'")
print(f"Time taken to process caption for {os.path.basename(image_path)} is: {time_taken:.2f} s")