Skip to content

Commit

Permalink
lib sync
Browse files Browse the repository at this point in the history
  • Loading branch information
astonzhang committed Oct 23, 2020
1 parent 54f2725 commit e90d41b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions d2l/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,9 @@ def load_data_nmt(batch_size, num_steps, num_examples=600):
"""Return the iterator and the vocabularies of the translation dataset."""
text = preprocess_nmt(read_data_nmt())
source, target = tokenize_nmt(text, num_examples)
src_vocab = d2l.Vocab(source, min_freq=2,
src_vocab = d2l.Vocab(source, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = d2l.Vocab(target, min_freq=2,
tgt_vocab = d2l.Vocab(target, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
Expand Down Expand Up @@ -976,7 +976,7 @@ def predict_s2s_ch9(model, src_sentence, src_vocab, tgt_vocab, num_steps,
# of the decoder at the next time step
dec_X = Y.argmax(axis=2)
pred = dec_X.squeeze(axis=0).astype('int32').item()
# Once the end-of-sequence token is predicted, the generation of
# Once the end-of-sequence token is predicted, the generation of
# the output sequence is complete
if pred == tgt_vocab['<eos>']:
break
Expand Down
4 changes: 2 additions & 2 deletions d2l/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,9 @@ def load_data_nmt(batch_size, num_steps, num_examples=600):
"""Return the iterator and the vocabularies of the translation dataset."""
text = preprocess_nmt(read_data_nmt())
source, target = tokenize_nmt(text, num_examples)
src_vocab = d2l.Vocab(source, min_freq=2,
src_vocab = d2l.Vocab(source, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = d2l.Vocab(target, min_freq=2,
tgt_vocab = d2l.Vocab(target, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
Expand Down
6 changes: 3 additions & 3 deletions d2l/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,9 @@ def load_data_nmt(batch_size, num_steps, num_examples=600):
"""Return the iterator and the vocabularies of the translation dataset."""
text = preprocess_nmt(read_data_nmt())
source, target = tokenize_nmt(text, num_examples)
src_vocab = d2l.Vocab(source, min_freq=2,
src_vocab = d2l.Vocab(source, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = d2l.Vocab(target, min_freq=2,
tgt_vocab = d2l.Vocab(target, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def predict_s2s_ch9(model, src_sentence, src_vocab, tgt_vocab, num_steps,
# of the decoder at the next time step
dec_X = Y.argmax(dim=2)
pred = dec_X.squeeze(dim=0).type(torch.int32).item()
# Once the end-of-sequence token is predicted, the generation of
# Once the end-of-sequence token is predicted, the generation of
# the output sequence is complete
if pred == tgt_vocab['<eos>']:
break
Expand Down

0 comments on commit e90d41b

Please sign in to comment.