-
Notifications
You must be signed in to change notification settings - Fork 1
/
word2vec_fns.py
65 lines (62 loc) · 2.83 KB
/
word2vec_fns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
import numpy as np
import collections
data_index = 0
def generate_batch(data, batch_size, skip_window):
"""
Generates a mini-batch of training data for the training CBOW
embedding model.
:param data (numpy.ndarray(dtype=int, shape=(corpus_size,)): holds the
training corpus, with words encoded as an integer
:param batch_size (int): size of the batch to generate
:param skip_window (int): number of words to both left and right that form
the context window for the target word.
Batch is a vector of shape (batch_size, 2*skip_window), with each entry for the batch containing all the context words, with the corresponding label being the word in the middle of the context
"""
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
buffer = collections.deque(maxlen=span)
if data_index + span > len(data):
data_index = 0
buffer.extend(data[data_index:data_index + span])
data_index += span
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
targets_to_avoid = [skip_window]
for j in range(num_skips):
# randomly sample a word in the context window, avoiding the target
# word, and words already added (both stored in targets_to_avoid)
while target in targets_to_avoid:
target = random.randint(0, span - 1)
targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[target]
if data_index == len(data):
# reached the end of the data, start again
buffer.extend(data[:span])
data_index = span
else:
# slide the window forward one word (n.b. buffer = deque(maxlen=span))
buffer.append(data[data_index])
data_index += 1
# Backtrack a little bit to avoid skipping words in the end of a batch
data_index = (data_index - span) % len(data)
return batch, labels
def get_mean_context_embeds(embeddings, train_inputs):
"""
:param embeddings (tf.Variable(shape=(vocabulary_size, embedding_size))
:param train_inputs (tf.placeholder(shape=(batch_size, 2*skip_window))
returns:
`mean_context_embeds`: the mean of the embeddings for all context words
for each entry in the batch, should have shape (batch_size,
embedding_size)
"""
# cpu is recommended to avoid out of memory errors, if you don't
# have a high capacity GPU
with tf.device('/cpu:0'):
pass
return mean_context_embeds