Skip to content

Commit

Permalink
scheduler selling
Browse files Browse the repository at this point in the history
  • Loading branch information
shuklabhay committed Oct 1, 2024
1 parent 473805d commit 98b1c03
Showing 1 changed file with 13 additions and 50 deletions.
63 changes: 13 additions & 50 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from architecture import LATENT_DIM, Critic, Generator
import numpy as np
from torch.optim.rmsprop import RMSprop
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.file_helpers import (
get_device,
save_model,
Expand All @@ -21,38 +21,6 @@
CRITIC_STEPS = 5


class WDistLRScheduler:
def __init__(self, optimizer, init_lr, beta=0.9, min_lr_factor=0.1):
self.optimizer = optimizer
self.init_lr = init_lr
self.beta = beta
self.min_lr_factor = min_lr_factor
self.max_w_dist = None
self.scheduler = LambdaLR(optimizer, self.lr_lambda)

def lr_lambda(self, _):
if self.max_w_dist is None:
return 1.0
scaling_factor = 1.0 - (1.0 - self.min_lr_factor) * np.exp(
-3 * (self.current_w_dist / self.max_w_dist)
)
return max(self.min_lr_factor, scaling_factor)

def step(self, w_dist):
if self.max_w_dist is None:
self.max_w_dist = w_dist
else:
self.max_w_dist = self.beta * self.max_w_dist + (1 - self.beta) * max(
w_dist, self.max_w_dist
)

self.current_w_dist = w_dist
self.scheduler.step()

def get_last_lr(self):
return self.scheduler.get_last_lr()


# Loss metrics
def compute_g_loss(critic, fake_validity, fake_audio_data, real_audio_data):
wasserstein_dist = -torch.mean(fake_validity)
Expand Down Expand Up @@ -204,21 +172,21 @@ def train_epoch(

total_g_loss += g_loss.item()

# # Save training progress image
# if i % (CRITIC_STEPS * 14) == 0:
# fake_audio_to_visualize = fake_audio_data[0].cpu().detach().numpy()
# graph_spectrogram(
# fake_audio_to_visualize,
# f"generator_epoch_{epoch_number + 1}_step_{i}.png",
# True,
# )
# Save training progress image
if i % (CRITIC_STEPS * 14) == 0:
fake_audio_to_visualize = fake_audio_data[0].cpu().detach().numpy()
graph_spectrogram(
fake_audio_to_visualize,
f"diverse_generator_epoch_{epoch_number + 1}_step_{i}.png",
True,
)

avg_g_loss = total_g_loss / len(dataloader)
avg_c_loss = total_c_loss / len(dataloader)
avg_w_dist = total_w_dist / len(dataloader)

scheduler_G.step(avg_g_loss)
scheduler_C.step(avg_c_loss)
scheduler_G.step(avg_w_dist)
scheduler_C.step(avg_w_dist)

return avg_g_loss, avg_c_loss, avg_w_dist

Expand Down Expand Up @@ -272,8 +240,8 @@ def training_loop(train_loader, val_loader):
optimizer_G = RMSprop(generator.parameters(), lr=LR_G, weight_decay=0.05)
optimizer_C = RMSprop(critic.parameters(), lr=LR_C, weight_decay=0.05)

scheduler_G = WDistLRScheduler(optimizer_G, LR_G)
scheduler_C = WDistLRScheduler(optimizer_C, LR_C)
scheduler_G = ReduceLROnPlateau(optimizer_G, patience=2, factor=0.5)
scheduler_C = ReduceLROnPlateau(optimizer_C, patience=2, factor=0.5)

# Train
device = get_device()
Expand Down Expand Up @@ -308,11 +276,6 @@ def training_loop(train_loader, val_loader):
f"------ Val ------ G Loss: {val_g_loss:.6f}, C Loss: {val_c_loss:.6f}, W Dist: {val_w_dist:.6f}"
)

print(
f"G lr: {scheduler_G.get_last_lr()[0]:.6f}",
f"C lr: {scheduler_C.get_last_lr()[0]:.6f}",
)

# Generate example audio
if (epoch + 1) % SHOW_GENERATED_INTERVAL == 0:
examples_to_generate = 3
Expand Down

0 comments on commit 98b1c03

Please sign in to comment.