diff --git a/README.md b/README.md index ef92be8..eae069a 100644 --- a/README.md +++ b/README.md @@ -61,15 +61,26 @@ corresponding hypothesis. where `references` is a list of ground truth reference text strings and `hypothesis` is the hypothesis text string. -### object oriented API for repeated calls in a script ### +### object oriented API for repeated calls in a script - single example ### from nlgeval import NLGEval nlgeval = NLGEval() # loads the models - metrics_dict = nlgeval.evaluate(references, hypothesis) + metrics_dict = nlgeval.compute_individual_metrics(references, hypothesis) where `references` is a list of ground truth reference text strings and `hypothesis` is the hypothesis text string. +### object oriented API for repeated calls in a script - multiple examples ### + + from nlgeval import NLGEval + nlgeval = NLGEval() # loads the models + metrics_dict = nlgeval.compute_metrics(references, hypothesis) + +where `references` is a list of lists of ground truth reference text strings and +`hypothesis` is a list of hypothesis text strings. Each inner list in `references` +is one set of references for the hypothesis (a list of single reference strings for +each sentence in `hypothesis` in the same order). + ## Reference ## If you use this code as part of any published research, please cite the following paper: diff --git a/nlgeval/__init__.py b/nlgeval/__init__.py index f329b5e..cd4942a 100644 --- a/nlgeval/__init__.py +++ b/nlgeval/__init__.py @@ -168,7 +168,7 @@ def load_glove(self): self.np = np self.glove_emb = Embedding() - def evaluate(self, ref, hyp): + def compute_individual_metrics(self, ref, hyp): assert isinstance(hyp, str) ref = [a.strip() for a in ref] refs = {0: ref} @@ -207,3 +207,40 @@ def evaluate(self, ref, hyp): ret_scores[name] = value return ret_scores + + def compute_metrics(self, ref_list, hyp_list): + ref_list = [map(str.strip, refs) for refs in zip(*ref_list)] + refs = {idx: strippedlines for (idx, strippedlines) in enumerate(ref_list)} + hyps = {idx: [lines.strip()] for (idx, lines) in enumerate(hyp_list)} + assert len(refs) == len(hyps) + + ret_scores = {} + if not self.no_overlap: + for scorer, method in self.scorers: + score, scores = scorer.compute_score(refs, hyps) + if isinstance(method, list): + for sc, scs, m in zip(score, scores, method): + ret_scores[m] = sc + else: + ret_scores[method] = score + + if not self.no_skipthoughts: + vector_hyps = self.skipthought_encoder.encode([h.strip() for h in hyp_list], verbose=False) + ref_list_T = self.np.array(ref_list).T.tolist() + vector_refs = map(lambda refl: self.skipthought_encoder.encode([r.strip() for r in refl], verbose=False), ref_list_T) + cosine_similarity = map(lambda refv: self.cosine_similarity(refv, vector_hyps).diagonal(), vector_refs) + cosine_similarity = self.np.max(cosine_similarity, axis=0).mean() + ret_scores['SkipThoughtCS'] = cosine_similarity + + if not self.no_glove: + glove_hyps = [h.strip() for h in hyp_list] + ref_list_T = self.np.array(ref_list).T.tolist() + glove_refs = map(lambda refl: [r.strip() for r in refl], ref_list_T) + scores = self.eval_emb_metrics(glove_hyps, glove_refs, emb=self.glove_emb) + scores = scores.split('\n') + for score in scores: + name, value = score.split(':') + value = float(value.strip()) + ret_scores[name] = value + + return ret_scores diff --git a/test/api.py b/test/api.py index 5bbe6f0..7076342 100644 --- a/test/api.py +++ b/test/api.py @@ -3,11 +3,19 @@ def test_oo_api(): with open("examples/hyp.txt") as f: hyp = f.readlines() + hyp = [x.strip() for x in hyp] with open("examples/ref1.txt") as f: ref1 = f.readlines() + ref1 = [x.strip() for x in ref1] with open("examples/ref2.txt") as f: ref2 = f.readlines() + ref2 = [x.strip() for x in ref2] nlge = NLGEval() - res = nlge.evaluate([ref1[0]] + [ref2[0]], hyp[0]) - res = nlge.evaluate([ref1[1]] + [ref2[1]], hyp[1]) + + res = nlge.compute_individual_metrics([ref1[0]] + [ref2[0]], hyp[0]) + res = nlge.compute_individual_metrics([ref1[1]] + [ref2[1]], hyp[1]) + + hyp_list = hyp + ref_list = [ref1, ref2] + res = nlge.compute_metrics(ref_list, hyp_list)