Skip to content

Commit

Permalink
Use Dadaptation for all DL models
Browse files Browse the repository at this point in the history
This optimizer is more efficient, and makes the learning rate obsolete, which
is nice from a UI standpoint.
  • Loading branch information
jakobnissen committed Jun 25, 2024
1 parent 0ebd8f5 commit f580c4d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 32 deletions.
2 changes: 0 additions & 2 deletions test/test_aamb_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class TestAAE(unittest.TestCase):
nlatent_l = 32
default_args = (14, 256, nlatent_l, 25, 0.5, 0.5, 0.15, False, 0)
default_temperature = 0.16
default_lr = 0.001

# Construction
def test_bad_args(self):
Expand Down Expand Up @@ -60,7 +59,6 @@ def test_loss_falls(self):
nepochs=3,
batchsteps=[1, 2],
T=self.default_temperature,
lr=self.default_lr,
modelfile=None,
)
mu, do, to, _, _, _, _ = aae(di, ti)
Expand Down
24 changes: 9 additions & 15 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,20 +433,20 @@ def __init__(
encoder_options: EncoderOptions,
vae_options: Optional[VAETrainingOptions],
aae_options: Optional[AAETrainingOptions],
lrate: float,
lrate: Optional[float],
):
assert isinstance(lrate, float)
assert isinstance(lrate, (type(None), float))

assert (encoder_options.vae_options is None) == (vae_options is None)
assert (encoder_options.aae_options is None) == (aae_options is None)

if lrate <= 0.0:
raise argparse.ArgumentTypeError("Learning rate must be positive")
self.lrate = lrate
if lrate is not None:
logger.warning(
"The --lrate argument is deprecated, and has no effect in Vamb 5 onwards"
)

self.vae_options = vae_options
self.aae_options = aae_options
self.lrate = lrate


class ClusterOptions:
Expand Down Expand Up @@ -660,7 +660,6 @@ def trainvae(
vae_options: VAEOptions,
training_options: VAETrainingOptions,
vamb_options: VambOptions,
lrate: float,
alpha: Optional[float],
data_loader: DataLoader,
) -> np.ndarray:
Expand All @@ -684,7 +683,6 @@ def trainvae(
vae.trainmodel(
vamb.encode.set_batchsize(data_loader, training_options.batchsize),
nepochs=training_options.nepochs,
lrate=lrate,
batchsteps=training_options.batchsteps,
modelfile=modelpath,
)
Expand All @@ -705,7 +703,6 @@ def trainaae(
aae_options: AAEOptions,
training_options: AAETrainingOptions,
vamb_options: VambOptions,
lrate: float,
alpha: Optional[float], # set automatically if None
contignames: Sequence[str],
) -> tuple[np.ndarray, dict[str, set[str]]]:
Expand All @@ -732,7 +729,6 @@ def trainaae(
training_options.nepochs,
training_options.batchsteps,
training_options.temp,
lrate,
modelpath,
)

Expand Down Expand Up @@ -933,7 +929,6 @@ def run(
vae_options=vae_options,
training_options=vae_training_options,
vamb_options=vamb_options,
lrate=training_options.lrate,
alpha=encoder_options.alpha,
data_loader=data_loader,
)
Expand All @@ -948,7 +943,6 @@ def run(
aae_options=aae_options,
vamb_options=vamb_options,
training_options=aae_training_options,
lrate=training_options.lrate,
alpha=encoder_options.alpha,
contignames=composition.metadata.identifiers, # type:ignore
)
Expand Down Expand Up @@ -1827,9 +1821,9 @@ def add_vae_arguments(subparser):
"-r",
dest="lrate",
metavar="",
type=float,
default=1e-3,
help="learning rate [0.001]",
type=Optional[float],
default=None,
help=argparse.SUPPRESS,
)
return subparser

Expand Down
10 changes: 5 additions & 5 deletions vamb/aamb_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional, IO, Union
from numpy.typing import NDArray
from loguru import logger
import dadaptation


############################################################################# MODEL ###########################################################
Expand Down Expand Up @@ -206,7 +207,6 @@ def trainmodel(
nepochs: int,
batchsteps: list[int],
T,
lr: float,
modelfile: Union[None, str, IO[bytes]] = None,
):
Tensor = torch.cuda.FloatTensor if self.usecuda else torch.FloatTensor
Expand Down Expand Up @@ -251,11 +251,11 @@ def trainmodel(
adversarial_loss.cuda()

#### Optimizers
optimizer_E = torch.optim.Adam(enc_params, lr=lr)
optimizer_D = torch.optim.Adam(dec_params, lr=lr)
optimizer_E = dadaptation.DAdaptAdam(enc_params, decouple=True)
optimizer_D = dadaptation.DAdaptAdam(dec_params, decouple=True)

optimizer_D_z = torch.optim.Adam(disc_z_params, lr=lr)
optimizer_D_y = torch.optim.Adam(disc_y_params, lr=lr)
optimizer_D_z = dadaptation.DAdaptAdam(disc_z_params, decouple=True)
optimizer_D_y = dadaptation.DAdaptAdam(disc_y_params, decouple=True)

for epoch_i in range(nepochs):
if epoch_i in batchsteps:
Expand Down
13 changes: 3 additions & 10 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch.utils.data.dataset import TensorDataset as _TensorDataset
from torch.utils.data import DataLoader as _DataLoader
from torch.nn.functional import softmax as _softmax
from torch.optim import Adam as _Adam
from torch import Tensor
import dadaptation
from torch import nn as _nn
from math import log as _log
from loguru import logger
Expand Down Expand Up @@ -155,7 +155,7 @@ class VAE(_nn.Module):
dropout: Probability of dropout on forward pass [0.2]
cuda: Use CUDA (GPU accelerated training) [False]
vae.trainmodel(dataloader, nepochs batchsteps, lrate, modelfile)
vae.trainmodel(dataloader, nepochs batchsteps, modelfile)
Trains the model, returning None
vae.encode(self, data_loader):
Expand Down Expand Up @@ -535,7 +535,6 @@ def trainmodel(
self,
dataloader: _DataLoader[tuple[Tensor, Tensor, Tensor]],
nepochs: int = 500,
lrate: float = 1e-3,
batchsteps: Optional[list[int]] = [25, 75, 150, 300],
modelfile: Union[None, str, Path, IO[bytes]] = None,
):
Expand All @@ -544,16 +543,11 @@ def trainmodel(
Inputs:
dataloader: DataLoader made by make_dataloader
nepochs: Train for this many epochs before encoding [500]
lrate: Starting learning rate for the optimizer [0.001]
batchsteps: None or double batchsize at these epochs [25, 75, 150, 300]
modelfile: Save models to this file if not None [None]
Output: None
"""

if lrate < 0:
raise ValueError(f"Learning rate must be positive, not {lrate}")

if nepochs < 1:
raise ValueError("Minimum 1 epoch, not {nepochs}")

Expand All @@ -572,7 +566,7 @@ def trainmodel(
# Get number of features
# Following line is un-inferrable due to typing problems with DataLoader
ncontigs, nsamples = dataloader.dataset.tensors[0].shape # type: ignore
optimizer = _Adam(self.parameters(), lr=lrate)
optimizer = dadaptation.DAdaptAdam(self.parameters(), decouple=True)

logger.info("\tNetwork properties:")
logger.info(f"\tCUDA: {self.usecuda}")
Expand All @@ -588,7 +582,6 @@ def trainmodel(
", ".join(map(str, sorted(batchsteps_set))) if batchsteps_set else "None"
)
logger.info(f"\tBatchsteps: {batchsteps_string}")
logger.info(f"\tLearning rate: {lrate}")
logger.info(f"\tN sequences: {ncontigs}")
logger.info(f"\tN samples: {nsamples}")

Expand Down

0 comments on commit f580c4d

Please sign in to comment.