Skip to content

Commit

Permalink
fixed vqvae projection error
Browse files Browse the repository at this point in the history
  • Loading branch information
mh-amani committed Dec 12, 2023
1 parent 9e07d68 commit 3d2dec2
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions configs/model/transformer_dbn_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ nn:
_target_: src.models.components.transformer.TransformerDBN
embedding_dim: 256
output_dim: ${model.nn.num_embedding}
dbn_after_each_layer: True
dbn_after_each_layer: False
dbn_last_layer: True
shared_embedding_dbn: True
shared_embedding_dbn: False
num_embedding: 10
seq_len: 11 # TODO: set this automatically based on the data config file or take it form some higher level folder.
emb_dropout: 0.1
Expand Down
2 changes: 1 addition & 1 deletion src/models/components/discrete_layers/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def project_matrix(self,x):
return x

def discretize(self, x, **kwargs) -> dict:
probs = self.kernel( - self.codebook_distances(x) / self.temperature)
x = self.project_matrix(x)
probs = self.kernel( - self.codebook_distances(x) / self.temperature)
indices = torch.argmax(probs, dim=-1)

if self.hard:
Expand Down
2 changes: 1 addition & 1 deletion src/models/transformer_dbn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def training_step(
self.train_loss(loss)
self.train_acc(preds, targets)
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("tain/disc_loss", disc_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("train/disc_loss", disc_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)

# return loss or backpropagation will fail
Expand Down

0 comments on commit 3d2dec2

Please sign in to comment.