Skip to content

Commit

Permalink
it kinda figures it out?
Browse files Browse the repository at this point in the history
  • Loading branch information
shuklabhay committed Sep 27, 2024
1 parent ce694f5 commit 3fcc1cf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
16 changes: 10 additions & 6 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# Constants
N_EPOCHS = 8
N_EPOCHS = 14
SHOW_GENERATED_INTERVAL = 4
SAVE_INTERVAL = int(N_EPOCHS / 1)

Expand Down Expand Up @@ -248,7 +248,6 @@ def training_loop(train_loader, val_loader):
generator.to(device)
critic.to(device)

best_val_w_dist = 0.0
epochs_no_improve = 0
patience = 5 # epochs
for epoch in range(N_EPOCHS):
Expand All @@ -264,13 +263,13 @@ def training_loop(train_loader, val_loader):
device,
epoch,
)
val_g_loss, val_c_loss, val_w_dist = validate(
generator, critic, val_loader, device
)

print(
f"[{epoch+1}/{N_EPOCHS}] Train - G Loss: {train_g_loss:.6f}, C Loss: {train_c_loss:.6f}, W Dist: {train_w_dist:.6f}"
)

val_g_loss, val_c_loss, val_w_dist = validate(
generator, critic, val_loader, device
)
print(
f"------ Val ------ G Loss: {val_g_loss:.6f}, C Loss: {val_c_loss:.6f}, W Dist: {val_w_dist:.6f}"
)
Expand All @@ -289,10 +288,15 @@ def training_loop(train_loader, val_loader):
)

# Early exit/saving
if (epoch + 1) == 0:
best_val_w_dist = np.abs(val_w_dist)
print(f"initialized best_val_w_dist at {np.abs(val_w_dist)}")

if np.abs(val_w_dist) < best_val_w_dist:
best_val_w_dist = np.abs(val_w_dist)
epochs_no_improve = 0
save_model(generator)
print(f"Model saved at w_dist={val_w_dist:.6f}")
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
Expand Down
4 changes: 2 additions & 2 deletions src/usage_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
training_sample_length = 0.6 # seconds
outputs_dir = "outputs" # Where to save your generated audio & model

model_save_name = "StereoSampleGAN-CuratedKick" # What to name your model save
training_audio_dir = "data/kick_samples_curated" # Your training data path
model_save_name = "StereoSampleGAN-DiverseKick" # What to name your model save
training_audio_dir = "data/kick_samples_diverse" # Your training data path
compiled_data_path = "data/compiled_data.npy" # Your compiled data/output path
model_save_path = f"{outputs_dir}/{model_save_name}.pth"

Expand Down

0 comments on commit 3fcc1cf

Please sign in to comment.