Skip to content

Commit

Permalink
making some edits regarding the bottleneck loss propogation
Browse files Browse the repository at this point in the history
  • Loading branch information
aryol committed Dec 12, 2023
1 parent 3289ed3 commit fe447f1
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions configs/experiment/pvr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

defaults:
- override /data: pvr
- override /model: gpt2 # transformer_dbn_classifier, gpt2
- override /model: transformer_dbn_classifier # transformer_dbn_classifier, gpt2
- override /callbacks: default
- override /trainer: default

Expand All @@ -28,7 +28,7 @@ trainer:

model:
optimizer:
lr: 0.001
lr: 0.00001


data:
Expand Down
4 changes: 2 additions & 2 deletions configs/model/transformer_dbn_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ nn:
_target_: src.models.components.transformer.TransformerDBN
embedding_dim: 256
output_dim: ${model.nn.num_embedding}
dbn_after_each_layer: False
dbn_last_layer: False
dbn_after_each_layer: True
dbn_last_layer: True
shared_embedding_dbn: True
num_embedding: 10
seq_len: 11 # TODO: set this automatically based on the data config file or take it form some higher level folder.
Expand Down
6 changes: 4 additions & 2 deletions src/models/components/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def forward(self, inputs):
inputs = inputs.int()
x = self.token_embedding(inputs)
b, n, _ = x.shape
bottleneck_loss = 0

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
Expand All @@ -135,14 +136,15 @@ def forward(self, inputs):

for layer in self.layers:
if isinstance(layer, AbstractDiscreteLayer):
indices, probs, x, vq_loss = layer(x, supervision=self.hparams['supervision']) # TODO: for now I'm adding this to the config file...
indices, probs, x, disc_loss = layer(x, supervision=self.hparams['supervision']) # TODO: for now I'm adding this to the config file...
bottleneck_loss += disc_loss
else:
x = layer(x)

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)
return self.mlp_head(x), bottleneck_loss


class TokenTransformer(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions src/models/transformer_dbn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def model_step(
- A tensor of target labels.
"""
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
logits, disc_loss = self.forward(x)
loss = self.criterion(logits, y) + disc_loss
preds = torch.argmax(logits, dim=1)
return loss, preds, y

Expand Down

0 comments on commit fe447f1

Please sign in to comment.