Skip to content
This repository has been archived by the owner on Mar 19, 2021. It is now read-only.

Commit

Permalink
Update translate code.
Browse files Browse the repository at this point in the history
  • Loading branch information
magic282 committed Jan 9, 2019
1 parent da36a6d commit 3006ff8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
8 changes: 4 additions & 4 deletions seq2seq_pt/s2s/Translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ def updateActive(t, rnnSize):

return allHyp, allScores, allIsCopy, allCopyPosition, allAttn, None

def translate(self, srcBatch, goldBatch):
def translate(self, srcBatch, bio_batch, feats_batch, goldBatch):
# (1) convert words to indexes
dataset = self.buildData(srcBatch, goldBatch)
dataset = self.buildData(srcBatch, bio_batch, feats_batch, goldBatch)
# (wrap(srcBatch), lengths), (wrap(tgtBatch), ), indices
src, tgt, indices = dataset[0]
src, bio, feats, tgt, indices = dataset[0]

# (2) translate
pred, predScore, predIsCopy, predCopyPosition, attn, _ = self.translateBatch(src, tgt)
pred, predScore, predIsCopy, predCopyPosition, attn, _ = self.translateBatch(src, bio, feats, tgt)
pred, predScore, predIsCopy, predCopyPosition, attn = list(zip(
*sorted(zip(pred, predScore, predIsCopy, predCopyPosition, attn, indices),
key=lambda x: x[-1])))[:-1]
Expand Down
13 changes: 11 additions & 2 deletions seq2seq_pt/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
help='Path to model .pt file')
parser.add_argument('-src', required=True,
help='Source sequence to decode (one line per sequence)')
parser.add_argument('-bio')
parser.add_argument('-feats', default=[], nargs='+', type=str)
parser.add_argument('-tgt',
help='True target sequence (optional)')
parser.add_argument('-output', default='pred.txt',
Expand Down Expand Up @@ -66,7 +68,6 @@ def addPair(f1, f2):


def main():
raise Exception('Not implemented')
opt = parser.parse_args()
logger.info(opt)
opt.cuda = opt.gpu > -1
Expand All @@ -80,15 +81,22 @@ def main():
predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal = 0, 0, 0, 0

srcBatch, tgtBatch = [], []
bio_batch, feats_batch = [], []

count = 0

tgtF = open(opt.tgt) if opt.tgt else None
bioF = open(opt.bio, encoding='utf-8')
featFs = [open(x, encoding='utf-8') for x in opt.feats]
for line in addone(open(opt.src, encoding='utf-8')):

if (line is not None):
srcTokens = line.strip().split(' ')
srcBatch += [srcTokens]
bio_tokens = bioF.readline().strip().split(' ')
bio_batch += [bio_tokens]
feats_tokens = [reader.readline().strip().split((' ')) for reader in featFs]
feats_batch += [feats_tokens]
if tgtF:
tgtTokens = tgtF.readline().split(' ') if tgtF else None
tgtBatch += [tgtTokens]
Expand All @@ -100,7 +108,7 @@ def main():
if len(srcBatch) == 0:
break

predBatch, predScore, goldScore = translator.translate(srcBatch, tgtBatch)
predBatch, predScore, goldScore = translator.translate(srcBatch, bio_batch, feats_batch, tgtBatch)

predScoreTotal += sum(score[0] for score in predScore)
predWordsTotal += sum(len(x[0]) for x in predBatch)
Expand Down Expand Up @@ -136,6 +144,7 @@ def main():
logger.info('')

srcBatch, tgtBatch = [], []
bio_batch, feats_batch = [], []

reportScore('PRED', predScoreTotal, predWordsTotal)
# if tgtF:
Expand Down

0 comments on commit 3006ff8

Please sign in to comment.