diff --git a/outputs/StereoSampleGAN-Snare.pth b/outputs/StereoSampleGAN-Snare.pth index 318ef23..46f9a30 100644 Binary files a/outputs/StereoSampleGAN-Snare.pth and b/outputs/StereoSampleGAN-Snare.pth differ diff --git a/requirements.txt b/requirements.txt index 587638c..1846387 100644 --- a/requirements.txt +++ b/requirements.txt @@ -79,6 +79,7 @@ threadpoolctl==3.5.0 tinycss2==1.3.0 torch==2.5.1 torchaudio==2.5.1 +torchmetrics==1.6.1 torchvggish==2.1.3 tornado==6.4.1 tqdm==4.66.5 diff --git a/src/train.py b/src/train.py index e344d42..ebf8a2c 100644 --- a/src/train.py +++ b/src/train.py @@ -2,13 +2,12 @@ import torch import torch.nn.functional as F -import torchaudio from architecture import Critic, Generator -from torch.optim.lr_scheduler import OneCycleLR +from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader -from torchaudio.prototype.pipelines import VGGISH from tqdm import tqdm from utils.constants import model_selection +from utils.evaluation import calculate_audio_metrics from utils.helpers import DataUtils, ModelParams, ModelUtils, SignalProcessing # Initialize parameters @@ -17,56 +16,12 @@ signal_processing = SignalProcessing(model_params.sample_length) -def calculate_fad(real_specs: torch.Tensor, generated_specs: torch.Tensor) -> float: - """Calculate Fréchet Audio Distance using manual implementation.""" - # Set up model - vggish = VGGISH.get_model().to(model_params.DEVICE) - - # Preprocess audio - real_specs = torch.tensor( - DataUtils.scale_data_to_range(real_specs.detach().cpu().numpy(), -1, 1), - device=model_params.DEVICE, - ) - real_specs = F.interpolate( - real_specs.mean(dim=1, keepdim=True), - size=(96, 64), - mode="bilinear", - ) - generated_specs = torch.tensor( - DataUtils.scale_data_to_range(generated_specs.detach().cpu().numpy(), -1, 1), - device=model_params.DEVICE, - ) - generated_specs = F.interpolate( - generated_specs.mean(dim=1, keepdim=True), - size=(96, 64), - mode="bilinear", - ) - - # Extract VGGish features - with torch.no_grad(): - real_feats = vggish(real_specs) - generated_feats = vggish(generated_specs) - - # Calculate features - mu_real = real_feats.mean(0) - sigma_real = torch.cov(real_feats.T) - mu_generated = generated_feats.mean(0) - sigma_generated = torch.cov(generated_feats.T) - - # Total FAD - fad = torchaudio.functional.frechet_distance( - mu_real, sigma_real, mu_generated, sigma_generated - ) - return fad.item() - - def compute_g_loss(critic, generated_validity, generated_specs, real_specs): # Loss metrics adversarial_loss = -torch.mean(generated_validity) - feat_match = 0.6 * calculate_feature_match_diff(critic, real_specs, generated_specs) - freq_loss = 0.5 * frequency_band_loss(generated_specs, real_specs) - # silence_loss = 0.5 * torch.mean(torch.abs(generated_specs[real_specs < -0.5] + 1.0)) - return adversarial_loss + feat_match + freq_loss # + silence_loss + feat_match = 0.8 * calculate_feature_match_diff(critic, real_specs, generated_specs) + + return adversarial_loss + feat_match # + silence_loss def compute_c_loss( @@ -77,18 +32,38 @@ def compute_c_loss( real_spec: torch.Tensor, training: bool, ) -> torch.Tensor: - """Calculate critic loss.""" + """Calculate critic loss with additional multi-scale spectral reconstruction.""" + # Wasserstein + spectral features wasserstein_dist = calculate_wasserstein_diff(real_validity, generated_validity) - spectral_diff = 0.3 * calculate_spectral_diff(real_spec, generated_spec) - spectral_convergence = 0.3 * calculate_spectral_convergence_diff( + spectral_diff = 0.15 * calculate_spectral_diff(real_spec, generated_spec) + spectral_convergence = 0.15 * calculate_spectral_convergence_diff( real_spec, generated_spec ) - computed_c_loss = wasserstein_dist + spectral_diff + spectral_convergence + # Multi-scale L1: downsample and compare at different scales + scales = [1, 0.5, 0.25] + multi_scale_loss = torch.zeros(1, device=real_spec.device) + for s in scales: + if s < 1.0: + size = ( + int(real_spec.shape[2] * s), + int(real_spec.shape[3] * s), + ) + real_down = F.interpolate(real_spec, size=size, mode="bilinear") + gen_down = F.interpolate(generated_spec, size=size, mode="bilinear") + multi_scale_loss += torch.mean(torch.abs(real_down - gen_down)) + + computed_c_loss = ( + wasserstein_dist + + spectral_diff + + spectral_convergence + + 0.05 * multi_scale_loss + ) + # Gradient penalty if training if training: gradient_penalty = calculate_gradient_penalty(critic, real_spec, generated_spec) - computed_c_loss = computed_c_loss + model_params.LAMBDA_GP * gradient_penalty + computed_c_loss += model_params.LAMBDA_GP * gradient_penalty return computed_c_loss @@ -114,20 +89,6 @@ def calculate_feature_match_diff( return loss / len(real_features) -def frequency_band_loss( - generated_specs: torch.Tensor, real_specs: torch.Tensor -) -> torch.Tensor: - """Calculate loss across frequency bands.""" - # Calculate mean energy per frequency band - generated_freq_energy = torch.mean(generated_specs, dim=3) - real_freq_energy = torch.mean(real_specs, dim=3) - - # Get loss per frequency band - freq_loss = F.l1_loss(generated_freq_energy, real_freq_energy) - - return freq_loss - - def calculate_spectral_diff( real_spec: torch.Tensor, generated_spec: torch.Tensor ) -> torch.Tensor: @@ -181,10 +142,8 @@ def train_epoch( dataloader: DataLoader, optimizer_G: torch.optim.Adam, optimizer_C: torch.optim.Adam, - scheduler_G: OneCycleLR, - scheduler_C: OneCycleLR, epoch_number: int, -) -> Tuple[float, float, float]: +) -> dict[str, float]: """Training.""" generator.train() critic.train() @@ -207,7 +166,7 @@ def train_epoch( DataUtils.visualize_spectrogram_grid( real_spec.cpu(), - f"Input data", + f"{model_selection.name.lower().capitalize()} Input Data", f"static/{model_selection.name.lower()}_data_visualization.png", ) @@ -219,9 +178,7 @@ def train_epoch( critic, generated_validity, real_validity, generated_spec, real_spec, True ) c_loss.backward() - torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0) optimizer_C.step() - scheduler_C.step() total_c_loss += c_loss.item() total_w_dist += calculate_wasserstein_diff( @@ -244,9 +201,7 @@ def train_epoch( real_spec, ) g_loss.backward() - torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) optimizer_G.step() - scheduler_G.step() total_g_loss += g_loss.item() @@ -254,10 +209,7 @@ def train_epoch( avg_c_loss = total_c_loss / len(dataloader) avg_w_dist = total_w_dist / len(dataloader) - scheduler_G.step() - scheduler_C.step() - - return avg_g_loss, avg_c_loss, avg_w_dist + return {"g_loss": avg_g_loss, "c_loss": avg_c_loss, "w_dist": avg_w_dist} def validate( @@ -265,12 +217,20 @@ def validate( critic: Critic, dataloader: DataLoader, epoch_number: int, -) -> Tuple[float, float, float, float, torch.Tensor]: - """Validation.""" +) -> dict[str, float | torch.Tensor]: + """Validation returning G loss, C loss, W-dist, FAD, Inception, KIS, and a sample.""" generator.eval() critic.eval() - total_g_loss, total_c_loss, total_w_dist, total_fad = 0.0, 0.0, 0.0, 0.0 + # Counters + total_g_loss = 0.0 + total_c_loss = 0.0 + total_w_dist = 0.0 + total_fad = 0.0 + total_is = 0.0 + total_kid = 0.0 + + # Val loop with torch.no_grad(): for _, (real_spec,) in tqdm( enumerate(dataloader), @@ -288,10 +248,7 @@ def validate( generated_validity = critic(generated_spec) g_loss = compute_g_loss( - critic, - generated_validity, - generated_spec, - real_spec, + critic, generated_validity, generated_spec, real_spec ) c_loss = compute_c_loss( critic, @@ -302,20 +259,27 @@ def validate( False, ) - # Iterate pointers + metrics = calculate_audio_metrics(real_spec, generated_spec) + total_g_loss += g_loss.item() total_c_loss += c_loss.item() total_w_dist += calculate_wasserstein_diff( real_validity, generated_validity ).item() - total_fad += calculate_fad(real_spec, generated_spec) + total_fad += metrics["fad"] + total_is += metrics["inception_score"] + total_kid += metrics["kernel_inception_distance"] return ( - total_g_loss / len(dataloader), - total_c_loss / len(dataloader), - total_w_dist / len(dataloader), - total_fad / len(dataloader), - generated_spec, + { + "g_loss": total_g_loss / len(dataloader), + "c_loss": total_c_loss / len(dataloader), + "w_dist": total_w_dist / len(dataloader), + "fad": total_fad / len(dataloader), + "is": total_is / len(dataloader), + "kid": total_kid / len(dataloader), + "val_specs": generated_spec.cpu(), + }, ) @@ -331,24 +295,10 @@ def training_loop(train_loader: DataLoader, val_loader: DataLoader) -> None: generator.parameters(), lr=model_params.LR_G, betas=(0.5, 0.999) ) optimizer_C = torch.optim.Adam( - generator.parameters(), lr=model_params.LR_C, betas=(0.5, 0.999) - ) - scheduler_G = OneCycleLR( - optimizer_G, - max_lr=model_params.LR_G, - total_steps=len(train_loader) * model_params.N_EPOCHS, - pct_start=0.2, - div_factor=35, - final_div_factor=1e4, - ) - scheduler_C = OneCycleLR( - optimizer_C, - max_lr=model_params.LR_C, - total_steps=len(train_loader) * model_params.N_EPOCHS, - pct_start=0.2, - div_factor=35, - final_div_factor=1e4, + critic.parameters(), lr=model_params.LR_C, betas=(0.5, 0.999) ) + scheduler_G = ReduceLROnPlateau(optimizer_G, mode="min", factor=0.5, patience=3) + scheduler_C = ReduceLROnPlateau(optimizer_C, mode="min", factor=0.5, patience=3) generator.to(model_params.DEVICE) critic.to(model_params.DEVICE) @@ -358,51 +308,54 @@ def training_loop(train_loader: DataLoader, val_loader: DataLoader) -> None: for epoch in range(model_params.N_EPOCHS): # Train - train_g_loss, train_c_loss, train_w_dist = train_epoch( + train_metrics = train_epoch( generator, critic, train_loader, optimizer_G, optimizer_C, - scheduler_G, - scheduler_C, epoch, ) # Validate - val_g_loss, val_c_loss, val_w_dist, current_fad, val_audio_items = validate( - generator, critic, val_loader, epoch + val_metrics = validate(generator, critic, val_loader, epoch) + + # Step schedulers based on validation FAD + scheduler_G.step(val_metrics["fad"]) + scheduler_C.step(val_metrics["fad"]) + + print( + f"TRAIN g_loss: {train_metrics["g_loss"]:.4f} c_loss: {train_metrics["c_loss"]:.4f} w_dist: {train_metrics["w_dist"]:.4f}" ) print( - f"TRAIN w_dist: {train_w_dist:.4f} g_loss: {train_g_loss:.4f} c_loss: {train_c_loss:.4f}," + f"VAL g_loss: {val_metrics["g_loss"]:.4f} c_loss: {val_metrics["c_loss"]:.4f} w_dist: {val_metrics["w_dist"]:.4f}" ) print( - f"VAL w_dist: {val_w_dist:.4f} g_loss: {val_g_loss:.4f} c_loss: {val_c_loss:.4f}" + f"VAL FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}" ) - print(f"FAD {current_fad:.4f}") print( - f"Generated: [{val_audio_items.cpu().min():.4f}, {val_audio_items.cpu().max():.4f}]" + f"Generated Range: [{val_metrics["val_specs"].min():.4f}, {val_metrics["val_specs"].max():.4f}]" ) # End of epoch handling DataUtils.visualize_spectrogram_grid( - val_audio_items, - f"Raw Model Output Epoch {epoch+1} - w_dist={val_w_dist:.4f} fad={current_fad:.4f}", + val_metrics["val_specs"], + f"Raw Model Output Epoch {epoch+1} - w_dist: {val_metrics["w_dist"]:.4f} FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}", f"static/{model_selection.name.lower()}_progress_val_spectrograms.png", ) # Early exit - if current_fad < best_fad: - best_fad = current_fad + if val_metrics["fad"] < best_fad: + best_fad = val_metrics["fad"] epochs_no_improve = 0 DataUtils.visualize_spectrogram_grid( - val_audio_items, - f"Raw Model Output Epoch {epoch+1} - w_dist={val_w_dist:.4f} fad={current_fad:.4f}", + val_metrics["val_specs"], + f"Raw Model Output Epoch {epoch+1} - w_dist: {val_metrics["w_dist"]:.4f} FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}", f"static/{model_selection.name.lower()}_best_val_spectrograms.png", ) model_utils.save_model(generator) print( - f"New best model saved at w_dist {val_w_dist:.4f} fad {current_fad:.4f}" + f"New best model saved at w_dist: {val_metrics["w_dist"]:.4f} FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}" ) else: epochs_no_improve += 1 diff --git a/src/utils/constants.py b/src/utils/constants.py index 686cc36..9e7353e 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -21,13 +21,12 @@ class ModelParams: DEVICE = "cuda:7" LATENT_DIM = 128 BATCH_SIZE = 64 - DROPOUT_RATE = 0.15 + DROPOUT_RATE = 0.1 # Training params CRITIC_STEPS = 5 - LR_G = 2e-5 - LR_C = 4e-5 - LR_DECAY = 0.99 + LR_G = 3e-4 + LR_C = 6e-4 LAMBDA_GP = 10 N_EPOCHS = 25 diff --git a/src/utils/evaluation.py b/src/utils/evaluation.py new file mode 100644 index 0000000..002f1f2 --- /dev/null +++ b/src/utils/evaluation.py @@ -0,0 +1,110 @@ +import torch +import torch.nn.functional as F +import torchaudio +from torchaudio.prototype.pipelines import VGGISH +from torchmetrics.image.inception import InceptionScore +from torchmetrics.image.kid import KernelInceptionDistance +from utils.helpers import DataUtils, ModelParams + + +def calculate_audio_metrics( + real_specs: torch.Tensor, generated_specs: torch.Tensor +) -> dict: + """Calculate FAD, IS, and KID.""" + model_params = ModelParams() + fad_value = calculate_fad(model_params, real_specs, generated_specs) + inception_dist = calculate_inception_score( + model_params, real_specs, generated_specs + ) + kis_value = calculate_kernel_inception_distance( + model_params, real_specs, generated_specs + ) + + return { + "fad": fad_value, + "inception_score": inception_dist, + "kernel_inception_distance": kis_value, + } + + +def calculate_fad( + model_params: ModelParams, real_specs: torch.Tensor, generated_specs: torch.Tensor +) -> float: + """Calculate Fréchet Audio Distance.""" + vggish = VGGISH.get_model().to(model_params.DEVICE) + + real_specs = torch.tensor( + DataUtils.scale_data_to_range(real_specs.detach().cpu().numpy(), -1, 1), + device=model_params.DEVICE, + ) + real_specs = F.interpolate( + real_specs.mean(dim=1, keepdim=True), + size=(96, 64), + mode="bilinear", + ) + generated_specs = torch.tensor( + DataUtils.scale_data_to_range(generated_specs.detach().cpu().numpy(), -1, 1), + device=model_params.DEVICE, + ) + generated_specs = F.interpolate( + generated_specs.mean(dim=1, keepdim=True), + size=(96, 64), + mode="bilinear", + ) + + with torch.no_grad(): + real_feats = vggish(real_specs) + generated_feats = vggish(generated_specs) + + mu_real = real_feats.mean(0) + sigma_real = torch.cov(real_feats.T) + mu_generated = generated_feats.mean(0) + sigma_generated = torch.cov(generated_feats.T) + fad_val = torchaudio.functional.frechet_distance( + mu_real, sigma_real, mu_generated, sigma_generated + ) + + return fad_val.item() + + +def calculate_inception_score( + model_params: ModelParams, generated_specs: torch.Tensor +) -> float: + """Calculate Inception Score for generated audio spectrograms.""" + # Process generated specs + generated_mono = generated_specs.mean(dim=1, keepdim=True) + generated_scaled = (generated_mono + 1) / 2 + generated_scaled = F.interpolate(generated_scaled, size=(299, 299), mode="bilinear") + generated_rgb = generated_scaled.repeat(1, 3, 1, 1) + + # Calculate Inception Score + inception_score = InceptionScore(splits=10, normalize=True).to(model_params.DEVICE) + inception_score.update(generated_rgb) + score = inception_score.compute() + return score[0].item() + + +def calculate_kernel_inception_distance( + model_params: ModelParams, real_specs: torch.Tensor, generated_specs: torch.Tensor +) -> float: + """Calculate KID between real and generated spectrograms.""" + # Process generated specs + generated_mono = generated_specs.mean(dim=1, keepdim=True) + generated_scaled = (generated_mono + 1) / 2 + generated_scaled = F.interpolate(generated_scaled, size=(299, 299), mode="bilinear") + generated_rgb = generated_scaled.repeat(1, 3, 1, 1) + + # Process real specs + real_mono = real_specs.mean(dim=1, keepdim=True) + real_scaled = (real_mono + 1) / 2 + real_scaled = F.interpolate(real_scaled, size=(299, 299), mode="bilinear") + real_rgb = real_scaled.repeat(1, 3, 1, 1) + + # Calculate KID + kid = KernelInceptionDistance(subsets=100, subset_size=50, normalize=True).to( + model_params.DEVICE + ) + kid.update(real_rgb, real=True) + kid.update(generated_rgb, real=False) + kid_mean, _ = kid.compute() + return kid_mean.item() diff --git a/static/snare_best_val_spectrograms.png b/static/snare_best_val_spectrograms.png index c08a56d..698cec6 100644 Binary files a/static/snare_best_val_spectrograms.png and b/static/snare_best_val_spectrograms.png differ diff --git a/static/snare_data_visualization.png b/static/snare_data_visualization.png index b04af33..90c2ff1 100644 Binary files a/static/snare_data_visualization.png and b/static/snare_data_visualization.png differ diff --git a/static/snare_progress_val_spectrograms.png b/static/snare_progress_val_spectrograms.png index c08a56d..f33fe6d 100644 Binary files a/static/snare_progress_val_spectrograms.png and b/static/snare_progress_val_spectrograms.png differ