diff --git a/configs/experiment/pvr.yaml b/configs/experiment/pvr.yaml index c4bf1bf..5ad4460 100644 --- a/configs/experiment/pvr.yaml +++ b/configs/experiment/pvr.yaml @@ -5,7 +5,7 @@ defaults: - override /data: pvr - - override /model: gpt2 # transformer_dbn_classifier, gpt2 + - override /model: transformer_dbn_classifier # transformer_dbn_classifier, gpt2 - override /callbacks: default - override /trainer: default @@ -28,7 +28,7 @@ trainer: model: optimizer: - lr: 0.001 + lr: 0.00001 data: diff --git a/configs/model/transformer_dbn_classifier.yaml b/configs/model/transformer_dbn_classifier.yaml index c1a21f6..4f6788f 100644 --- a/configs/model/transformer_dbn_classifier.yaml +++ b/configs/model/transformer_dbn_classifier.yaml @@ -27,8 +27,8 @@ nn: _target_: src.models.components.transformer.TransformerDBN embedding_dim: 256 output_dim: ${model.nn.num_embedding} - dbn_after_each_layer: False - dbn_last_layer: False + dbn_after_each_layer: True + dbn_last_layer: True shared_embedding_dbn: True num_embedding: 10 seq_len: 11 # TODO: set this automatically based on the data config file or take it form some higher level folder. diff --git a/src/models/components/transformer.py b/src/models/components/transformer.py index 365a94e..1575f43 100644 --- a/src/models/components/transformer.py +++ b/src/models/components/transformer.py @@ -127,6 +127,7 @@ def forward(self, inputs): inputs = inputs.int() x = self.token_embedding(inputs) b, n, _ = x.shape + bottleneck_loss = 0 cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) @@ -135,14 +136,15 @@ def forward(self, inputs): for layer in self.layers: if isinstance(layer, AbstractDiscreteLayer): - indices, probs, x, vq_loss = layer(x, supervision=self.hparams['supervision']) # TODO: for now I'm adding this to the config file... + indices, probs, x, disc_loss = layer(x, supervision=self.hparams['supervision']) # TODO: for now I'm adding this to the config file... + bottleneck_loss += disc_loss else: x = layer(x) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x) - return self.mlp_head(x) + return self.mlp_head(x), bottleneck_loss class TokenTransformer(nn.Module): diff --git a/src/models/transformer_dbn_classifier.py b/src/models/transformer_dbn_classifier.py index 98bf8d0..a599b53 100644 --- a/src/models/transformer_dbn_classifier.py +++ b/src/models/transformer_dbn_classifier.py @@ -74,8 +74,8 @@ def model_step( - A tensor of target labels. """ x, y = batch - logits = self.forward(x) - loss = self.criterion(logits, y) + logits, disc_loss = self.forward(x) + loss = self.criterion(logits, y) + disc_loss preds = torch.argmax(logits, dim=1) return loss, preds, y