Skip to content

Commit

Permalink
Merge pull request #12 from Maluuba/oo-api-corpus
Browse files Browse the repository at this point in the history
fixes #11 : add functionality for compute_metrics to accept list of lists apart from files
  • Loading branch information
kracwarlock authored Mar 15, 2018
2 parents 86584e1 + 91b8245 commit 4d43b65
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,26 @@ corresponding hypothesis.
where `references` is a list of ground truth reference text strings and
`hypothesis` is the hypothesis text string.

### object oriented API for repeated calls in a script ###
### object oriented API for repeated calls in a script - single example ###

from nlgeval import NLGEval
nlgeval = NLGEval() # loads the models
metrics_dict = nlgeval.evaluate(references, hypothesis)
metrics_dict = nlgeval.compute_individual_metrics(references, hypothesis)

where `references` is a list of ground truth reference text strings and
`hypothesis` is the hypothesis text string.

### object oriented API for repeated calls in a script - multiple examples ###

from nlgeval import NLGEval
nlgeval = NLGEval() # loads the models
metrics_dict = nlgeval.compute_metrics(references, hypothesis)

where `references` is a list of lists of ground truth reference text strings and
`hypothesis` is a list of hypothesis text strings. Each inner list in `references`
is one set of references for the hypothesis (a list of single reference strings for
each sentence in `hypothesis` in the same order).

## Reference ##
If you use this code as part of any published research, please cite the following paper:

Expand Down
39 changes: 38 additions & 1 deletion nlgeval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def load_glove(self):
self.np = np
self.glove_emb = Embedding()

def evaluate(self, ref, hyp):
def compute_individual_metrics(self, ref, hyp):
assert isinstance(hyp, str)
ref = [a.strip() for a in ref]
refs = {0: ref}
Expand Down Expand Up @@ -207,3 +207,40 @@ def evaluate(self, ref, hyp):
ret_scores[name] = value

return ret_scores

def compute_metrics(self, ref_list, hyp_list):
ref_list = [map(str.strip, refs) for refs in zip(*ref_list)]
refs = {idx: strippedlines for (idx, strippedlines) in enumerate(ref_list)}
hyps = {idx: [lines.strip()] for (idx, lines) in enumerate(hyp_list)}
assert len(refs) == len(hyps)

ret_scores = {}
if not self.no_overlap:
for scorer, method in self.scorers:
score, scores = scorer.compute_score(refs, hyps)
if isinstance(method, list):
for sc, scs, m in zip(score, scores, method):
ret_scores[m] = sc
else:
ret_scores[method] = score

if not self.no_skipthoughts:
vector_hyps = self.skipthought_encoder.encode([h.strip() for h in hyp_list], verbose=False)
ref_list_T = self.np.array(ref_list).T.tolist()
vector_refs = map(lambda refl: self.skipthought_encoder.encode([r.strip() for r in refl], verbose=False), ref_list_T)
cosine_similarity = map(lambda refv: self.cosine_similarity(refv, vector_hyps).diagonal(), vector_refs)
cosine_similarity = self.np.max(cosine_similarity, axis=0).mean()
ret_scores['SkipThoughtCS'] = cosine_similarity

if not self.no_glove:
glove_hyps = [h.strip() for h in hyp_list]
ref_list_T = self.np.array(ref_list).T.tolist()
glove_refs = map(lambda refl: [r.strip() for r in refl], ref_list_T)
scores = self.eval_emb_metrics(glove_hyps, glove_refs, emb=self.glove_emb)
scores = scores.split('\n')
for score in scores:
name, value = score.split(':')
value = float(value.strip())
ret_scores[name] = value

return ret_scores
12 changes: 10 additions & 2 deletions test/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
def test_oo_api():
with open("examples/hyp.txt") as f:
hyp = f.readlines()
hyp = [x.strip() for x in hyp]
with open("examples/ref1.txt") as f:
ref1 = f.readlines()
ref1 = [x.strip() for x in ref1]
with open("examples/ref2.txt") as f:
ref2 = f.readlines()
ref2 = [x.strip() for x in ref2]

nlge = NLGEval()
res = nlge.evaluate([ref1[0]] + [ref2[0]], hyp[0])
res = nlge.evaluate([ref1[1]] + [ref2[1]], hyp[1])

res = nlge.compute_individual_metrics([ref1[0]] + [ref2[0]], hyp[0])
res = nlge.compute_individual_metrics([ref1[1]] + [ref2[1]], hyp[1])

hyp_list = hyp
ref_list = [ref1, ref2]
res = nlge.compute_metrics(ref_list, hyp_list)

0 comments on commit 4d43b65

Please sign in to comment.