diff --git a/src/train.py b/src/train.py index 55741eb..1c1d910 100644 --- a/src/train.py +++ b/src/train.py @@ -2,7 +2,7 @@ from architecture import LATENT_DIM, Critic, Generator import numpy as np from torch.optim.rmsprop import RMSprop -from torch.optim.lr_scheduler import LambdaLR +from torch.optim.lr_scheduler import ReduceLROnPlateau from utils.file_helpers import ( get_device, save_model, @@ -21,38 +21,6 @@ CRITIC_STEPS = 5 -class WDistLRScheduler: - def __init__(self, optimizer, init_lr, beta=0.9, min_lr_factor=0.1): - self.optimizer = optimizer - self.init_lr = init_lr - self.beta = beta - self.min_lr_factor = min_lr_factor - self.max_w_dist = None - self.scheduler = LambdaLR(optimizer, self.lr_lambda) - - def lr_lambda(self, _): - if self.max_w_dist is None: - return 1.0 - scaling_factor = 1.0 - (1.0 - self.min_lr_factor) * np.exp( - -3 * (self.current_w_dist / self.max_w_dist) - ) - return max(self.min_lr_factor, scaling_factor) - - def step(self, w_dist): - if self.max_w_dist is None: - self.max_w_dist = w_dist - else: - self.max_w_dist = self.beta * self.max_w_dist + (1 - self.beta) * max( - w_dist, self.max_w_dist - ) - - self.current_w_dist = w_dist - self.scheduler.step() - - def get_last_lr(self): - return self.scheduler.get_last_lr() - - # Loss metrics def compute_g_loss(critic, fake_validity, fake_audio_data, real_audio_data): wasserstein_dist = -torch.mean(fake_validity) @@ -204,21 +172,21 @@ def train_epoch( total_g_loss += g_loss.item() - # # Save training progress image - # if i % (CRITIC_STEPS * 14) == 0: - # fake_audio_to_visualize = fake_audio_data[0].cpu().detach().numpy() - # graph_spectrogram( - # fake_audio_to_visualize, - # f"generator_epoch_{epoch_number + 1}_step_{i}.png", - # True, - # ) + # Save training progress image + if i % (CRITIC_STEPS * 14) == 0: + fake_audio_to_visualize = fake_audio_data[0].cpu().detach().numpy() + graph_spectrogram( + fake_audio_to_visualize, + f"diverse_generator_epoch_{epoch_number + 1}_step_{i}.png", + True, + ) avg_g_loss = total_g_loss / len(dataloader) avg_c_loss = total_c_loss / len(dataloader) avg_w_dist = total_w_dist / len(dataloader) - scheduler_G.step(avg_g_loss) - scheduler_C.step(avg_c_loss) + scheduler_G.step(avg_w_dist) + scheduler_C.step(avg_w_dist) return avg_g_loss, avg_c_loss, avg_w_dist @@ -272,8 +240,8 @@ def training_loop(train_loader, val_loader): optimizer_G = RMSprop(generator.parameters(), lr=LR_G, weight_decay=0.05) optimizer_C = RMSprop(critic.parameters(), lr=LR_C, weight_decay=0.05) - scheduler_G = WDistLRScheduler(optimizer_G, LR_G) - scheduler_C = WDistLRScheduler(optimizer_C, LR_C) + scheduler_G = ReduceLROnPlateau(optimizer_G, patience=2, factor=0.5) + scheduler_C = ReduceLROnPlateau(optimizer_C, patience=2, factor=0.5) # Train device = get_device() @@ -308,11 +276,6 @@ def training_loop(train_loader, val_loader): f"------ Val ------ G Loss: {val_g_loss:.6f}, C Loss: {val_c_loss:.6f}, W Dist: {val_w_dist:.6f}" ) - print( - f"G lr: {scheduler_G.get_last_lr()[0]:.6f}", - f"C lr: {scheduler_C.get_last_lr()[0]:.6f}", - ) - # Generate example audio if (epoch + 1) % SHOW_GENERATED_INTERVAL == 0: examples_to_generate = 3