diff --git a/outputs/StereoSampleGAN-Kick.pth b/outputs/StereoSampleGAN-Kick.pth index fda9f6f..4a8cc61 100644 Binary files a/outputs/StereoSampleGAN-Kick.pth and b/outputs/StereoSampleGAN-Kick.pth differ diff --git a/src/train.py b/src/train.py index 76959d2..30d9f93 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,6 @@ import torch from architecture import LATENT_DIM, Critic, Generator +import numpy as np from torch.optim.rmsprop import RMSprop from torch.optim.lr_scheduler import ReduceLROnPlateau from utils.file_helpers import ( @@ -10,9 +11,11 @@ # Constants N_EPOCHS = 8 -VALIDATION_INTERVAL = 2 +VALIDATION_INTERVAL = 4 SAVE_INTERVAL = int(N_EPOCHS / 1) +LR_G = 0.003 +LR_C = 0.004 LAMBDA_GP = 5 CRITIC_STEPS = 5 @@ -209,9 +212,6 @@ def validate(generator, critic, dataloader, device): def training_loop(train_loader, val_loader): # Initialize models and optimizers - LR_G = 0.003 - LR_C = 0.004 - generator = Generator() critic = Critic() optimizer_G = RMSprop(generator.parameters(), lr=LR_G, weight_decay=0.05) @@ -240,8 +240,11 @@ def training_loop(train_loader, val_loader): f"[{epoch+1}/{N_EPOCHS}] Train - G Loss: {train_g_loss:.6f}, C Loss: {train_c_loss:.6f}" ) - # Validate periodically - if (epoch + 1) % VALIDATION_INTERVAL == 0: + # Validation and saving + early_exit_loss_thresh = 0.2 + early_exit_condition = np.abs(train_g_loss) <= early_exit_loss_thresh + + if (epoch + 1) % VALIDATION_INTERVAL == 0 or early_exit_condition is True: val_g_loss, val_c_loss = validate(generator, critic, val_loader, device) print( f"------ Val ------ G Loss: {val_g_loss:.6f}, C Loss: {val_c_loss:.6f}" @@ -259,7 +262,9 @@ def training_loop(train_loader, val_loader): ) # Save model - if (epoch + 1) % SAVE_INTERVAL == 0 or train_g_loss <= 0.2: - print(f"Training stopped at epoch {epoch+1}, g_loss: {train_g_loss:.6f}") + if (epoch + 1) % SAVE_INTERVAL == 0 or early_exit_condition is True: + print( + f"Training stopped at epoch {epoch+1}, Final g_loss: {train_g_loss:.6f}" + ) save_model(generator, "StereoSampleGAN-Kick") break