Skip to content

Commit

Permalink
margainally better perhaps ?
Browse files Browse the repository at this point in the history
  • Loading branch information
shuklabhay committed Sep 18, 2024
1 parent c661a8c commit 64f033b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
Binary file modified outputs/StereoSampleGAN-Kick.pth
Binary file not shown.
21 changes: 13 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from architecture import LATENT_DIM, Critic, Generator
import numpy as np
from torch.optim.rmsprop import RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.file_helpers import (
Expand All @@ -10,9 +11,11 @@

# Constants
N_EPOCHS = 8
VALIDATION_INTERVAL = 2
VALIDATION_INTERVAL = 4
SAVE_INTERVAL = int(N_EPOCHS / 1)

LR_G = 0.003
LR_C = 0.004
LAMBDA_GP = 5
CRITIC_STEPS = 5

Expand Down Expand Up @@ -209,9 +212,6 @@ def validate(generator, critic, dataloader, device):

def training_loop(train_loader, val_loader):
# Initialize models and optimizers
LR_G = 0.003
LR_C = 0.004

generator = Generator()
critic = Critic()
optimizer_G = RMSprop(generator.parameters(), lr=LR_G, weight_decay=0.05)
Expand Down Expand Up @@ -240,8 +240,11 @@ def training_loop(train_loader, val_loader):
f"[{epoch+1}/{N_EPOCHS}] Train - G Loss: {train_g_loss:.6f}, C Loss: {train_c_loss:.6f}"
)

# Validate periodically
if (epoch + 1) % VALIDATION_INTERVAL == 0:
# Validation and saving
early_exit_loss_thresh = 0.2
early_exit_condition = np.abs(train_g_loss) <= early_exit_loss_thresh

if (epoch + 1) % VALIDATION_INTERVAL == 0 or early_exit_condition is True:
val_g_loss, val_c_loss = validate(generator, critic, val_loader, device)
print(
f"------ Val ------ G Loss: {val_g_loss:.6f}, C Loss: {val_c_loss:.6f}"
Expand All @@ -259,7 +262,9 @@ def training_loop(train_loader, val_loader):
)

# Save model
if (epoch + 1) % SAVE_INTERVAL == 0 or train_g_loss <= 0.2:
print(f"Training stopped at epoch {epoch+1}, g_loss: {train_g_loss:.6f}")
if (epoch + 1) % SAVE_INTERVAL == 0 or early_exit_condition is True:
print(
f"Training stopped at epoch {epoch+1}, Final g_loss: {train_g_loss:.6f}"
)
save_model(generator, "StereoSampleGAN-Kick")
break

0 comments on commit 64f033b

Please sign in to comment.