|
| 1 | +"""This file defines a dynamic etm object. |
| 2 | +""" |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +import numpy as np |
| 7 | +import math |
| 8 | + |
| 9 | +from torch import nn |
| 10 | + |
| 11 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 12 | + |
| 13 | +class DETM(nn.Module): |
| 14 | + def __init__(self, args, embeddings): |
| 15 | + super(DETM, self).__init__() |
| 16 | + |
| 17 | + ## define hyperparameters |
| 18 | + self.num_topics = args.num_topics |
| 19 | + self.num_times = args.num_times |
| 20 | + self.vocab_size = args.vocab_size |
| 21 | + self.t_hidden_size = args.t_hidden_size |
| 22 | + self.eta_hidden_size = args.eta_hidden_size |
| 23 | + self.rho_size = args.rho_size |
| 24 | + self.emsize = args.emb_size |
| 25 | + self.enc_drop = args.enc_drop |
| 26 | + self.eta_nlayers = args.eta_nlayers |
| 27 | + self.t_drop = nn.Dropout(args.enc_drop) |
| 28 | + self.delta = args.delta |
| 29 | + self.train_embeddings = args.train_embeddings |
| 30 | + |
| 31 | + self.theta_act = self.get_activation(args.theta_act) |
| 32 | + |
| 33 | + ## define the word embedding matrix \rho |
| 34 | + if args.train_embeddings: |
| 35 | + self.rho = nn.Linear(args.rho_size, args.vocab_size, bias=False) |
| 36 | + else: |
| 37 | + num_embeddings, emsize = embeddings.size() |
| 38 | + rho = nn.Embedding(num_embeddings, emsize) |
| 39 | + rho.weight.data = embeddings |
| 40 | + self.rho = rho.weight.data.clone().float().to(device) |
| 41 | + |
| 42 | + ## define the variational parameters for the topic embeddings over time (alpha) ... alpha is K x T x L |
| 43 | + self.mu_q_alpha = nn.Parameter(torch.randn(args.num_topics, args.num_times, args.rho_size)) |
| 44 | + self.logsigma_q_alpha = nn.Parameter(torch.randn(args.num_topics, args.num_times, args.rho_size)) |
| 45 | + |
| 46 | + ## define variational distribution for \theta_{1:D} via amortizartion... theta is K x D |
| 47 | + self.q_theta = nn.Sequential( |
| 48 | + nn.Linear(args.vocab_size+args.num_topics, args.t_hidden_size), |
| 49 | + self.theta_act, |
| 50 | + nn.Linear(args.t_hidden_size, args.t_hidden_size), |
| 51 | + self.theta_act, |
| 52 | + ) |
| 53 | + self.mu_q_theta = nn.Linear(args.t_hidden_size, args.num_topics, bias=True) |
| 54 | + self.logsigma_q_theta = nn.Linear(args.t_hidden_size, args.num_topics, bias=True) |
| 55 | + |
| 56 | + ## define variational distribution for \eta via amortizartion... eta is K x T |
| 57 | + self.q_eta_map = nn.Linear(args.vocab_size, args.eta_hidden_size) |
| 58 | + self.q_eta = nn.LSTM(args.eta_hidden_size, args.eta_hidden_size, args.eta_nlayers, dropout=args.eta_dropout) |
| 59 | + self.mu_q_eta = nn.Linear(args.eta_hidden_size+args.num_topics, args.num_topics, bias=True) |
| 60 | + self.logsigma_q_eta = nn.Linear(args.eta_hidden_size+args.num_topics, args.num_topics, bias=True) |
| 61 | + |
| 62 | + def get_activation(self, act): |
| 63 | + if act == 'tanh': |
| 64 | + act = nn.Tanh() |
| 65 | + elif act == 'relu': |
| 66 | + act = nn.ReLU() |
| 67 | + elif act == 'softplus': |
| 68 | + act = nn.Softplus() |
| 69 | + elif act == 'rrelu': |
| 70 | + act = nn.RReLU() |
| 71 | + elif act == 'leakyrelu': |
| 72 | + act = nn.LeakyReLU() |
| 73 | + elif act == 'elu': |
| 74 | + act = nn.ELU() |
| 75 | + elif act == 'selu': |
| 76 | + act = nn.SELU() |
| 77 | + elif act == 'glu': |
| 78 | + act = nn.GLU() |
| 79 | + else: |
| 80 | + print('Defaulting to tanh activations...') |
| 81 | + act = nn.Tanh() |
| 82 | + return act |
| 83 | + |
| 84 | + def reparameterize(self, mu, logvar): |
| 85 | + """Returns a sample from a Gaussian distribution via reparameterization. |
| 86 | + """ |
| 87 | + if self.training: |
| 88 | + std = torch.exp(0.5 * logvar) |
| 89 | + eps = torch.randn_like(std) |
| 90 | + return eps.mul_(std).add_(mu) |
| 91 | + else: |
| 92 | + return mu |
| 93 | + |
| 94 | + def get_kl(self, q_mu, q_logsigma, p_mu=None, p_logsigma=None): |
| 95 | + """Returns KL( N(q_mu, q_logsigma) || N(p_mu, p_logsigma) ). |
| 96 | + """ |
| 97 | + if p_mu is not None and p_logsigma is not None: |
| 98 | + sigma_q_sq = torch.exp(q_logsigma) |
| 99 | + sigma_p_sq = torch.exp(p_logsigma) |
| 100 | + kl = ( sigma_q_sq + (q_mu - p_mu)**2 ) / ( sigma_p_sq + 1e-6 ) |
| 101 | + kl = kl - 1 + p_logsigma - q_logsigma |
| 102 | + kl = 0.5 * torch.sum(kl, dim=-1) |
| 103 | + else: |
| 104 | + kl = -0.5 * torch.sum(1 + q_logsigma - q_mu.pow(2) - q_logsigma.exp(), dim=-1) |
| 105 | + return kl |
| 106 | + |
| 107 | + def get_alpha(self): ## mean field |
| 108 | + alphas = torch.zeros(self.num_times, self.num_topics, self.rho_size).to(device) |
| 109 | + kl_alpha = [] |
| 110 | + |
| 111 | + alphas[0] = self.reparameterize(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :]) |
| 112 | + |
| 113 | + p_mu_0 = torch.zeros(self.num_topics, self.rho_size).to(device) |
| 114 | + logsigma_p_0 = torch.zeros(self.num_topics, self.rho_size).to(device) |
| 115 | + kl_0 = self.get_kl(self.mu_q_alpha[:, 0, :], self.logsigma_q_alpha[:, 0, :], p_mu_0, logsigma_p_0) |
| 116 | + kl_alpha.append(kl_0) |
| 117 | + for t in range(1, self.num_times): |
| 118 | + alphas[t] = self.reparameterize(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :]) |
| 119 | + |
| 120 | + p_mu_t = alphas[t-1] |
| 121 | + logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics, self.rho_size).to(device)) |
| 122 | + kl_t = self.get_kl(self.mu_q_alpha[:, t, :], self.logsigma_q_alpha[:, t, :], p_mu_t, logsigma_p_t) |
| 123 | + kl_alpha.append(kl_t) |
| 124 | + kl_alpha = torch.stack(kl_alpha).sum() |
| 125 | + return alphas, kl_alpha.sum() |
| 126 | + |
| 127 | + def get_eta(self, rnn_inp): ## structured amortized inference |
| 128 | + inp = self.q_eta_map(rnn_inp).unsqueeze(1) |
| 129 | + hidden = self.init_hidden() |
| 130 | + output, _ = self.q_eta(inp, hidden) |
| 131 | + output = output.squeeze() |
| 132 | + |
| 133 | + etas = torch.zeros(self.num_times, self.num_topics).to(device) |
| 134 | + kl_eta = [] |
| 135 | + |
| 136 | + inp_0 = torch.cat([output[0], torch.zeros(self.num_topics,).to(device)], dim=0) |
| 137 | + mu_0 = self.mu_q_eta(inp_0) |
| 138 | + logsigma_0 = self.logsigma_q_eta(inp_0) |
| 139 | + etas[0] = self.reparameterize(mu_0, logsigma_0) |
| 140 | + |
| 141 | + p_mu_0 = torch.zeros(self.num_topics,).to(device) |
| 142 | + logsigma_p_0 = torch.zeros(self.num_topics,).to(device) |
| 143 | + kl_0 = self.get_kl(mu_0, logsigma_0, p_mu_0, logsigma_p_0) |
| 144 | + kl_eta.append(kl_0) |
| 145 | + for t in range(1, self.num_times): |
| 146 | + inp_t = torch.cat([output[t], etas[t-1]], dim=0) |
| 147 | + mu_t = self.mu_q_eta(inp_t) |
| 148 | + logsigma_t = self.logsigma_q_eta(inp_t) |
| 149 | + etas[t] = self.reparameterize(mu_t, logsigma_t) |
| 150 | + |
| 151 | + p_mu_t = etas[t-1] |
| 152 | + logsigma_p_t = torch.log(self.delta * torch.ones(self.num_topics,).to(device)) |
| 153 | + kl_t = self.get_kl(mu_t, logsigma_t, p_mu_t, logsigma_p_t) |
| 154 | + kl_eta.append(kl_t) |
| 155 | + kl_eta = torch.stack(kl_eta).sum() |
| 156 | + return etas, kl_eta |
| 157 | + |
| 158 | + def get_theta(self, eta, bows, times): ## amortized inference |
| 159 | + """Returns the topic proportions. |
| 160 | + """ |
| 161 | + eta_td = eta[times.type('torch.LongTensor')] |
| 162 | + inp = torch.cat([bows, eta_td], dim=1) |
| 163 | + q_theta = self.q_theta(inp) |
| 164 | + if self.enc_drop > 0: |
| 165 | + q_theta = self.t_drop(q_theta) |
| 166 | + mu_theta = self.mu_q_theta(q_theta) |
| 167 | + logsigma_theta = self.logsigma_q_theta(q_theta) |
| 168 | + z = self.reparameterize(mu_theta, logsigma_theta) |
| 169 | + theta = F.softmax(z, dim=-1) |
| 170 | + kl_theta = self.get_kl(mu_theta, logsigma_theta, eta_td, torch.zeros(self.num_topics).to(device)) |
| 171 | + return theta, kl_theta |
| 172 | + |
| 173 | + def get_beta(self, alpha): |
| 174 | + """Returns the topic matrix \beta of shape K x V |
| 175 | + """ |
| 176 | + if self.train_embeddings: |
| 177 | + logit = self.rho(alpha.view(alpha.size(0)*alpha.size(1), self.rho_size)) |
| 178 | + else: |
| 179 | + tmp = alpha.view(alpha.size(0)*alpha.size(1), self.rho_size) |
| 180 | + logit = torch.mm(tmp, self.rho.permute(1, 0)) |
| 181 | + logit = logit.view(alpha.size(0), alpha.size(1), -1) |
| 182 | + beta = F.softmax(logit, dim=-1) |
| 183 | + return beta |
| 184 | + |
| 185 | + def get_nll(self, theta, beta, bows): |
| 186 | + theta = theta.unsqueeze(1) |
| 187 | + loglik = torch.bmm(theta, beta).squeeze(1) |
| 188 | + loglik = loglik |
| 189 | + loglik = torch.log(loglik+1e-6) |
| 190 | + nll = -loglik * bows |
| 191 | + nll = nll.sum(-1) |
| 192 | + return nll |
| 193 | + |
| 194 | + def forward(self, bows, normalized_bows, times, rnn_inp, num_docs): |
| 195 | + bsz = normalized_bows.size(0) |
| 196 | + coeff = num_docs / bsz |
| 197 | + alpha, kl_alpha = self.get_alpha() |
| 198 | + eta, kl_eta = self.get_eta(rnn_inp) |
| 199 | + theta, kl_theta = self.get_theta(eta, normalized_bows, times) |
| 200 | + kl_theta = kl_theta.sum() * coeff |
| 201 | + |
| 202 | + beta = self.get_beta(alpha) |
| 203 | + beta = beta[times.type('torch.LongTensor')] |
| 204 | + nll = self.get_nll(theta, beta, bows) |
| 205 | + nll = nll.sum() * coeff |
| 206 | + nelbo = nll + kl_alpha + kl_eta + kl_theta |
| 207 | + return nelbo, nll, kl_alpha, kl_eta, kl_theta |
| 208 | + |
| 209 | + def init_hidden(self): |
| 210 | + """Initializes the first hidden state of the RNN used as inference network for \eta. |
| 211 | + """ |
| 212 | + weight = next(self.parameters()) |
| 213 | + nlayers = self.eta_nlayers |
| 214 | + nhid = self.eta_hidden_size |
| 215 | + return (weight.new_zeros(nlayers, 1, nhid), weight.new_zeros(nlayers, 1, nhid)) |
0 commit comments