Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hyperparam tuning #18

Merged
merged 8 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified outputs/StereoSampleGAN-Kick.pth
Binary file not shown.
Binary file added outputs/StereoSampleGAN-OldKick.pth
Binary file not shown.
2 changes: 2 additions & 0 deletions paper/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ two foundational changes to making task happen with deep conv: discrimiator real

This model seeks to replicate a DCGAN's multi-channel image generation capabilities[1] to create varied two channel audio representations.

I WAS FACING THE LONG RANGE DEPENDENCY PROBLEM, IT DIDNT LEARN THE OVERALL TREND AND USING ATTENION FIXED IT. good realization ig idk, mention it here tho like attention specificallyt fixed this

### 4.2. Training

This work uses 80% of the dataset as training data and 20% as validation with all data split into batches of 16.
Expand Down
10 changes: 5 additions & 5 deletions src/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ def __init__(self):
self.deconv_blocks = nn.Sequential(
nn.ConvTranspose2d(LATENT_DIM, 128, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(128),
nn.ReLU(), # Shape: (BATCH_SIZE, 128, 4, 4)
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 128, 4, 4)
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(), # Shape: (BATCH_SIZE, 64, 8, 8)
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 64, 8, 8)
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
# LinearAttention(32),
nn.BatchNorm2d(32),
nn.ReLU(), # Shape: (BATCH_SIZE, 32, 16, 16)
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 32, 16, 16)
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(), # Shape: (BATCH_SIZE, 16, 32, 32)
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 16, 32, 32)
nn.ConvTranspose2d(16, 8, kernel_size=6, stride=4, padding=1),
nn.BatchNorm2d(8),
nn.ReLU(), # Shape: (BATCH_SIZE, 8, 128, 128)
nn.LeakyReLU(0.2), # Shape: (BATCH_SIZE, 8, 128, 128)
nn.ConvTranspose2d(8, N_CHANNELS, kernel_size=6, stride=2, padding=2),
# Shape: (BATCH_SIZE, N_CHANNELS, 256, 256)
nn.Tanh(),
Expand Down
4 changes: 2 additions & 2 deletions src/data_processing/encode_audio_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# Encode samples
if len(sys.argv) > 1:
visualize = sys.argv[1].lower() == "true"
visualize = sys.argv[1].lower() == "visualize"
else:
visualize = False

Expand All @@ -22,4 +22,4 @@
real_data = load_loudness_data(
compiled_data_path
) # datapts, channels, frames, freq bins
print("Data shape,", str(real_data.shape))
print("Data shape:", str(real_data.shape))
7 changes: 5 additions & 2 deletions src/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from architecture import Generator, LATENT_DIM
from utils.file_helpers import get_device, outputs_dir, save_audio
from utils.signal_helpers import graph_spectrogram, norm_db_to_audio
from utils.signal_helpers import audio_to_norm_db, graph_spectrogram, norm_db_to_audio

# Initialize Generator
device = get_device()
Expand All @@ -22,8 +22,11 @@
generated_output = generated_output.squeeze().numpy()
print("Generated output shape:", generated_output.shape)

graph_spectrogram(generated_output, "generated output")
# graph_spectrogram(generated_output, "generated output")
audio_info = norm_db_to_audio(generated_output)
audio_save_path = os.path.join(outputs_dir, "generated_audio.wav")

save_audio(audio_save_path, audio_info)

vis_signal_after_istft = audio_to_norm_db(audio_info)
graph_spectrogram(vis_signal_after_istft, "generated audio (after istft)")
22 changes: 1 addition & 21 deletions src/stereo_sample_gan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import torch
from torch.optim.rmsprop import RMSprop
from torch.utils.data import DataLoader, TensorDataset, random_split

