Skip to content

Commit

Permalink
improve model (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuklabhay authored Aug 27, 2024
1 parent 00b9840 commit cb51f06
Show file tree
Hide file tree
Showing 13 changed files with 489 additions and 343 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Pylint
name: On Push

on: [push]

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Deep Convolution Audio Generation

[![On Push](https://github.com/shuklabhay/deep-convolution-audio-generation/actions/workflows/push.yml/badge.svg)](https://github.com/shuklabhay/deep-convolution-audio-generation/actions/workflows/push.yml/badge.svg)

Implementing Deep Convolution to generate audio using a generative network

## Directories
Expand Down
Binary file modified model/DCGAN.pth
Binary file not shown.
137 changes: 86 additions & 51 deletions paper/main.md

Large diffs are not rendered by default.

95 changes: 60 additions & 35 deletions src/architecture.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,66 @@
import torch
import torch.nn as nn
from utils.helpers import N_CHANNELS, N_FRAMES, N_FREQ_BINS
from utils.signal_helpers import N_CHANNELS
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

# Constants Constants
# Constants
BATCH_SIZE = 16
LATENT_DIM = 100
N_EPOCHS = 10

VALIDATION_INTERVAL = int(N_EPOCHS / 2)
SAVE_INTERVAL = int(N_EPOCHS / 1)
LATENT_DIM = 128


# Model Components
class LinearAttention(nn.Module):
def __init__(self, in_channels):
super(LinearAttention, self).__init__()
self.reduced_channels = max(in_channels // 8, 1)
self.query = nn.Conv2d(
in_channels, self.reduced_channels, kernel_size=1, groups=1
)
self.key = nn.Conv2d(
in_channels, self.reduced_channels, kernel_size=1, groups=1
)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1, groups=1)
self.gamma = nn.Parameter(torch.zeros(1))

def forward(self, x):
batch, channels, height, width = x.size()

query = self.query(x).view(batch, -1, height * width)
key = self.key(x).view(batch, -1, height * width).permute(0, 2, 1)
value = self.value(x).view(batch, -1, height * width)

attention = torch.bmm(key, query)
attention = F.normalize(F.relu(attention), p=1, dim=1)

out = torch.bmm(value, attention)
out = out.view(batch, channels, height, width)
out = self.gamma * out + x
return out


class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.deconv_blocks = nn.Sequential(
nn.ConvTranspose2d(LATENT_DIM, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.ConvTranspose2d(LATENT_DIM, 128, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ReLU(), # Shape: (BATCH_SIZE, 128, 4, 4)
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ReLU(), # Shape: (BATCH_SIZE, 64, 8, 8)
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
# LinearAttention(32),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.ReLU(), # Shape: (BATCH_SIZE, 32, 16, 16)
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),
nn.ReLU(), # Shape: (BATCH_SIZE, 16, 32, 32)
nn.ConvTranspose2d(16, 8, kernel_size=6, stride=4, padding=1),
nn.BatchNorm2d(8),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(8, 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(4, N_CHANNELS, kernel_size=4, stride=2, padding=1),
nn.Upsample(
size=(N_FRAMES, N_FREQ_BINS), mode="bilinear", align_corners=False
),
nn.ReLU(), # Shape: (BATCH_SIZE, 8, 128, 128)
nn.ConvTranspose2d(8, N_CHANNELS, kernel_size=6, stride=2, padding=2),
# Shape: (BATCH_SIZE, N_CHANNELS, 256, 256)
nn.Tanh(),
)

Expand All @@ -54,32 +73,38 @@ class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv_blocks = nn.Sequential(
nn.Upsample(size=(256, 256), mode="bilinear", align_corners=False),
spectral_norm(nn.Conv2d(N_CHANNELS, 4, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(8),
spectral_norm(nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(16),
LinearAttention(16),
spectral_norm(nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(32),
spectral_norm(nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(64),
spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(128),
spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(256),
nn.Conv2d(256, 1, kernel_size=4, stride=2, padding=1),
spectral_norm(nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=0)),
nn.Flatten(),
nn.Sigmoid(),
)

def extract_features(self, x):
feature_indices = [1, 3, 9] # Conv block index
features = []
for i, layer in enumerate(self.conv_blocks):
x = layer(x)
if i in feature_indices:
features.append(x)
return features

def forward(self, x):
x = self.conv_blocks(x)
return x
27 changes: 14 additions & 13 deletions src/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,36 @@
Generator,
)
from train import training_loop
from utils.helpers import (
compiled_data_path,
from utils.file_helpers import (
get_device,
load_npy_data,
compiled_data_path,
)

# Constants
LR_G = 0.002
LR_D = 0.001
LR_G = 0.004
LR_D = 0.004

# Load data
spectrogram_bank = load_npy_data(compiled_data_path)
spectrogram_bank = torch.FloatTensor(spectrogram_bank)
train_size = int(0.8 * len(spectrogram_bank))
val_size = len(spectrogram_bank) - train_size
all_spectrograms = load_npy_data(compiled_data_path)
all_spectrograms = torch.FloatTensor(all_spectrograms)
train_size = int(0.8 * len(all_spectrograms))
val_size = len(all_spectrograms) - train_size
train_dataset, val_dataset = random_split(
TensorDataset(spectrogram_bank), [train_size, val_size]
TensorDataset(all_spectrograms), [train_size, val_size]
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize models and optimizers
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=LR_G, betas=(0.5, 0.999)) # type: ignore
optimizer_D = optim.Adam(discriminator.parameters(), lr=LR_D, betas=(0.5, 0.999)) # type: ignore
criterion = nn.MSELoss()
optimizer_G = optim.AdamW(generator.parameters(), lr=LR_G, betas=(0.5, 0.999)) # type: ignore
optimizer_D = optim.AdamW(discriminator.parameters(), lr=LR_D, betas=(0.5, 0.999)) # type: ignore


# Train
device = get_device()
generator.to(device)
discriminator.to(device)
Expand All @@ -50,6 +52,5 @@
criterion,
optimizer_G,
optimizer_D,
spectrogram_bank,
device,
)
5 changes: 3 additions & 2 deletions src/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from architecture import Generator, LATENT_DIM
from utils.helpers import normalized_db_to_wav, get_device, graph_spectrogram
from utils.file_helpers import get_device
from utils.signal_helpers import normalized_loudness_to_audio

# Initialize Generator
device = get_device()
Expand All @@ -19,4 +20,4 @@

generated_output = generated_output.squeeze().numpy()
print("Generated output shape:", generated_output.shape)
normalized_db_to_wav(generated_output, "generated_audio")
normalized_loudness_to_audio(generated_output, "generated_audio")
Loading

0 comments on commit cb51f06

Please sign in to comment.