diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index e55fc55b..504cc535 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -12,7 +12,7 @@ def __init__(self): super(EmbedLoader, self).__init__() @staticmethod - def _load_glove(emb_file): + def _load_glove(emb_dim, emb_file): """Read file as a glove embedding file format: @@ -28,12 +28,12 @@ def _load_glove(emb_file): with open(emb_file, 'r', encoding='utf-8') as f: for line in f: line = list(filter(lambda w: len(w) > 0, line.strip().split(' '))) - if len(line) > 2: + if len(line) == emb_dim + 1: emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) return emb @staticmethod - def _load_pretrain(emb_file, emb_type): + def _load_pretrain(emb_dim, emb_file, emb_type): """Read txt data from embedding file and convert to np.array as pre-trained embedding :param str emb_file: the pre-trained embedding file path @@ -41,7 +41,7 @@ def _load_pretrain(emb_file, emb_type): :return: a dict of ``{str: np.array}`` """ if emb_type == 'glove': - return EmbedLoader._load_glove(emb_file) + return EmbedLoader._load_glove(emb_dim, emb_file) else: raise Exception("embedding type {} not support yet".format(emb_type)) @@ -58,7 +58,7 @@ def load_embedding(emb_dim, emb_file, emb_type, vocab): vocab - input vocab or vocab built by pre-train """ - pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) + pretrain = EmbedLoader._load_pretrain(emb_dim, emb_file, emb_type) if vocab is None: # build vocabulary from pre-trained embedding vocab = Vocabulary()