|
| 1 | +""" Word2Vec. |
| 2 | +
|
| 3 | +Implement Word2Vec algorithm to compute vector representations of words. |
| 4 | +This example is using a small chunk of Wikipedia articles to train from. |
| 5 | +
|
| 6 | +References: |
| 7 | + - Mikolov, Tomas et al. "Efficient Estimation of Word Representations |
| 8 | + in Vector Space.", 2013. |
| 9 | +
|
| 10 | +Links: |
| 11 | + - [Word2Vec] https://arxiv.org/pdf/1301.3781.pdf |
| 12 | +
|
| 13 | +Author: Aymeric Damien |
| 14 | +Project: https://github.com/aymericdamien/TensorFlow-Examples/ |
| 15 | +""" |
| 16 | +from __future__ import division, print_function, absolute_import |
| 17 | + |
| 18 | +import collections |
| 19 | +import os |
| 20 | +import random |
| 21 | +import urllib |
| 22 | +import zipfile |
| 23 | + |
| 24 | +import numpy as np |
| 25 | +import tensorflow as tf |
| 26 | + |
| 27 | +# Training Parameters |
| 28 | +learning_rate = 0.1 |
| 29 | +batch_size = 128 |
| 30 | +num_steps = 3000000 |
| 31 | +display_step = 10000 |
| 32 | +eval_step = 200000 |
| 33 | + |
| 34 | +# Evaluation Parameters |
| 35 | +eval_words = ['five', 'of', 'going', 'hardware', 'american', 'britain'] |
| 36 | + |
| 37 | +# Word2Vec Parameters |
| 38 | +embedding_size = 200 # Dimension of the embedding vector |
| 39 | +max_vocabulary_size = 50000 # Total number of different words in the vocabulary |
| 40 | +min_occurrence = 10 # Remove all words that does not appears at least n times |
| 41 | +skip_window = 3 # How many words to consider left and right |
| 42 | +num_skips = 2 # How many times to reuse an input to generate a label |
| 43 | +num_sampled = 64 # Number of negative examples to sample |
| 44 | + |
| 45 | + |
| 46 | +# Download a small chunk of Wikipedia articles collection |
| 47 | +url = 'http://mattmahoney.net/dc/text8.zip' |
| 48 | +data_path = 'text8.zip' |
| 49 | +if not os.path.exists(data_path): |
| 50 | + print("Downloading the dataset... (It may take some time)") |
| 51 | + filename, _ = urllib.urlretrieve(url, data_path) |
| 52 | + print("Done!") |
| 53 | +# Unzip the dataset file. Text has already been processed |
| 54 | +with zipfile.ZipFile(data_path) as f: |
| 55 | + text_words = f.read(f.namelist()[0]).lower().split() |
| 56 | + |
| 57 | +# Build the dictionary and replace rare words with UNK token |
| 58 | +count = [('UNK', -1)] |
| 59 | +# Retrieve the most common words |
| 60 | +count.extend(collections.Counter(text_words).most_common(max_vocabulary_size - 1)) |
| 61 | +# Remove samples with less than 'min_occurrence' occurrences |
| 62 | +for i in range(len(count) - 1, -1): |
| 63 | + if count[i][1] < min_occurrence: |
| 64 | + count.pop(i) |
| 65 | + else: |
| 66 | + # The collection is ordered, so stop when 'min_occurrence' is reached |
| 67 | + break |
| 68 | +# Compute the vocabulary size |
| 69 | +vocabulary_size = len(count) |
| 70 | +# Assign an id to each word |
| 71 | +word2id = dict() |
| 72 | +for i, (word, _)in enumerate(count): |
| 73 | + word2id[word] = i |
| 74 | + |
| 75 | +data = list() |
| 76 | +unk_count = 0 |
| 77 | +for word in text_words: |
| 78 | + # Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary |
| 79 | + index = word2id.get(word, 0) |
| 80 | + if index == 0: |
| 81 | + unk_count += 1 |
| 82 | + data.append(index) |
| 83 | +count[0] = ('UNK', unk_count) |
| 84 | +id2word = dict(zip(word2id.values(), word2id.keys())) |
| 85 | + |
| 86 | +print("Words count:", len(text_words)) |
| 87 | +print("Unique words:", len(set(text_words))) |
| 88 | +print("Vocabulary size:", vocabulary_size) |
| 89 | +print("Most common words:", count[:10]) |
| 90 | + |
| 91 | +data_index = 0 |
| 92 | +# Generate training batch for the skip-gram model |
| 93 | +def next_batch(batch_size, num_skips, skip_window): |
| 94 | + global data_index |
| 95 | + assert batch_size % num_skips == 0 |
| 96 | + assert num_skips <= 2 * skip_window |
| 97 | + batch = np.ndarray(shape=(batch_size), dtype=np.int32) |
| 98 | + labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) |
| 99 | + # get window size (words left and right + current one) |
| 100 | + span = 2 * skip_window + 1 |
| 101 | + buffer = collections.deque(maxlen=span) |
| 102 | + if data_index + span > len(data): |
| 103 | + data_index = 0 |
| 104 | + buffer.extend(data[data_index:data_index + span]) |
| 105 | + data_index += span |
| 106 | + for i in range(batch_size // num_skips): |
| 107 | + context_words = [w for w in range(span) if w != skip_window] |
| 108 | + words_to_use = random.sample(context_words, num_skips) |
| 109 | + for j, context_word in enumerate(words_to_use): |
| 110 | + batch[i * num_skips + j] = buffer[skip_window] |
| 111 | + labels[i * num_skips + j, 0] = buffer[context_word] |
| 112 | + if data_index == len(data): |
| 113 | + buffer.extend(data[0:span]) |
| 114 | + data_index = span |
| 115 | + else: |
| 116 | + buffer.append(data[data_index]) |
| 117 | + data_index += 1 |
| 118 | + # Backtrack a little bit to avoid skipping words in the end of a batch |
| 119 | + data_index = (data_index + len(data) - span) % len(data) |
| 120 | + return batch, labels |
| 121 | + |
| 122 | + |
| 123 | +# Input data |
| 124 | +X = tf.placeholder(tf.int32, shape=[None]) |
| 125 | +# Input label |
| 126 | +Y = tf.placeholder(tf.int32, shape=[None, 1]) |
| 127 | + |
| 128 | +# Ensure the following ops & var are assigned on CPU |
| 129 | +# (some ops are not compatible on GPU) |
| 130 | +with tf.device('/cpu:0'): |
| 131 | + # Create the embedding variable (each row represent a word embedding vector) |
| 132 | + embedding = tf.Variable(tf.random_normal([vocabulary_size, embedding_size])) |
| 133 | + # Lookup the corresponding embedding vectors for each sample in X |
| 134 | + X_embed = tf.nn.embedding_lookup(embedding, X) |
| 135 | + |
| 136 | + # Construct the variables for the NCE loss |
| 137 | + nce_weights = tf.Variable(tf.random_normal([vocabulary_size, embedding_size])) |
| 138 | + nce_biases = tf.Variable(tf.zeros([vocabulary_size])) |
| 139 | + |
| 140 | +# Compute the average NCE loss for the batch |
| 141 | +loss_op = tf.reduce_mean( |
| 142 | + tf.nn.nce_loss(weights=nce_weights, |
| 143 | + biases=nce_biases, |
| 144 | + labels=Y, |
| 145 | + inputs=X_embed, |
| 146 | + num_sampled=num_sampled, |
| 147 | + num_classes=vocabulary_size)) |
| 148 | + |
| 149 | +# Define the optimizer |
| 150 | +optimizer = tf.train.GradientDescentOptimizer(learning_rate) |
| 151 | +train_op = optimizer.minimize(loss_op) |
| 152 | + |
| 153 | +# Evaluation |
| 154 | +# Compute the cosine similarity between input data embedding and every embedding vectors |
| 155 | +X_embed_norm = X_embed / tf.sqrt(tf.reduce_sum(tf.square(X_embed))) |
| 156 | +embedding_norm = embedding / tf.sqrt(tf.reduce_sum(tf.square(embedding), 1, keepdims=True)) |
| 157 | +cosine_sim_op = tf.matmul(X_embed_norm, embedding_norm, transpose_b=True) |
| 158 | + |
| 159 | +# Initialize the variables (i.e. assign their default value) |
| 160 | +init = tf.global_variables_initializer() |
| 161 | + |
| 162 | +with tf.Session() as sess: |
| 163 | + |
| 164 | + # Run the initializer |
| 165 | + sess.run(init) |
| 166 | + |
| 167 | + # Testing data |
| 168 | + x_test = np.array([word2id[w] for w in eval_words]) |
| 169 | + |
| 170 | + average_loss = 0 |
| 171 | + for step in xrange(1, num_steps + 1): |
| 172 | + # Get a new batch of data |
| 173 | + batch_x, batch_y = next_batch(batch_size, num_skips, skip_window) |
| 174 | + # Run training op |
| 175 | + _, loss = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y}) |
| 176 | + average_loss += loss |
| 177 | + |
| 178 | + if step % display_step == 0 or step == 1: |
| 179 | + if step > 1: |
| 180 | + average_loss /= display_step |
| 181 | + print("Step " + str(step) + ", Average Loss= " + \ |
| 182 | + "{:.4f}".format(average_loss)) |
| 183 | + average_loss = 0 |
| 184 | + |
| 185 | + # Evaluation |
| 186 | + if step % eval_step == 0 or step == 1: |
| 187 | + print("Evaluation...") |
| 188 | + sim = sess.run(cosine_sim_op, feed_dict={X: x_test}) |
| 189 | + for i in xrange(len(eval_words)): |
| 190 | + top_k = 8 # number of nearest neighbors |
| 191 | + nearest = (-sim[i, :]).argsort()[1:top_k + 1] |
| 192 | + log_str = '"%s" nearest neighbors:' % eval_words[i] |
| 193 | + for k in xrange(top_k): |
| 194 | + log_str = '%s %s,' % (log_str, id2word[nearest[k]]) |
| 195 | + print(log_str) |
0 commit comments