-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathexample.py
74 lines (64 loc) · 2.49 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import code_bert_score
import pickle
from nltk.translate import bleu_score
from nltk.translate.bleu_score import sentence_bleu
import re
def tokenize_for_bleu_eval(code):
code = re.sub(r'([^A-Za-z0-9_])', r' \1 ', code)
tokens = [t for t in code.split(' ') if t]
return tokens
def print_results(predictions, refs, pred_results):
for i in range(len(refs)):
print(f'Example {i}:')
print(f'Reference: {refs[i]}')
print(f'Prediction: {predictions[i]}')
print(f'Prediction precision: {pred_results[0][i]:.3f}, recall: {pred_results[1][i]:.3f}, f1: {pred_results[2][i]:.3f}, f3: {pred_results[3][i]:.3f}')
ref_tokens = tokenize_for_bleu_eval(refs[i])
pred_tokens = tokenize_for_bleu_eval(predictions[i])
print(f'BLEU score: {sentence_bleu([ref_tokens], pred_tokens):.3f}')
print()
if __name__ == '__main__':
predictions = [
"""boolean f(Object target) {
for (Object elem: this.elements) {
if (elem.equals(target)) {
return true;
}
}
return false;
}""",
"""int f(Object target) {
for (int i=0; i<this.elements.size(); i++) {
Object elem = this.elements.get(i);
if (elem.equals(target)) {
return i;
}
}
return -1;
}"""
]
refs = [ \
"""int f(Object target) {
int i = 0;
for (Object elem: this.elements) {
if (elem.equals(target)) {
return i;
}
i++;
}
return -1;
}"""] * len(predictions)
with open('idf_dicts/java_idf.pkl', 'rb') as f:
java_idf = pickle.load(f)
pred_results = code_bert_score.score([''],['a'], sources=["a"], lang="python")
pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf)
print_results(predictions, refs, pred_results)
print('When providing the context: "find the index of target in this.elements"')
pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf, sources=['find the index of target in this.elements'] * 2)
print_results(predictions, refs, pred_results)
with open('idf_dicts/python_idf.pkl', 'rb') as f:
python_idf = pickle.load(f)
pred_results = code_bert_score.score(cands=['math.sqrt(x)'], refs=[['x ** 0.5']], no_punc=True, lang='python', idf=python_idf)
print(pred_results)
pred_results = code_bert_score.score(cands=['math.sqrt(x)'], refs=[['x ** 0.5']], rescale_with_baseline=True, lang='en')
print(pred_results)