Skip to content

Commit

Permalink
Initialize AAE parameters
Browse files Browse the repository at this point in the history
The function init_encoder_and_training was never called by accident.
Also clean up the display of the AAE log a little more.
  • Loading branch information
jakobnissen committed Jun 25, 2024
1 parent 160bfac commit 0ebd8f5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
1 change: 1 addition & 0 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,7 @@ def __init__(self, args):
batchsteps=args.batchsteps_aae,
temp=args.temp,
)
self.init_encoder_and_training()

def init_encoder_and_training(self):
self.encoder_options = EncoderOptions(
Expand Down
10 changes: 2 additions & 8 deletions vamb/aamb_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
from math import log, isfinite
import time
from torch.autograd import Variable
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical
import torch.nn as nn
Expand Down Expand Up @@ -220,7 +219,7 @@ def trainmodel(
logger.info(f"\tAlpha: {self.alpha}")
logger.info(f"\tY length: {self.y_len}")
logger.info(f"\tZ length: {self.ld}")
logger.info("\n\tTraining properties:")
logger.info("\tTraining properties:")
logger.info(f"\tN epochs: {nepochs}")
logger.info(f"\tStarting batch size: {data_loader.batch_size}")
batchsteps_string = (
Expand Down Expand Up @@ -274,7 +273,6 @@ def trainmodel(
) = (0, 0, 0, 0, 0, 0)

total_batches_inthis_epoch = len(data_loader)
time_epoch_0 = time.time()

# weights, abundances currently unused here
for depths_in, tnfs_in, _, _ in data_loader:
Expand Down Expand Up @@ -392,11 +390,8 @@ def trainmodel(
CE_e += float(ce.item())
SSE_e += float(sse.item())

time_epoch_1 = time.time()
time_e = np.round((time_epoch_1 - time_epoch_0) / 60, 3)

logger.info(
"\tEpoch: {}\t Loss Enc/Dec: {:.6f}\t Rec. loss: {:.4f}\t CE: {:.4f}\tSSE: {:.4f}\t Dz loss: {:.7f}\t Dy loss: {:.6f}\t Batchsize: {}\t Epoch time(min): {: .4}".format(
"\t\tEpoch: {:>3} Loss Enc/Dec: {:.5e} Rec. loss: {:.5e} CE: {:.5e} SSE: {:.5e} Dz loss: {:.5e} Dy loss: {:.5e} Batchsize: {:>4}".format(
epoch_i + 1,
ED_loss_e / total_batches_inthis_epoch,
V_loss_e / total_batches_inthis_epoch,
Expand All @@ -405,7 +400,6 @@ def trainmodel(
D_z_loss_e / total_batches_inthis_epoch,
D_y_loss_e / total_batches_inthis_epoch,
data_loader.batch_size,
time_e,
),
)

Expand Down

0 comments on commit 0ebd8f5

Please sign in to comment.