Skip to content

Commit bf4bcf5

Browse files
committed
code
1 parent 1bc031e commit bf4bcf5

6 files changed

+1157
-0
lines changed

data.py

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import os
2+
import random
3+
import pickle
4+
import numpy as np
5+
import torch
6+
import scipy.io
7+
8+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9+
10+
def _fetch(path, name):
11+
if name == 'train':
12+
token_file = os.path.join(path, 'bow_tr_tokens.mat')
13+
count_file = os.path.join(path, 'bow_tr_counts.mat')
14+
elif name == 'valid':
15+
token_file = os.path.join(path, 'bow_va_tokens.mat')
16+
count_file = os.path.join(path, 'bow_va_counts.mat')
17+
else:
18+
token_file = os.path.join(path, 'bow_ts_tokens.mat')
19+
count_file = os.path.join(path, 'bow_ts_counts.mat')
20+
tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
21+
counts = scipy.io.loadmat(count_file)['counts'].squeeze()
22+
if name == 'test':
23+
token_1_file = os.path.join(path, 'bow_ts_h1_tokens.mat')
24+
count_1_file = os.path.join(path, 'bow_ts_h1_counts.mat')
25+
token_2_file = os.path.join(path, 'bow_ts_h2_tokens.mat')
26+
count_2_file = os.path.join(path, 'bow_ts_h2_counts.mat')
27+
tokens_1 = scipy.io.loadmat(token_1_file)['tokens'].squeeze()
28+
counts_1 = scipy.io.loadmat(count_1_file)['counts'].squeeze()
29+
tokens_2 = scipy.io.loadmat(token_2_file)['tokens'].squeeze()
30+
counts_2 = scipy.io.loadmat(count_2_file)['counts'].squeeze()
31+
return {'tokens': tokens, 'counts': counts, 'tokens_1': tokens_1, 'counts_1': counts_1, 'tokens_2': tokens_2, 'counts_2': counts_2}
32+
return {'tokens': tokens, 'counts': counts}
33+
34+
def _fetch_temporal(path, name):
35+
if name == 'train':
36+
token_file = os.path.join(path, 'bow_tr_tokens.mat')
37+
count_file = os.path.join(path, 'bow_tr_counts.mat')
38+
time_file = os.path.join(path, 'bow_tr_timestamps.mat')
39+
elif name == 'valid':
40+
token_file = os.path.join(path, 'bow_va_tokens.mat')
41+
count_file = os.path.join(path, 'bow_va_counts.mat')
42+
time_file = os.path.join(path, 'bow_va_timestamps.mat')
43+
else:
44+
token_file = os.path.join(path, 'bow_ts_tokens.mat')
45+
count_file = os.path.join(path, 'bow_ts_counts.mat')
46+
time_file = os.path.join(path, 'bow_ts_timestamps.mat')
47+
tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
48+
counts = scipy.io.loadmat(count_file)['counts'].squeeze()
49+
times = scipy.io.loadmat(time_file)['timestamps'].squeeze()
50+
if name == 'test':
51+
token_1_file = os.path.join(path, 'bow_ts_h1_tokens.mat')
52+
count_1_file = os.path.join(path, 'bow_ts_h1_counts.mat')
53+
token_2_file = os.path.join(path, 'bow_ts_h2_tokens.mat')
54+
count_2_file = os.path.join(path, 'bow_ts_h2_counts.mat')
55+
tokens_1 = scipy.io.loadmat(token_1_file)['tokens'].squeeze()
56+
counts_1 = scipy.io.loadmat(count_1_file)['counts'].squeeze()
57+
tokens_2 = scipy.io.loadmat(token_2_file)['tokens'].squeeze()
58+
counts_2 = scipy.io.loadmat(count_2_file)['counts'].squeeze()
59+
return {'tokens': tokens, 'counts': counts, 'times': times,
60+
'tokens_1': tokens_1, 'counts_1': counts_1,
61+
'tokens_2': tokens_2, 'counts_2': counts_2}
62+
return {'tokens': tokens, 'counts': counts, 'times': times}
63+
64+
def get_data(path, temporal=False):
65+
### load vocabulary
66+
with open(os.path.join(path, 'vocab.pkl'), 'rb') as f:
67+
vocab = pickle.load(f)
68+
69+
if not temporal:
70+
train = _fetch(path, 'train')
71+
valid = _fetch(path, 'valid')
72+
test = _fetch(path, 'test')
73+
else:
74+
train = _fetch_temporal(path, 'train')
75+
valid = _fetch_temporal(path, 'valid')
76+
test = _fetch_temporal(path, 'test')
77+
78+
return vocab, train, valid, test
79+
80+
def get_batch(tokens, counts, ind, vocab_size, emsize=300, temporal=False, times=None):
81+
"""fetch input data by batch."""
82+
batch_size = len(ind)
83+
data_batch = np.zeros((batch_size, vocab_size))
84+
if temporal:
85+
times_batch = np.zeros((batch_size, ))
86+
for i, doc_id in enumerate(ind):
87+
doc = tokens[doc_id]
88+
count = counts[doc_id]
89+
if temporal:
90+
timestamp = times[doc_id]
91+
times_batch[i] = timestamp
92+
L = count.shape[1]
93+
if len(doc) == 1:
94+
doc = [doc.squeeze()]
95+
count = [count.squeeze()]
96+
else:
97+
doc = doc.squeeze()
98+
count = count.squeeze()
99+
if doc_id != -1:
100+
for j, word in enumerate(doc):
101+
data_batch[i, word] = count[j]
102+
data_batch = torch.from_numpy(data_batch).float().to(device)
103+
if temporal:
104+
times_batch = torch.from_numpy(times_batch).to(device)
105+
return data_batch, times_batch
106+
return data_batch
107+
108+
def get_rnn_input(tokens, counts, times, num_times, vocab_size, num_docs):
109+
indices = torch.randperm(num_docs)
110+
indices = torch.split(indices, 1000)
111+
rnn_input = torch.zeros(num_times, vocab_size).to(device)
112+
cnt = torch.zeros(num_times, ).to(device)
113+
for idx, ind in enumerate(indices):
114+
data_batch, times_batch = get_batch(tokens, counts, ind, vocab_size, temporal=True, times=times)
115+
for t in range(num_times):
116+
tmp = (times_batch == t).nonzero()
117+
docs = data_batch[tmp].squeeze().sum(0)
118+
rnn_input[t] += docs
119+
cnt[t] += len(tmp)
120+
if idx % 20 == 0:
121+
print('idx: {}/{}'.format(idx, len(indices)))
122+
rnn_input = rnn_input / cnt.unsqueeze(1)
123+
return rnn_input