from architecture import (
BATCH_SIZE,
Critic,
Generator,
)
from train import training_loop
from utils.file_helpers import (
Expand All @@ -29,23 +26,6 @@
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize models and optimizers
generator = Generator()
critic = Critic()
optimizer_G = RMSprop(generator.parameters(), lr=LR_G, weight_decay=0.05)
optimizer_C = RMSprop(critic.parameters(), lr=LR_C, weight_decay=0.05)


# Train
device = get_device()
generator.to(device)
critic.to(device)
training_loop(
generator,
critic,
train_loader,
val_loader,
optimizer_G,
optimizer_C,
device,
)
training_loop(train_loader, val_loader)
85 changes: 61 additions & 24 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
import torch
from architecture import LATENT_DIM
from architecture import LATENT_DIM, Critic, Generator
import numpy as np
from torch.optim.rmsprop import RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.file_helpers import (
get_device,
save_model,
)
from utils.signal_helpers import graph_spectrogram

# Constants
N_EPOCHS = 4
VALIDATION_INTERVAL = 1
N_EPOCHS = 8
VALIDATION_INTERVAL = 4
SAVE_INTERVAL = int(N_EPOCHS / 1)

LR_G = 0.003
LR_C = 0.004
LAMBDA_GP = 5
N_CRITIC = 5
CRITIC_STEPS = 5


# Total loss functions
# Loss metrics
def compute_g_loss(critic, fake_validity, fake_audio_data, real_audio_data):
feat_match = 0.25 * calculate_feature_match_diff(
feat_match = 0.35 * calculate_feature_match_diff(
critic, real_audio_data, fake_audio_data
)

Expand All @@ -43,14 +49,13 @@ def compute_c_loss(
)

computed_c_loss = (
torch.mean(fake_validity)
- torch.mean(real_validity)
-(torch.mean(real_validity) - torch.mean(fake_validity))
+ spectral_diff
+ spectral_convergence
)

if training:
gradient_penalty = compute_gradient_penalty(
gradient_penalty = calculate_gradient_penalty(
critic, real_audio_data, fake_audio_data, device
)
computed_c_loss += LAMBDA_GP * gradient_penalty
Expand Down Expand Up @@ -84,7 +89,7 @@ def calculate_spectral_convergence_diff(real_audio_data, fake_audio_data):
return numerator / denominator


def compute_gradient_penalty(critic, real_samples, fake_samples, device):
def calculate_gradient_penalty(critic, real_samples, fake_samples, device):
real_samples.requires_grad_(True)
fake_samples.requires_grad_(True)

Expand All @@ -108,7 +113,16 @@ def compute_gradient_penalty(critic, real_samples, fake_samples, device):


# Training
def train_epoch(generator, critic, dataloader, optimizer_G, optimizer_C, device):
def train_epoch(
generator,
critic,
dataloader,
optimizer_G,
optimizer_C,
scheduler_G,
scheduler_C,
device,
):
generator.train()
critic.train()
total_g_loss, total_c_loss = 0, 0
Expand All @@ -121,7 +135,6 @@ def train_epoch(generator, critic, dataloader, optimizer_G, optimizer_C, device)
optimizer_C.zero_grad()
z = torch.randn(batch, LATENT_DIM, 1, 1).to(device)
fake_audio_data = generator(z)

real_validity = critic(real_audio_data)
fake_validity = critic(fake_audio_data.detach())

Expand All @@ -134,28 +147,32 @@ def train_epoch(generator, critic, dataloader, optimizer_G, optimizer_C, device)
True,
device,
)

c_loss.backward()
optimizer_C.step()

total_c_loss += c_loss.item()

# Train generator every n_critic steps
if i % N_CRITIC == 0:
# Train generator every CRITIC_STEPS steps
if i % CRITIC_STEPS == 0:
optimizer_G.zero_grad()
fake_audio_data = generator(z)
fake_validity = critic(fake_audio_data)

g_loss = compute_g_loss(
critic, fake_validity, fake_audio_data, real_audio_data
)

g_loss.backward()
optimizer_G.step()

total_g_loss += g_loss.item()

return total_g_loss / len(dataloader), total_c_loss / len(dataloader)
avg_g_loss = total_g_loss / len(dataloader)
avg_c_loss = total_c_loss / len(dataloader)

scheduler_G.step(avg_g_loss)
scheduler_C.step(avg_c_loss)

return avg_g_loss, avg_c_loss


def validate(generator, critic, dataloader, device):
Expand Down Expand Up @@ -193,25 +210,41 @@ def validate(generator, critic, dataloader, device):
return total_g_loss / len(dataloader), total_c_loss / len(dataloader)


def training_loop(
generator, critic, train_loader, val_loader, optimizer_G, optimizer_C, device
):
def training_loop(train_loader, val_loader):
# Initialize models and optimizers
generator = Generator()
critic = Critic()
optimizer_G = RMSprop(generator.parameters(), lr=LR_G, weight_decay=0.05)
optimizer_C = RMSprop(critic.parameters(), lr=LR_C, weight_decay=0.05)

scheduler_G = ReduceLROnPlateau(optimizer_G, mode="min", factor=0.5, patience=5)
scheduler_C = ReduceLROnPlateau(optimizer_C, mode="min", factor=0.5, patience=5)

# Train
device = get_device()
generator.to(device)
critic.to(device)
for epoch in range(N_EPOCHS):
train_g_loss, train_c_loss = train_epoch(
generator,
critic,
train_loader,
optimizer_G,
optimizer_C,
scheduler_G,
scheduler_C,
device,
)

print(
f"[{epoch+1}/{N_EPOCHS}] Train - G Loss: {train_g_loss:.6f}, C Loss: {train_c_loss:.6f}"
)

# Validate periodically
if (epoch + 1) % VALIDATION_INTERVAL == 0:
# Validation and saving
early_exit_loss_thresh = 0.2
early_exit_condition = np.abs(train_g_loss) <= early_exit_loss_thresh

if (epoch + 1) % VALIDATION_INTERVAL == 0 or early_exit_condition is True:
val_g_loss, val_c_loss = validate(generator, critic, val_loader, device)
print(
f"------ Val ------ G Loss: {val_g_loss:.6f}, C Loss: {val_c_loss:.6f}"
Expand All @@ -228,6 +261,10 @@ def training_loop(
f"Epoch {epoch + 1} Generated Audio #{i + 1}",
)

# Save models periodically
if (epoch + 1) % SAVE_INTERVAL == 0:
# Save model
if (epoch + 1) % SAVE_INTERVAL == 0 or early_exit_condition is True:
print(
f"Training stopped at epoch {epoch+1}, Final g_loss: {train_g_loss:.6f}"
)
save_model(generator, "StereoSampleGAN-Kick")
break
8 changes: 1 addition & 7 deletions src/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@ def save_audio(save_path, audio):
sf.write(save_path, audio.T, GLOBAL_SR)


def save_model(model, name, preserve_old=False):
# Clear previous models
if preserve_old is not True:
for filename in os.listdir(outputs_dir):
file_path = os.path.join(outputs_dir, filename)
os.remove(file_path)

def save_model(model, name):
# Save model
torch.save(
model.state_dict(),
Expand Down
21 changes: 14 additions & 7 deletions src/utils/signal_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import plotly.graph_objects as go
import plotly.subplots as sp
import scipy
from scipy.signal import convolve2d

from utils.file_helpers import (
GLOBAL_SR,
Expand Down Expand Up @@ -56,17 +57,20 @@ def norm_db_to_audio(loudness_info):

for i in range(N_CHANNELS):
data = scale_data_to_range(loudness_info[i], -40, 40)
data[data < -35] = -40 # Noise gate
magnitudes = librosa.db_to_amplitude(data)
istft = griffin_lim_istft(magnitudes)

stereo_audio.append(istft)

return np.array(stereo_audio) # OUT: Audio information
stereo_audio = np.array(stereo_audio)

return stereo_audio


def griffin_lim_istft(channel_magnitudes):
iterations = 5
momentum = 0.3
iterations = 10
momentum = 0.99

angles = np.exp(2j * np.pi * np.random.rand(*channel_magnitudes.shape))
stft = channel_magnitudes.astype(np.complex64) * angles

Expand Down Expand Up @@ -96,9 +100,12 @@ def griffin_lim_istft(channel_magnitudes):
pad_mode="linear_ramp",
)

stft = stft[:DATA_SHAPE, :DATA_SHAPE]
angles = np.exp(1j * np.angle(stft.T))
stft = stft[:DATA_SHAPE, :DATA_SHAPE] # preserve shape
new_angles = np.exp(1j * np.angle(stft.T))

stft = channel_magnitudes * new_angles

channel_magnitudes[channel_magnitudes < 0.05] = 0 # Noise gate
complex_istft = librosa.istft(
(channel_magnitudes * angles).T,
hop_length=GLOBAL_HOP,
Expand All @@ -110,7 +117,7 @@ def griffin_lim_istft(channel_magnitudes):
return complex_istft


# Data Helpers
# Audio Helpers
def load_audio(path):
y, sr = librosa.load(path, sr=GLOBAL_SR, mono=False)
if y.ndim == 1:
Expand Down
Loading