From 31c564c56db7a2bf2519d1ed83e68c275cec1c84 Mon Sep 17 00:00:00 2001 From: Linyang He Date: Thu, 17 Jan 2019 18:20:52 +0800 Subject: [PATCH] Bugfix embed_loader.py --- fastNLP/io/embed_loader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index e55fc55b..504cc535 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -12,7 +12,7 @@ def __init__(self): super(EmbedLoader, self).__init__() @staticmethod - def _load_glove(emb_file): + def _load_glove(emb_dim, emb_file): """Read file as a glove embedding file format: @@ -28,12 +28,12 @@ def _load_glove(emb_file): with open(emb_file, 'r', encoding='utf-8') as f: for line in f: line = list(filter(lambda w: len(w) > 0, line.strip().split(' '))) - if len(line) > 2: + if len(line) == emb_dim + 1: emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) return emb @staticmethod - def _load_pretrain(emb_file, emb_type): + def _load_pretrain(emb_dim, emb_file, emb_type): """Read txt data from embedding file and convert to np.array as pre-trained embedding :param str emb_file: the pre-trained embedding file path @@ -41,7 +41,7 @@ def _load_pretrain(emb_file, emb_type): :return: a dict of ``{str: np.array}`` """ if emb_type == 'glove': - return EmbedLoader._load_glove(emb_file) + return EmbedLoader._load_glove(emb_dim, emb_file) else: raise Exception("embedding type {} not support yet".format(emb_type)) @@ -58,7 +58,7 @@ def load_embedding(emb_dim, emb_file, emb_type, vocab): vocab - input vocab or vocab built by pre-train """ - pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) + pretrain = EmbedLoader._load_pretrain(emb_dim, emb_file, emb_type) if vocab is None: # build vocabulary from pre-trained embedding vocab = Vocabulary()