Skip to content

Commit

Permalink
diverse kick model
Browse files Browse the repository at this point in the history
  • Loading branch information
shuklabhay committed Sep 27, 2024
1 parent 080a71b commit 378613e
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 26 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,21 @@ Specify training data paramaters in `usage_params.py`
<1 sec long (longer hasn't been fully tested)
- Prepare training data by running `python3 src/data_processing/encode_audio_data.py`
- Train model by running `python3 src/stereo_sample_gan.py`
- You might have to tune `N_EPOCHS` and `LR_G` in `train.py` based on amount of data and results running with different numbers of epochs.
- Generate audio (based on current `usage_params.py`) by running `python3 src/generate.py`

Diverse Kick Drum Model Training progress (8 epochs):

<img src="static/diverse_kick_training_progress.gif" alt="Diverse kick training progress" width="400">

## Pretrained Models

### Diverse Kick Drums
### Diverse Kick Drum

Kick drum generation model trained on ~8000 essentially random kick drums.

- More variation between each generated sample but audio is often inconsistent and contains some artifacts.

Training progress (8 epochs):

<img src="static/diverse_kick_training_progress.gif" alt="Diverse kick training progress" width="400">

### Diverse Kick Drums
### Curated Kick Drum

Kick drum generation model trained on ~4400 slightly more rigorously but still essentially randomly chosen kick drums.

Expand Down
Binary file modified outputs/StereoSampleGAN-DiverseKick.pth
Binary file not shown.
15 changes: 12 additions & 3 deletions src/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,25 @@ def __init__(self):
nn.ConvTranspose2d(LATENT_DIM, 128, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 128, 4, 4)
nn.Dropout(0.3),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 64, 8, 8)
nn.Dropout(0.3),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
# LinearAttention(32),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 32, 16, 16)
nn.Dropout(0.3),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 16, 32, 32)
nn.Dropout(0.3),
nn.ConvTranspose2d(16, 8, kernel_size=6, stride=4, padding=1),
nn.BatchNorm2d(8),
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 8, 128, 128)
nn.Dropout(0.3),
nn.ConvTranspose2d(8, N_CHANNELS, kernel_size=6, stride=2, padding=2),
# Shape: (BATCH_SIZE, N_CHANNELS, 256, 256)
nn.Tanh(),
nn.Tanh(), # Shape: (BATCH_SIZE, N_CHANNELS, 256, 256)
)

def forward(self, z):
Expand All @@ -75,22 +78,28 @@ def __init__(self):
self.conv_blocks = nn.Sequential(
spectral_norm(nn.Conv2d(N_CHANNELS, 4, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
spectral_norm(nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(8),
nn.Dropout(0.3),
spectral_norm(nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(16),
nn.Dropout(0.3),
LinearAttention(16),
spectral_norm(nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
nn.Dropout(0.3),
spectral_norm(nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
nn.Dropout(0.3),
spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(128),
nn.Dropout(0.3),
spectral_norm(nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=0)),
nn.Flatten(),
)
Expand Down
27 changes: 13 additions & 14 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SHOW_GENERATED_INTERVAL = 4
SAVE_INTERVAL = int(N_EPOCHS / 1)

LR_G = 0.001
LR_G = 0.003
LR_C = 0.004
LAMBDA_GP = 5
CRITIC_STEPS = 5
Expand Down Expand Up @@ -172,14 +172,14 @@ def train_epoch(

total_g_loss += g_loss.item()

# Save training progress image
if i % (CRITIC_STEPS * 14) == 0:
fake_audio_to_visualize = fake_audio_data[0].cpu().detach().numpy()
graph_spectrogram(
fake_audio_to_visualize,
f"generator_epoch_{epoch_number + 1}_step_{i}.png",
True,
)
# # Save training progress image
# if i % (CRITIC_STEPS * 14) == 0:
# fake_audio_to_visualize = fake_audio_data[0].cpu().detach().numpy()
# graph_spectrogram(
# fake_audio_to_visualize,
# f"generator_epoch_{epoch_number + 1}_step_{i}.png",
# True,
# )

avg_g_loss = total_g_loss / len(dataloader)
avg_c_loss = total_c_loss / len(dataloader)
Expand Down Expand Up @@ -248,6 +248,7 @@ def training_loop(train_loader, val_loader):
generator.to(device)
critic.to(device)

best_val_w_dist = float("inf") # Initialize
epochs_no_improve = 0
patience = 5 # epochs
for epoch in range(N_EPOCHS):
Expand All @@ -274,7 +275,7 @@ def training_loop(train_loader, val_loader):
f"------ Val ------ G Loss: {val_g_loss:.6f}, C Loss: {val_c_loss:.6f}, W Dist: {val_w_dist:.6f}"
)

# Display example audio
# Generate example audio
if (epoch + 1) % SHOW_GENERATED_INTERVAL == 0:
examples_to_generate = 3
z = torch.randn(examples_to_generate, LATENT_DIM, 1, 1).to(device)
Expand All @@ -288,16 +289,14 @@ def training_loop(train_loader, val_loader):
)

# Early exit/saving
if (epoch + 1) == 0:
best_val_w_dist = np.abs(val_w_dist)
print(f"initialized best_val_w_dist at {np.abs(val_w_dist)}")
elif best_val_w_dist and np.abs(val_w_dist) < best_val_w_dist:
if np.abs(val_w_dist) < best_val_w_dist:
best_val_w_dist = np.abs(val_w_dist)
epochs_no_improve = 0
save_model(generator)
print(f"Model saved at w_dist={val_w_dist:.6f}")
else:
epochs_no_improve += 1
print(f"epochs without improvement: {epochs_no_improve}")
if epochs_no_improve >= patience:
print("Early stopping triggered")
break
4 changes: 2 additions & 2 deletions src/usage_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
training_sample_length = 0.6 # seconds
outputs_dir = "outputs" # Where to save your generated audio & model

model_save_name = "StereoSampleGAN-DiverseKick" # What to name your model save
training_audio_dir = "data/kick_samples_diverse" # Your training data path
model_save_name = "StereoSampleGAN-CuratedKick" # What to name your model save
training_audio_dir = "data/kick_samples_curated" # Your training data path
compiled_data_path = "data/compiled_data.npy" # Your compiled data/output path
model_save_path = f"{outputs_dir}/{model_save_name}.pth"

Expand Down

0 comments on commit 378613e

Please sign in to comment.