forked from wandb/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_agreement.py
140 lines (117 loc) · 4.32 KB
/
compute_agreement.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Compute agreement among judges.
Usage:
python compute_agreement.py --judges gpt4-pair human --votefiles human_judgments.json gpt4_pair_judgments.json
python compute_agreement.py --judges human human --votefiles human_judgments.json
"""
import argparse
import json
import os
import numpy as np
def get_judge_name(judge):
if isinstance(judge, list) and judge[0] == "gpt-4" and judge[1].startswith("pair"):
return "gpt4-pair"
if judge.startswith("expert"):
return "human"
if judge.startswith("author"):
return "author"
def revert(vote):
if vote == "model_a":
return "model_b"
elif vote == "model_b":
return "model_a"
return vote
def get_mt_bench_votes_data(raw_votes):
data = [{}, {}]
for judge_votes in raw_votes:
for vote in judge_votes:
turn = vote["turn"] - 1
if vote["model_a"] < vote["model_b"]:
key = (vote["question_id"], vote["model_a"], vote["model_b"])
winner = vote["winner"]
else:
key = (vote["question_id"], vote["model_b"], vote["model_a"])
winner = revert(vote["winner"])
judge = get_judge_name(vote["judge"])
if key not in data[turn]:
data[turn][key] = {}
if judge not in data[turn][key]:
data[turn][key][judge] = []
data[turn][key][judge].append(winner)
return data
def convertvote(vote):
if "tie" in vote:
return "tie"
return vote
def equalvote(vote1, vote2):
if "tie" in vote1 and "tie" in vote2:
return True
return vote1 == vote2
# data: Dict[qid -> List[vote]]
def get_mt_bench_agreement(data, judge1, judge2, ban):
if judge1.startswith("gpt4") and judge2 == "human":
stats = [0, 0]
for votes in data.values():
if judge1 not in votes or judge2 not in votes:
continue
assert len(votes[judge1]) == 1
if convertvote(votes[judge1][0]) in ban:
continue
for v in votes[judge2]:
if convertvote(v) in ban:
continue
stats[1] += 1
stats[0] += equalvote(votes[judge1][0], v)
return stats[0], stats[1]
elif judge1 == "human" and judge2 == "human":
stats = [0, 0]
for votes in data.values():
if "human" not in votes:
continue
for i in range(len(votes["human"]) - 1):
for j in range(i + 1, len(votes["human"])):
if (
convertvote(votes["human"][i]) in ban
or convertvote(votes["human"][j]) in ban
):
continue
stats[1] += 1
stats[0] += equalvote(votes["human"][i], votes["human"][j])
return stats[0], stats[1]
else:
raise Exception("Unsupported judges.")
def run_mt_bench_agreement(judges, votefiles):
# votes[i]: List of votes
votes = []
for filename in votefiles:
with open(filename, "r") as f:
data = json.load(f)
votes.append(data)
data = get_mt_bench_votes_data(votes)
agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=[])
print(
f"turn 1 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}"
)
agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=["tie"])
print(
f"turn 1 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}"
)
agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=[])
print(
f"turn 2 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}"
)
agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=["tie"])
print(
f"turn 2 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--judges", nargs=2, type=str, default=["gpt4-pair", "human"])
parser.add_argument(
"--votefiles",
nargs="+",
type=str,
default=["gpt4_judgments.json", "human_judgments.json"],
)
args = parser.parse_args()
run_mt_bench_agreement(args.judges, args.votefiles)