-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcontrol_recall.py
83 lines (65 loc) · 2.52 KB
/
control_recall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import json
import os
from glob import glob
from typing import *
import numpy as np
from nltk import word_tokenize
from tqdm import tqdm
from transformers import AutoTokenizer
KEY_TOKEN_FILE = "tokens.txt"
KEY_WORD_FILE = "words.txt"
TOKENIZER = "huggingface/bart-base"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_file", type=str, default="/home/zhuangzy/controlnet")
parser.add_argument("--cand_dir", type=str, required=True)
parser.add_argument("--save_path", type=str, required=False)
args = parser.parse_args()
return args
def compute_recall(cands: List[List[str]], oracles: List[List[str]]):
precision = []
for cand, oracle in zip(cands, oracles):
hit = [o for o in oracle if o in cand]
precision.append(len(hit) / len(oracle))
return np.round(np.mean(precision), 2)
def evaluate(sentences: List[str], oracle_tokens: List[List[str]], oracle_words: List[List[str]]):
words = [word_tokenize(sentence) for sentence in sentences]
word_recall = compute_recall(words, oracle_words)
tokens = [tokenizer.tokenize(sentence) for sentence in sentences]
token_recall = compute_recall(tokens, oracle_tokens)
results = {
"token_recall": word_recall,
"word_recall": token_recall
}
for k, v in results.items():
print(f"{k}: {v}")
return results
def load(filename):
res = []
with open(filename, 'r', encoding='utf-8') as f:
for line in f.readlines():
res.append(line.strip('\n').split(' '))
return res
def read_file(filename):
with open(filename, 'r', encoding='utf-8') as f:
res = f.readlines()
return [line.strip('\n') for line in res]
def main(opt: argparse.Namespace):
oracle_tokens = load(os.path.join(opt.base_file, KEY_TOKEN_FILE))
oracle_words = load(os.path.join(opt.base_file, KEY_WORD_FILE))
eval_results = {}
for filename in tqdm(sorted(glob(f"{opt.cand_dir}/*.txt")), desc="Evaluating..."):
print(f"\n\n evaluating {filename}...")
cand_sentences = read_file(filename)
result = evaluate(cand_sentences, oracle_tokens, oracle_words)
eval_results[filename] = result
result = json.dumps(eval_results, indent=2, ensure_ascii=False)
print(result)
if opt.save_path is not None:
with open(opt.save_path, 'w+', encoding='utf-8') as f:
f.write(result)
if __name__ == '__main__':
options = parse_args()
main(options)