diff --git a/configs/model/transformer_dbn_classifier.yaml b/configs/model/transformer_dbn_classifier.yaml index 23f770e..f47e9ee 100644 --- a/configs/model/transformer_dbn_classifier.yaml +++ b/configs/model/transformer_dbn_classifier.yaml @@ -27,9 +27,9 @@ nn: _target_: src.models.components.transformer.TransformerDBN embedding_dim: 256 output_dim: ${model.nn.num_embedding} - dbn_after_each_layer: True + dbn_after_each_layer: False dbn_last_layer: True - shared_embedding_dbn: True + shared_embedding_dbn: False num_embedding: 10 seq_len: 11 # TODO: set this automatically based on the data config file or take it form some higher level folder. emb_dropout: 0.1 diff --git a/src/models/components/discrete_layers/vqvae.py b/src/models/components/discrete_layers/vqvae.py index 61db136..cb0dc22 100644 --- a/src/models/components/discrete_layers/vqvae.py +++ b/src/models/components/discrete_layers/vqvae.py @@ -30,8 +30,8 @@ def project_matrix(self,x): return x def discretize(self, x, **kwargs) -> dict: - probs = self.kernel( - self.codebook_distances(x) / self.temperature) x = self.project_matrix(x) + probs = self.kernel( - self.codebook_distances(x) / self.temperature) indices = torch.argmax(probs, dim=-1) if self.hard: diff --git a/src/models/transformer_dbn_classifier.py b/src/models/transformer_dbn_classifier.py index 6130c5b..07a698d 100644 --- a/src/models/transformer_dbn_classifier.py +++ b/src/models/transformer_dbn_classifier.py @@ -97,7 +97,7 @@ def training_step( self.train_loss(loss) self.train_acc(preds, targets) self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) - self.log("tain/disc_loss", disc_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/disc_loss", disc_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) # return loss or backpropagation will fail