-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmmqa_oracle.py
97 lines (76 loc) · 2.43 KB
/
mmqa_oracle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import ipdb
import json
from tqdm import tqdm
import numpy as np
from utils.metrics import mmqa_metrics_approx
from utils.model_series import load_generator
from utils.utils import infer
import argparse
############### CLIP + Rerank ###############
def baseline_generate(
val_dataset,
generator_path,
tokenizer,
image_processor,
generator_model,
):
acc_scores = {"ALL": []}
with open("datasets/MMQA_ImageQ_metadata.json", "r") as f:
metadata = json.load(f)
for datum in tqdm(val_dataset):
qid = datum["qid"]
question = datum["question"]
answer = datum["answers"][0]["answer"]
pos_imgs = datum["supporting_context"]
pos_source = []
for item in pos_imgs:
pos_source.append(item["doc_id"])
IMAGE_PATH = ""
for i in range(len(pos_source)):
IMAGE_PATH += "finetune/tasks/MMQA_imgs/" + metadata[pos_source[i]]["path"]
if i != len(pos_source) - 1:
IMAGE_PATH += ","
output = infer(
generator_path,
IMAGE_PATH,
question,
generator_model,
tokenizer,
image_processor,
from_array=False,
)
if "how many" in question.lower():
qcate = "number"
else:
qcate = "normal"
accuracy = mmqa_metrics_approx(output, answer, qcate)
acc_scores["ALL"].append(accuracy)
print("Generation ACC:", np.mean(acc_scores["ALL"]))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--datasets", type=str, default="test")
parser.add_argument("--generator_model", type=str, default="noise_injected_lora")
parser.add_argument("--series", type=str, default="llava")
args = parser.parse_args()
print(args)
(tokenizer, generator_model, image_processor), generator_path = load_generator(
args, "mmqa"
)
if args.datasets == "test":
with open("datasets/MMQA_test_image.json", "r") as f:
val_dataset = json.load(f)
elif args.datasets == "dev":
with open("datasets/MMQA_test_image.json", "r") as f:
val_dataset = json.load(f)
with torch.no_grad():
baseline_generate(
val_dataset,
generator_path,
tokenizer,
image_processor,
generator_model,
)
print(args)
if __name__ == "__main__":
main()