Skip to content

Commit

Permalink
metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
shuklabhay committed Jan 31, 2025
1 parent 42a12a1 commit 45801ac
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 132 deletions.
Binary file modified outputs/StereoSampleGAN-Snare.pth
Binary file not shown.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ threadpoolctl==3.5.0
tinycss2==1.3.0
torch==2.5.1
torchaudio==2.5.1
torchmetrics==1.6.1
torchvggish==2.1.3
tornado==6.4.1
tqdm==4.66.5
Expand Down
209 changes: 81 additions & 128 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import torch
import torch.nn.functional as F
import torchaudio
from architecture import Critic, Generator
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.prototype.pipelines import VGGISH
from tqdm import tqdm
from utils.constants import model_selection
from utils.evaluation import calculate_audio_metrics
from utils.helpers import DataUtils, ModelParams, ModelUtils, SignalProcessing

# Initialize parameters
Expand All @@ -17,56 +16,12 @@
signal_processing = SignalProcessing(model_params.sample_length)


def calculate_fad(real_specs: torch.Tensor, generated_specs: torch.Tensor) -> float:
"""Calculate Fréchet Audio Distance using manual implementation."""
# Set up model
vggish = VGGISH.get_model().to(model_params.DEVICE)

# Preprocess audio
real_specs = torch.tensor(
DataUtils.scale_data_to_range(real_specs.detach().cpu().numpy(), -1, 1),
device=model_params.DEVICE,
)
real_specs = F.interpolate(
real_specs.mean(dim=1, keepdim=True),
size=(96, 64),
mode="bilinear",
)
generated_specs = torch.tensor(
DataUtils.scale_data_to_range(generated_specs.detach().cpu().numpy(), -1, 1),
device=model_params.DEVICE,
)
generated_specs = F.interpolate(
generated_specs.mean(dim=1, keepdim=True),
size=(96, 64),
mode="bilinear",
)

# Extract VGGish features
with torch.no_grad():
real_feats = vggish(real_specs)
generated_feats = vggish(generated_specs)

# Calculate features
mu_real = real_feats.mean(0)
sigma_real = torch.cov(real_feats.T)
mu_generated = generated_feats.mean(0)
sigma_generated = torch.cov(generated_feats.T)

# Total FAD
fad = torchaudio.functional.frechet_distance(
mu_real, sigma_real, mu_generated, sigma_generated
)
return fad.item()


def compute_g_loss(critic, generated_validity, generated_specs, real_specs):
# Loss metrics
adversarial_loss = -torch.mean(generated_validity)
feat_match = 0.6 * calculate_feature_match_diff(critic, real_specs, generated_specs)
freq_loss = 0.5 * frequency_band_loss(generated_specs, real_specs)
# silence_loss = 0.5 * torch.mean(torch.abs(generated_specs[real_specs < -0.5] + 1.0))
return adversarial_loss + feat_match + freq_loss # + silence_loss
feat_match = 0.8 * calculate_feature_match_diff(critic, real_specs, generated_specs)

return adversarial_loss + feat_match # + silence_loss


