From 3006ff8b29684adead2eec82c304d81b5907fa74 Mon Sep 17 00:00:00 2001 From: magic282 Date: Wed, 9 Jan 2019 15:22:53 +0800 Subject: [PATCH] Update translate code. --- seq2seq_pt/s2s/Translator.py | 8 ++++---- seq2seq_pt/translate.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/seq2seq_pt/s2s/Translator.py b/seq2seq_pt/s2s/Translator.py index 383993a..213352e 100644 --- a/seq2seq_pt/s2s/Translator.py +++ b/seq2seq_pt/s2s/Translator.py @@ -200,14 +200,14 @@ def updateActive(t, rnnSize): return allHyp, allScores, allIsCopy, allCopyPosition, allAttn, None - def translate(self, srcBatch, goldBatch): + def translate(self, srcBatch, bio_batch, feats_batch, goldBatch): # (1) convert words to indexes - dataset = self.buildData(srcBatch, goldBatch) + dataset = self.buildData(srcBatch, bio_batch, feats_batch, goldBatch) # (wrap(srcBatch), lengths), (wrap(tgtBatch), ), indices - src, tgt, indices = dataset[0] + src, bio, feats, tgt, indices = dataset[0] # (2) translate - pred, predScore, predIsCopy, predCopyPosition, attn, _ = self.translateBatch(src, tgt) + pred, predScore, predIsCopy, predCopyPosition, attn, _ = self.translateBatch(src, bio, feats, tgt) pred, predScore, predIsCopy, predCopyPosition, attn = list(zip( *sorted(zip(pred, predScore, predIsCopy, predCopyPosition, attn, indices), key=lambda x: x[-1])))[:-1] diff --git a/seq2seq_pt/translate.py b/seq2seq_pt/translate.py index 0287d07..9f4c3c7 100644 --- a/seq2seq_pt/translate.py +++ b/seq2seq_pt/translate.py @@ -19,6 +19,8 @@ help='Path to model .pt file') parser.add_argument('-src', required=True, help='Source sequence to decode (one line per sequence)') +parser.add_argument('-bio') +parser.add_argument('-feats', default=[], nargs='+', type=str) parser.add_argument('-tgt', help='True target sequence (optional)') parser.add_argument('-output', default='pred.txt', @@ -66,7 +68,6 @@ def addPair(f1, f2): def main(): - raise Exception('Not implemented') opt = parser.parse_args() logger.info(opt) opt.cuda = opt.gpu > -1 @@ -80,15 +81,22 @@ def main(): predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0 srcBatch, tgtBatch = [], [] + bio_batch, feats_batch = [], [] count = 0 tgtF = open(opt.tgt) if opt.tgt else None + bioF = open(opt.bio, encoding='utf-8') + featFs = [open(x, encoding='utf-8') for x in opt.feats] for line in addone(open(opt.src, encoding='utf-8')): if (line is not None): srcTokens = line.strip().split(' ') srcBatch += [srcTokens] + bio_tokens = bioF.readline().strip().split(' ') + bio_batch += [bio_tokens] + feats_tokens = [reader.readline().strip().split((' ')) for reader in featFs] + feats_batch += [feats_tokens] if tgtF: tgtTokens = tgtF.readline().split(' ') if tgtF else None tgtBatch += [tgtTokens] @@ -100,7 +108,7 @@ def main(): if len(srcBatch) == 0: break - predBatch, predScore, goldScore = translator.translate(srcBatch, tgtBatch) + predBatch, predScore, goldScore = translator.translate(srcBatch, bio_batch, feats_batch, tgtBatch) predScoreTotal += sum(score[0] for score in predScore) predWordsTotal += sum(len(x[0]) for x in predBatch) @@ -136,6 +144,7 @@ def main(): logger.info('') srcBatch, tgtBatch = [], [] + bio_batch, feats_batch = [], [] reportScore('PRED', predScoreTotal, predWordsTotal) # if tgtF: