Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiments in optimising batch steps #180

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions test/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_loss_falls(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
rpkm_copy = self.rpkm.copy()
tnfs_copy = self.tnfs.copy()
dl, mask = vamb.encode.make_dataloader(
dl, _ = vamb.encode.make_dataloader(
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
)
di = torch.Tensor(rpkm_copy)
Expand Down Expand Up @@ -202,10 +202,20 @@ def test_loss_falls(self):
after_encoding = vae_2.encode(dl)
self.assertTrue(np.all(np.abs(before_encoding - after_encoding) < 1e-6))

def test_warn_too_many_batch_steps(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
rpkm_copy = self.rpkm.copy()
tnfs_copy = self.tnfs.copy()
dl, _ = vamb.encode.make_dataloader(
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
)
with self.assertWarns(Warning):
vae.trainmodel(dl, nepochs=4, batchsteps=[1, 2, 3])

def test_encoding(self):
nlatent = 15
vae = vamb.encode.VAE(self.rpkm.shape[1], nlatent=nlatent)
dl, mask = vamb.encode.make_dataloader(
dl, _ = vamb.encode.make_dataloader(
self.rpkm, self.tnfs, self.lens, batchsize=32
)
encoding = vae.encode(dl)
Expand Down
48 changes: 33 additions & 15 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch import Tensor
from torch import nn as _nn
from math import log as _log
from time import time
import warnings

__doc__ = """Encode a depths matrix and a tnf matrix to latent representation.

Expand Down Expand Up @@ -367,6 +369,7 @@ def trainepoch(
epoch: int,
optimizer,
batchsteps: list[int],
start_time: float,
logfile,
) -> _DataLoader[tuple[Tensor, Tensor, Tensor]]:
self.train()
Expand All @@ -377,7 +380,7 @@ def trainepoch(
epoch_celoss = 0.0

if epoch in batchsteps:
data_loader = set_batchsize(data_loader, data_loader.batch_size * 2)
data_loader = set_batchsize(data_loader, data_loader.batch_size * 2) # type: ignore

for depths_in, tnf_in, weights in data_loader:
depths_in.requires_grad = True
Expand Down Expand Up @@ -405,9 +408,11 @@ def trainepoch(
epoch_celoss += ce.data.item()

if logfile is not None:
elapsed = time() - start_time
print(
"\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\tKLD: {:.4f}\tBatchsize: {}".format(
"\tEpoch: {}\tLoss: {:.6f}\tCE: {:.7f}\tSSE: {:.6f}\tKLD: {:.4f}\tBatchsize: {}\tSeconds: {:.2f}".format(
epoch + 1,
elapsed,
epoch_loss / len(data_loader),
epoch_celoss / len(data_loader),
epoch_sseloss / len(data_loader),
Expand Down Expand Up @@ -446,7 +451,7 @@ def encode(self, data_loader) -> _np.ndarray:

row = 0
with _torch.no_grad():
for depths, tnf, weights in new_data_loader:
for depths, tnf, _ in new_data_loader:
# Move input to GPU if requested
if self.usecuda:
depths = depths.cuda()
Expand Down Expand Up @@ -547,28 +552,41 @@ def trainmodel(
if nepochs < 1:
raise ValueError("Minimum 1 epoch, not {nepochs}")

if batchsteps is None:
batchsteps_set: set[int] = set()
if batchsteps is None or len(batchsteps) == 0:
sorted_batch_steps: list[int] = []
else:
# First collect to list in order to allow all element types, then check that
# they are integers
batchsteps = list(batchsteps)
if not all(isinstance(i, int) for i in batchsteps):
raise ValueError("All elements of batchsteps must be integers")
if max(batchsteps, default=0) >= nepochs:
sorted_batch_steps = sorted(set(batchsteps))
if sorted_batch_steps[0] < 1:
raise ValueError(
f"Minimum of batchsteps must be 1, not {sorted_batch_steps[0]}"
)
if sorted_batch_steps[-1] >= nepochs:
raise ValueError("Max batchsteps must not equal or exceed nepochs")
last_batchsize = dataloader.batch_size * 2 ** len(batchsteps)
if len(dataloader.dataset) < last_batchsize: # type: ignore

n_contigs = len(dataloader.dataset) # type: ignore
starting_batch_size: int = dataloader.batch_size # type: ignore
if n_contigs < starting_batch_size:
raise ValueError(
f"Last batch size of {last_batchsize} exceeds dataset length "
f"of {len(dataloader.dataset)}. " # type: ignore
f"Starting batch size of {starting_batch_size} exceeds dataset length "
f"of {n_contigs}. "
"This means you have too few contigs left after filtering to train. "
"It is not adviced to run Vamb with fewer than 10,000 sequences "
"after filtering. "
"Please check the Vamb log file to see where the sequences were "
"filtered away, and verify BAM files has sensible content."
)
batchsteps_set = set(batchsteps)
maximum_batch_steps = (n_contigs // starting_batch_size).bit_length() - 1
if maximum_batch_steps < len(sorted_batch_steps):
warnings.warn(
f"Requested {len(sorted_batch_steps)} batch steps, but with a starting "
f"batch size of {starting_batch_size} and {n_contigs} contigs, "
f"only the first {maximum_batch_steps} batch steps can be used."
)
sorted_batch_steps = sorted_batch_steps[:maximum_batch_steps]

# Get number of features
# Following line is un-inferrable due to typing problems with DataLoader
Expand All @@ -587,8 +605,8 @@ def trainmodel(
print("\tN epochs:", nepochs, file=logfile)
print("\tStarting batch size:", dataloader.batch_size, file=logfile)
batchsteps_string = (
", ".join(map(str, sorted(batchsteps_set)))
if batchsteps_set
", ".join(map(str, sorted_batch_steps))
if len(sorted_batch_steps) > 0
else "None"
)
print("\tBatchsteps:", batchsteps_string, file=logfile)
Expand All @@ -599,7 +617,7 @@ def trainmodel(
# Train
for epoch in range(nepochs):
dataloader = self.trainepoch(
dataloader, epoch, optimizer, sorted(batchsteps_set), logfile
dataloader, epoch, optimizer, sorted_batch_steps, time(), logfile
)

# Save weights - Lord forgive me, for I have sinned when catching all exceptions
Expand Down