def compute_c_loss(
Expand All @@ -77,18 +32,38 @@ def compute_c_loss(
real_spec: torch.Tensor,
training: bool,
) -> torch.Tensor:
"""Calculate critic loss."""
"""Calculate critic loss with additional multi-scale spectral reconstruction."""
# Wasserstein + spectral features
wasserstein_dist = calculate_wasserstein_diff(real_validity, generated_validity)
spectral_diff = 0.3 * calculate_spectral_diff(real_spec, generated_spec)
spectral_convergence = 0.3 * calculate_spectral_convergence_diff(
spectral_diff = 0.15 * calculate_spectral_diff(real_spec, generated_spec)
spectral_convergence = 0.15 * calculate_spectral_convergence_diff(
real_spec, generated_spec
)

computed_c_loss = wasserstein_dist + spectral_diff + spectral_convergence
# Multi-scale L1: downsample and compare at different scales
scales = [1, 0.5, 0.25]
multi_scale_loss = torch.zeros(1, device=real_spec.device)
for s in scales:
if s < 1.0:
size = (
int(real_spec.shape[2] * s),
int(real_spec.shape[3] * s),
)
real_down = F.interpolate(real_spec, size=size, mode="bilinear")
gen_down = F.interpolate(generated_spec, size=size, mode="bilinear")
multi_scale_loss += torch.mean(torch.abs(real_down - gen_down))

computed_c_loss = (
wasserstein_dist
+ spectral_diff
+ spectral_convergence
+ 0.05 * multi_scale_loss
)

# Gradient penalty if training
if training:
gradient_penalty = calculate_gradient_penalty(critic, real_spec, generated_spec)
computed_c_loss = computed_c_loss + model_params.LAMBDA_GP * gradient_penalty
computed_c_loss += model_params.LAMBDA_GP * gradient_penalty

return computed_c_loss

Expand All @@ -114,20 +89,6 @@ def calculate_feature_match_diff(
return loss / len(real_features)


def frequency_band_loss(
generated_specs: torch.Tensor, real_specs: torch.Tensor
) -> torch.Tensor:
"""Calculate loss across frequency bands."""
# Calculate mean energy per frequency band
generated_freq_energy = torch.mean(generated_specs, dim=3)
real_freq_energy = torch.mean(real_specs, dim=3)

# Get loss per frequency band
freq_loss = F.l1_loss(generated_freq_energy, real_freq_energy)

return freq_loss


def calculate_spectral_diff(
real_spec: torch.Tensor, generated_spec: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -181,10 +142,8 @@ def train_epoch(
dataloader: DataLoader,
optimizer_G: torch.optim.Adam,
optimizer_C: torch.optim.Adam,
scheduler_G: OneCycleLR,
scheduler_C: OneCycleLR,
epoch_number: int,
) -> Tuple[float, float, float]:
) -> dict[str, float]:
"""Training."""
generator.train()
critic.train()
Expand All @@ -207,7 +166,7 @@ def train_epoch(

DataUtils.visualize_spectrogram_grid(
real_spec.cpu(),
f"Input data",
f"{model_selection.name.lower().capitalize()} Input Data",
f"static/{model_selection.name.lower()}_data_visualization.png",
)

Expand All @@ -219,9 +178,7 @@ def train_epoch(
critic, generated_validity, real_validity, generated_spec, real_spec, True
)
c_loss.backward()
torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0)
optimizer_C.step()
scheduler_C.step()

total_c_loss += c_loss.item()
total_w_dist += calculate_wasserstein_diff(
Expand All @@ -244,33 +201,36 @@ def train_epoch(
real_spec,
)
g_loss.backward()
torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
optimizer_G.step()
scheduler_G.step()

total_g_loss += g_loss.item()

avg_g_loss = total_g_loss / (len(dataloader) // model_params.CRITIC_STEPS)
avg_c_loss = total_c_loss / len(dataloader)
avg_w_dist = total_w_dist / len(dataloader)

scheduler_G.step()
scheduler_C.step()

return avg_g_loss, avg_c_loss, avg_w_dist
return {"g_loss": avg_g_loss, "c_loss": avg_c_loss, "w_dist": avg_w_dist}


def validate(
generator: Generator,
critic: Critic,
dataloader: DataLoader,
epoch_number: int,
) -> Tuple[float, float, float, float, torch.Tensor]:
"""Validation."""
) -> dict[str, float | torch.Tensor]:
"""Validation returning G loss, C loss, W-dist, FAD, Inception, KIS, and a sample."""
generator.eval()
critic.eval()
total_g_loss, total_c_loss, total_w_dist, total_fad = 0.0, 0.0, 0.0, 0.0

# Counters
total_g_loss = 0.0
total_c_loss = 0.0
total_w_dist = 0.0
total_fad = 0.0
total_is = 0.0
total_kid = 0.0

# Val loop
with torch.no_grad():
for _, (real_spec,) in tqdm(
enumerate(dataloader),
Expand All @@ -288,10 +248,7 @@ def validate(
generated_validity = critic(generated_spec)

g_loss = compute_g_loss(
critic,
generated_validity,
generated_spec,
real_spec,
critic, generated_validity, generated_spec, real_spec
)
c_loss = compute_c_loss(
critic,
Expand All @@ -302,20 +259,27 @@ def validate(
False,
)

# Iterate pointers
metrics = calculate_audio_metrics(real_spec, generated_spec)

total_g_loss += g_loss.item()
total_c_loss += c_loss.item()
total_w_dist += calculate_wasserstein_diff(
real_validity, generated_validity
).item()
total_fad += calculate_fad(real_spec, generated_spec)
total_fad += metrics["fad"]
total_is += metrics["inception_score"]
total_kid += metrics["kernel_inception_distance"]

return (
total_g_loss / len(dataloader),
total_c_loss / len(dataloader),
total_w_dist / len(dataloader),
total_fad / len(dataloader),
generated_spec,
{
"g_loss": total_g_loss / len(dataloader),
"c_loss": total_c_loss / len(dataloader),
"w_dist": total_w_dist / len(dataloader),
"fad": total_fad / len(dataloader),
"is": total_is / len(dataloader),
"kid": total_kid / len(dataloader),
"val_specs": generated_spec.cpu(),
},
)


Expand All @@ -331,24 +295,10 @@ def training_loop(train_loader: DataLoader, val_loader: DataLoader) -> None:
generator.parameters(), lr=model_params.LR_G, betas=(0.5, 0.999)
)
optimizer_C = torch.optim.Adam(
generator.parameters(), lr=model_params.LR_C, betas=(0.5, 0.999)
)
scheduler_G = OneCycleLR(
optimizer_G,
max_lr=model_params.LR_G,
total_steps=len(train_loader) * model_params.N_EPOCHS,
pct_start=0.2,
div_factor=35,
final_div_factor=1e4,
)
scheduler_C = OneCycleLR(
optimizer_C,
max_lr=model_params.LR_C,
total_steps=len(train_loader) * model_params.N_EPOCHS,
pct_start=0.2,
div_factor=35,
final_div_factor=1e4,
critic.parameters(), lr=model_params.LR_C, betas=(0.5, 0.999)
)
scheduler_G = ReduceLROnPlateau(optimizer_G, mode="min", factor=0.5, patience=3)
scheduler_C = ReduceLROnPlateau(optimizer_C, mode="min", factor=0.5, patience=3)
generator.to(model_params.DEVICE)
critic.to(model_params.DEVICE)

Expand All @@ -358,51 +308,54 @@ def training_loop(train_loader: DataLoader, val_loader: DataLoader) -> None:

for epoch in range(model_params.N_EPOCHS):
# Train
train_g_loss, train_c_loss, train_w_dist = train_epoch(
train_metrics = train_epoch(
generator,
critic,
train_loader,
optimizer_G,
optimizer_C,
scheduler_G,
scheduler_C,
epoch,
)

# Validate
val_g_loss, val_c_loss, val_w_dist, current_fad, val_audio_items = validate(
generator, critic, val_loader, epoch
val_metrics = validate(generator, critic, val_loader, epoch)

# Step schedulers based on validation FAD
scheduler_G.step(val_metrics["fad"])
scheduler_C.step(val_metrics["fad"])

print(
f"TRAIN g_loss: {train_metrics["g_loss"]:.4f} c_loss: {train_metrics["c_loss"]:.4f} w_dist: {train_metrics["w_dist"]:.4f}"
)
print(
f"TRAIN w_dist: {train_w_dist:.4f} g_loss: {train_g_loss:.4f} c_loss: {train_c_loss:.4f},"
f"VAL g_loss: {val_metrics["g_loss"]:.4f} c_loss: {val_metrics["c_loss"]:.4f} w_dist: {val_metrics["w_dist"]:.4f}"
)
print(
f"VAL w_dist: {val_w_dist:.4f} g_loss: {val_g_loss:.4f} c_loss: {val_c_loss:.4f}"
f"VAL FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}"
)
print(f"FAD {current_fad:.4f}")
print(
f"Generated: [{val_audio_items.cpu().min():.4f}, {val_audio_items.cpu().max():.4f}]"
f"Generated Range: [{val_metrics["val_specs"].min():.4f}, {val_metrics["val_specs"].max():.4f}]"
)

# End of epoch handling
DataUtils.visualize_spectrogram_grid(
val_audio_items,
f"Raw Model Output Epoch {epoch+1} - w_dist={val_w_dist:.4f} fad={current_fad:.4f}",
val_metrics["val_specs"],
f"Raw Model Output Epoch {epoch+1} - w_dist: {val_metrics["w_dist"]:.4f} FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}",
f"static/{model_selection.name.lower()}_progress_val_spectrograms.png",
)

# Early exit
if current_fad < best_fad:
best_fad = current_fad
if val_metrics["fad"] < best_fad:
best_fad = val_metrics["fad"]
epochs_no_improve = 0
DataUtils.visualize_spectrogram_grid(
val_audio_items,
f"Raw Model Output Epoch {epoch+1} - w_dist={val_w_dist:.4f} fad={current_fad:.4f}",
val_metrics["val_specs"],
f"Raw Model Output Epoch {epoch+1} - w_dist: {val_metrics["w_dist"]:.4f} FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}",
f"static/{model_selection.name.lower()}_best_val_spectrograms.png",
)
model_utils.save_model(generator)
print(
f"New best model saved at w_dist {val_w_dist:.4f} fad {current_fad:.4f}"
f"New best model saved at w_dist: {val_metrics["w_dist"]:.4f} FAD: {val_metrics["fad"]:.4f} IS: {val_metrics["is"]:.4f} KIS: {val_metrics["kid"]:.4f}"
)
else:
epochs_no_improve += 1
Expand Down
7 changes: 3 additions & 4 deletions src/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ class ModelParams:
DEVICE = "cuda:7"
LATENT_DIM = 128
BATCH_SIZE = 64
DROPOUT_RATE = 0.15
DROPOUT_RATE = 0.1

# Training params
CRITIC_STEPS = 5
LR_G = 2e-5
LR_C = 4e-5
LR_DECAY = 0.99
LR_G = 3e-4
LR_C = 6e-4
LAMBDA_GP = 10
N_EPOCHS = 25

Expand Down
Loading

0 comments on commit 45801ac

Please sign in to comment.