detm.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""This file defines a dynamic etm object.
2+
"""
3+
4+
import torch
5+
import torch.nn.functional as F
6+
import numpy as np
7+
import math
8+
9+
from torch import nn
10+
11+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12+
13+
class DETM(nn.Module):
14+
def __init__(self, args, embeddings):
15+
super(DETM, self).__init__()
16+
17+
## define hyperparameters
18+
self.num_topics = args.num_topics
19+
self.num_times = args.num_times
20+
self.vocab_size = args.vocab_size
21+
self.t_hidden_size = args.t_hidden_size
22+
self.eta_hidden_size = args.eta_hidden_size
23+
self.rho_size = args.rho_size
24+
self.emsize = args.emb_size
25+
self.enc_drop = args.enc_drop
26+
self.eta_nlayers = args.eta_nlayers
27+
self.t_drop = nn.Dropout(args.enc_drop)
28+
self.delta = args.delta
29+
self.train_embeddings = args.train_embeddings
30+
31+
self.theta_act = self.get_activation(args.theta_act)
32+
33+
## define the word embedding matrix \rho
34+
if args.train_embeddings:
35+
self.rho = nn.Linear(args.rho_size, args.vocab_size, bias=False)
36+
else:
37+
num_embeddings, emsize = embeddings.size()
38+
rho = nn.Embedding(num_embeddings, emsize)
39+
rho.weight.data = embeddings
40+
self.rho = rho.weight.data.clone().float().to(device)
41+
42+
## define the variational parameters for the topic embeddings over time (alpha) ... alpha is K x T x L
43+
self.mu_q_alpha = nn.Parameter(torch.randn(args.num_topics, args.num_times, args.rho_size))
44+
self.logsigma_q_alpha = nn.Parameter(torch.randn(args.num_topics, args.num_times, args.rho_size))
45+
46+
## define variational distribution for \theta_{1:D} via amortizartion... theta is K x D
47+
self.q_theta = nn.Sequential(
48+
nn.Linear(args.vocab_size+args.num_topics, args.t_hidden_size),
49+
self.theta_act,
50+
nn.Linear(args.t_hidden_size, args.t_hidden_size),
51+
self.theta_act,
52+
)
53+
self.mu_q_theta = nn.Linear(args.t_hidden_size, args.num_topics, bias=True)
54+
self.logsigma_q_theta = nn.Linear(args.t_hidden_size, args.num_topics, bias=True)
55+
56+
## define variational distribution for \eta via amortizartion... eta is K x T
57+
self.q_eta_map = nn.Linear(args.vocab_size, args.eta_hidden_size)
58+
self.q_eta = nn.LSTM(args.eta_hidden_size, args.eta_hidden_size, args.eta_nlayers, dropout=args.eta_dropout)
59+
self.mu_q_eta = nn.Linear(args.eta_hidden_size+args.num_topics, args.num_topics, bias=True)
60+
self.logsigma_q_eta = nn.Linear(args.eta_hidden_size+args.num_topics, args.num_topics, bias=True)
61+
62+
def get_activation(self, act):
63+
if act == 'tanh':
64+
act = nn.Tanh()
65+
elif act == 'relu':
66+
act = nn.ReLU()
67+
elif act == 'softplus':
68+
act = nn.Softplus()
69+
elif act == 'rrelu':
70+
act = nn.RReLU()
71+
elif act == 'leakyrelu':
72+
act = nn.LeakyReLU()
73+
elif act == 'elu':
74+
act = nn.ELU()
75+
elif act == 'selu':
76+
act = nn.SELU()
77+
elif act == 'glu':
78+
act = nn.GLU()
79+
else:
80+
print('Defaulting to tanh activations...')
81+
act = nn.Tanh()
82+
return act
83+
84+
def reparameterize(self, mu, logvar):
85+
"""Returns a sample from a Gaussian distribution via reparameterization.
86+
"""
87+
if self.training:
88+
std = torch.exp(0.5 * logvar)
89+
eps = torch.randn_like(std)
90+
return eps.mul_(std).add_(mu)
91+
else:
92+
return mu
93+
94+
def get_kl(self, q_mu, q_logsigma, p_mu=None, p_logsigma=None):
95+
"""Returns KL( N(q_mu, q_logsigma) || N(p_mu, p_logsigma) ).
96+
"""
97+
if p_mu is not None and p_logsigma is not None:
98+
sigma_q_sq = torch.exp(q_logsigma)
99+
sigma_p_sq = torch.exp(p_logsigma)
100+
kl = ( sigma_q_sq + (q_mu - p_mu)**2 ) / ( sigma_p_sq + 1e-6 )
101+
kl = kl - 1 + p_logsigma - q_logsigma
102+
kl = 0.5 * torch.sum(kl, dim=-1)
103+
else:
104+
kl = -0.5 * torch.sum(1 + q_logsigma - q_mu.pow(2) - q_logsigma.exp(), dim=-1)
105+
return kl
106+
107+
def get_alpha(self): ## mean field
108+
alphas = torch.zeros(self.num_times, self.num_topics, self.rho_size).to(device)
109+
kl_alpha = []
110+
111+
alphas[0] = self.reparameterize(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :])
112+
113+
p_mu_0 = torch.zeros(self.num_topics, self.rho_size).to(device)
114+
logsigma_p_0 = torch.zeros(self.num_topics, self.rho_size).to(device)
115+
kl_0 = self.get_kl(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :], p_mu_0, logsigma_p_0)
116+
kl_alpha.append(kl_0)
117+
for t in range(1, self.num_times):
118+
alphas[t] = self.reparameterize(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :])
119+
120+
p_mu_t = alphas[t-1]
121+
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics, self.rho_size).to(device))
122+
kl_t = self.get_kl(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :], p_mu_t, logsigma_p_t)
123+
kl_alpha.append(kl_t)
124+
kl_alpha = torch.stack(kl_alpha).sum()
125+
return alphas, kl_alpha.sum()
126+
127+
def get_eta(self, rnn_inp): ## structured amortized inference
128+
inp = self.q_eta_map(rnn_inp).unsqueeze(1)
129+
hidden = self.init_hidden()
130+
output, _ = self.q_eta(inp, hidden)
131+
output = output.squeeze()
132+
133+
etas = torch.zeros(self.num_times, self.num_topics).to(device)
134+
kl_eta = []
135+
136+
inp_0 = torch.cat([output[0], torch.zeros(self.num_topics,).to(device)], dim=0)
137+
mu_0 = self.mu_q_eta(inp_0)
138+
logsigma_0 = self.logsigma_q_eta(inp_0)
139+
etas[0] = self.reparameterize(mu_0, logsigma_0)
140+
141+
p_mu_0 = torch.zeros(self.num_topics,).to(device)
142+
logsigma_p_0 = torch.zeros(self.num_topics,).to(device)
143+
kl_0 = self.get_kl(mu_0, logsigma_0, p_mu_0, logsigma_p_0)
144+
kl_eta.append(kl_0)
145+
for t in range(1, self.num_times):
146+
inp_t = torch.cat([output[t], etas[t-1]], dim=0)
147+
mu_t = self.mu_q_eta(inp_t)
148+
logsigma_t = self.logsigma_q_eta(inp_t)
149+
etas[t] = self.reparameterize(mu_t, logsigma_t)
150+
151+
p_mu_t = etas[t-1]
152+
logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics,).to(device))
153+
kl_t = self.get_kl(mu_t, logsigma_t, p_mu_t, logsigma_p_t)
154+
kl_eta.append(kl_t)
155+
kl_eta = torch.stack(kl_eta).sum()
156+
return etas, kl_eta
157+
158+
def get_theta(self, eta, bows, times): ## amortized inference
159+
"""Returns the topic proportions.
160+
"""
161+
eta_td = eta[times.type('torch.LongTensor')]
162+
inp = torch.cat([bows, eta_td], dim=1)
163+
q_theta = self.q_theta(inp)
164+
if self.enc_drop > 0:
165+
q_theta = self.t_drop(q_theta)
166+
mu_theta = self.mu_q_theta(q_theta)
167+
logsigma_theta = self.logsigma_q_theta(q_theta)
168+
z = self.reparameterize(mu_theta, logsigma_theta)
169+
theta = F.softmax(z, dim=-1)
170+
kl_theta = self.get_kl(mu_theta, logsigma_theta, eta_td, torch.zeros(self.num_topics).to(device))
171+
return theta, kl_theta
172+
173+
def get_beta(self, alpha):
174+
"""Returns the topic matrix \beta of shape K x V
175+
"""
176+
if self.train_embeddings:
177+
logit = self.rho(alpha.view(alpha.size(0)*alpha.size(1), self.rho_size))
178+
else:
179+
tmp = alpha.view(alpha.size(0)*alpha.size(1), self.rho_size)
180+
logit = torch.mm(tmp, self.rho.permute(1, 0))
181+
logit = logit.view(alpha.size(0), alpha.size(1), -1)
182+
beta = F.softmax(logit, dim=-1)
183+
return beta
184+
185+
def get_nll(self, theta, beta, bows):
186+
theta = theta.unsqueeze(1)
187+
loglik = torch.bmm(theta, beta).squeeze(1)
188+
loglik = loglik
189+
loglik = torch.log(loglik+1e-6)
190+
nll = -loglik * bows
191+
nll = nll.sum(-1)
192+
return nll
193+
194+
def forward(self, bows, normalized_bows, times, rnn_inp, num_docs):
195+
bsz = normalized_bows.size(0)
196+
coeff = num_docs / bsz
197+
alpha, kl_alpha = self.get_alpha()
198+
eta, kl_eta = self.get_eta(rnn_inp)
199+
theta, kl_theta = self.get_theta(eta, normalized_bows, times)
200+
kl_theta = kl_theta.sum() * coeff
201+
202+
beta = self.get_beta(alpha)
203+
beta = beta[times.type('torch.LongTensor')]
204+
nll = self.get_nll(theta, beta, bows)
205+
nll = nll.sum() * coeff
206+
nelbo = nll + kl_alpha + kl_eta + kl_theta
207+
return nelbo, nll, kl_alpha, kl_eta, kl_theta
208+
209+
def init_hidden(self):
210+
"""Initializes the first hidden state of the RNN used as inference network for \eta.
211+
"""
212+
weight = next(self.parameters())
213+
nlayers = self.eta_nlayers
214+
nhid = self.eta_hidden_size
215+
return (weight.new_zeros(nlayers, 1, nhid), weight.new_zeros(nlayers, 1, nhid))

0 commit comments

Comments
 (0)