From c541b38f28516d90eefdbe2c858526d30c96f1cb Mon Sep 17 00:00:00 2001 From: astariul Date: Sat, 6 Jul 2019 16:23:06 +0900 Subject: [PATCH] wip --- nlgeval/__init__.py | 6 +++++ nlgeval/others/bert_scorer.py | 47 +++++++++++++++++++++++++++++++++++ nlgeval/tests/test_nlgeval.py | 11 ++++++++ 3 files changed, 64 insertions(+) create mode 100644 nlgeval/others/bert_scorer.py diff --git a/nlgeval/__init__.py b/nlgeval/__init__.py index a85db78..acc74d9 100644 --- a/nlgeval/__init__.py +++ b/nlgeval/__init__.py @@ -9,6 +9,7 @@ from nlgeval.pycocoevalcap.cider.cider import Cider from nlgeval.pycocoevalcap.meteor.meteor import Meteor from nlgeval.pycocoevalcap.rouge.rouge import Rouge +from nlgeval.others.bert_scorer import BertScore # str/unicode stripping in Python 2 and 3 instead of `str.strip`. @@ -34,6 +35,7 @@ def compute_metrics(hypothesis, references, no_overlap=False, no_skipthoughts=Fa (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), (Meteor(), "METEOR"), (Rouge(), "ROUGE_L"), + (BertScore(), "BERT_score"), (Cider(), "CIDEr") ] for scorer, method in scorers: @@ -99,6 +101,7 @@ def compute_individual_metrics(ref, hyp, no_overlap=False, no_skipthoughts=False (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), (Meteor(), "METEOR"), (Rouge(), "ROUGE_L"), + (BertScore(), "BERT_score"), (Cider(), "CIDEr") ] for scorer, method in scorers: @@ -152,6 +155,7 @@ class NLGEval(object): 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', + 'BERT_score' 'CIDEr', # Skip-thought @@ -212,6 +216,8 @@ def load_scorers(self): self.scorers.append((Meteor(), "METEOR")) if 'ROUGE_L' not in self.metrics_to_omit: self.scorers.append((Rouge(), "ROUGE_L")) + if 'BERT_score' not in self.metrics_to_omit: + self.scorers.append((BertScore(), "BERT_score")) if 'CIDEr' not in self.metrics_to_omit: self.scorers.append((Cider(), "CIDEr")) diff --git a/nlgeval/others/bert_scorer.py b/nlgeval/others/bert_scorer.py new file mode 100644 index 0000000..63bdbcb --- /dev/null +++ b/nlgeval/others/bert_scorer.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# +# File Name : bert_scorer.py +# +# Description : Computes BERT score as described by Tianyi Zhang et all (2019) +# +# Creation Date : 2019-07-06 +# Author : REMOND Nicolas + +from bert_score import score + +class BertScore(): + ''' + Class for computing BERT score for a set of candidate sentences + ''' + + def __init__(self, score_type='f_score'): + # Score type to be returned + if score_type not in ['f_score', 'recall', 'precision']: + raise ValueError("Score type must be either 'f_score', 'precision', or 'recall'. Given : {}".format(score_type)) + self.score_type = score_type + + def compute_score(self, gts, res): + """ + Computes BERT score given a set of reference and candidate sentences for the dataset + :param res: dict : candidate / test sentences. + :param gts: dict : references. + :returns: average_score: float (mean BERT score computed by averaging scores for all the images), individual scores + """ + assert(gts.keys() == res.keys()) + imgIds = gts.keys() + + hyp = [res[id][0] for id in imgIds] + ref = [gts[id][0] for id in imgIds] # Take only the first reference + # Because Bert Score support only 1 + assert len(hyp) == len(ref) + + P, R, F1 = score(hyp, ref, bert="bert-base-uncased", no_idf=(len(ref) == 1)) + + if self.score_type == 'recall': + s = R + elif self.score_type == 'precision': + s = P + elif self.score_type == 'f_score': + s = F1 + + return s.mean().item(), s.tolist() \ No newline at end of file diff --git a/nlgeval/tests/test_nlgeval.py b/nlgeval/tests/test_nlgeval.py index 08b5d60..f3373c5 100644 --- a/nlgeval/tests/test_nlgeval.py +++ b/nlgeval/tests/test_nlgeval.py @@ -119,3 +119,14 @@ def test_compute_metrics(self): self.assertAlmostEqual(0.568696, scores['VectorExtremaCosineSimilarity'], places=5) self.assertAlmostEqual(0.784205, scores['GreedyMatchingScore'], places=5) self.assertEqual(11, len(scores)) + + def test_bert_score(self): + n = NLGEval(metrics_to_omit=['Bleu_1', 'Bleu_2', 'Bleu_3', 'ROUGE_L', 'METEOR', 'EmbeddingAverageCosineSimilairty', 'CIDEr', 'SkipThoughtCS', 'VectorExtremaCosineSimilarity', 'GreedyMatchingScore']) + + # Individual Metrics + scores = n.compute_individual_metrics(ref=["Until you start talking to Katrin Bahr."], + hyp="Until you talk to Katrin Bahr.") + self.assertAlmostEqual(0.9345, scores['BERT_score'], places=5) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file